import os
import json
import argparse

def create_experiment_tree(experiment_name, experiment_path, data_paths, hyperparameters):
    """
    Create the experiment folder tree structure for µPIX and save the hyperparameters.

    Args:
        experiment_name (str): Name of the experiment.
        experiment_path (str): Path where the experiment folder will be created.
        data_paths (dict): Dictionary containing paths to clean, noisy, and test images.
        hyperparameters (dict): Dictionary containing training hyperparameters.
    """
    # Create the base experiment directory
    experiment_dir = os.path.join(experiment_path, experiment_name)
    os.makedirs(experiment_dir, exist_ok=True)

    # Create the hyperparameters.json file
    hyperparameters_file = os.path.join(experiment_dir, 'hyperparameters.json')
    
    # Add data paths to the hyperparameters
    hyperparameters['data_paths'] = data_paths

    # Write hyperparameters to hyperparameters.json
    with open(hyperparameters_file, 'w') as f:
        json.dump(hyperparameters, f, indent=4)

    # Create the results folder with subfolders and log file
    results_dir = os.path.join(experiment_dir, 'results')
    networks_dir = os.path.join(results_dir, 'networks')
    images_dir = os.path.join(results_dir, 'images')
    os.makedirs(networks_dir, exist_ok=True)
    os.makedirs(images_dir, exist_ok=True)

    # Create an empty log.txt file in the results folder
    log_file = os.path.join(results_dir, 'log.txt')
    open(log_file, 'a').close()

    # Create the predictions folder
    predictions_dir = os.path.join(experiment_dir, 'predictions')
    os.makedirs(predictions_dir, exist_ok=True)

    print(f"Experiment '{experiment_name}' created successfully at {experiment_dir}")


def main():
    # Set up the argument parser
    parser = argparse.ArgumentParser(description='Create experiment folder for µPIX and save hyperparameters.')
    
    parser.add_argument('--experiment_name', type=str, required=True, help='Name of the experiment')
    parser.add_argument('--experiment_path', type=str, required=True, help='Directory where the experiment will be saved')
    parser.add_argument('--clean_data_path', type=str, required=True, help='Path to clean images')
    parser.add_argument('--noisy_data_path', type=str, required=True, help='Path to noisy images')
    parser.add_argument('--test_data_path', type=str, required=False, help='Path to test images')
    
    args = parser.parse_args()

    # Organize the data paths into a dictionary
    data_paths = {
        'clean': args.clean_data_path,
        'noisy': args.noisy_data_path,
        'test': args.test_data_path
    }

    # Organize the hyperparameters into a dictionary
    hyperparameters = {
        'learning_rate_generator': 1e-4,
        'learning_rate_discriminator': 1e-4,
        'batch_size': 16,
        'num_epochs': 100,
        'loss_weight': 10,
        'tile_size': 256,
        'patience': 20,
        'valid_size':0.1,
        'seed':42
    }

    # Call the function to create the experiment tree and save the hyperparameters
    create_experiment_tree(args.experiment_name, args.experiment_path, data_paths, hyperparameters)


if __name__ == '__main__':
    main()