import csv, os, time,shutil
from workflow.validate import (
    validate_config_options, 
    validate_args, 
    validate_additional_args, 
    validate_env, 
    validate_manifest_files
)
from workflow.log import (
    logcolor,
    build_logger,
    human_readable_time
)
from modules import (
    AdapterRemovalRunner, 
    Bowtie2Runner, 
    FastQCRunner, 
    SamToolsRunner, 
    PicardRunner,
    SplitByRefRunner,
    MergeByHorizonRunner,
    FilterPresenceRunner,
    CleanSQRunner,
    MapDamageRunner
)

class Runner:
    """
    Runner class takes in a path to a configparser object and sets up the workflow
    from user supplied config info. Lots of moving pieces.
    """
    software_dictionary = {
        "bowtie2": Bowtie2Runner,
        "fastqc": FastQCRunner,
        "picard": PicardRunner,
        "mapdamage": MapDamageRunner,
        "adapterremoval": AdapterRemovalRunner,
        "samtools": SamToolsRunner,
        "splitbyref": SplitByRefRunner,
        "mergebyhorizon": MergeByHorizonRunner,
        "filterpresence": FilterPresenceRunner,
        "cleansq": CleanSQRunner
    }
    def __init__(self, config, config_path):
        # Set up empty params to be populated after validation
        self.software_steps = []
        self.input_fastq_path = ""
        self.output_path = ""
        self.logger = None
        self.threads = 2    # default to 2?
        self.config = config
        self.config_path = config_path
        self.dry_run = False
        self.manifest_path = ""
        self.manifest_list = []
        
        timestamp = time.strftime("%Y-%m-%d_%H:%M:%S")
        self.logpath = f"workflow_run_{timestamp}.log"
        self.logger = build_logger(self.logpath)

    def main(self):
        self.logger.info(f"{logcolor.BOLD}=== VALIDATE & LOAD ==={logcolor.ENDC}")
        self.validate_all()
        
        # If no errors, read the config into the runner object
        # If errors, scrap the whole thing and go get some lunch
        self.read_config() 

        # build software step objects
        # (or try to at least)
        self.setup_software_steps()   

        # Valid software env, check for all needed software
        # (wow still validating)
        self.validate_steps() 

        # FullRun or DryRun depending on Conf
        if self.dry_run:
            self.dryrun()
        else:
            self.fullrun()
            

    def dryrun(self):
        self.logger.info(f"\n{logcolor.BOLD}=== DRY RUN ==={logcolor.ENDC}")
        # Print each software step in order (nothing is run, but it has pretty colors)
        for step in self.software_steps:
            step.dry_run()
        
        print(self.config_path,)

    def fullrun(self):
        start_time_all = time.time()

        self.logger.info(f"\n{logcolor.BOLD}=== FULL RUN ==={logcolor.ENDC}")
        step_n_files = self.convert_manifest()
        for step in self.software_steps:
            step_n_files = step.find_input_files(step_n_files)
            step.build_command_queue()
            step.create_output_dir()
            run_output = step.run()

            # Account for steps that have to rediscover output
            if run_output != None:
                step_n_files = run_output

        end_time_all = time.time()

        # Summarize run times
        self.logger.info(f"\n{logcolor.BOLD}Run complete!{logcolor.ENDC}")
        self.logger.info(f"\n{logcolor.BOLD}=== Summary ==={logcolor.ENDC}")
        total_time = human_readable_time(end_time_all - start_time_all)
        self.logger.info(f"{logcolor.INFO}Total time elapsed: {total_time}{logcolor.ENDC}\n")
        self.logger.info(f"{logcolor.INFO}Individual step times: {total_time}{logcolor.ENDC}")
        for step in self.software_steps:
            total_step_time = human_readable_time(step.end_time - step.start_time)
            self.logger.info(f"{logcolor.INFO} - {step.stepname} time: {total_step_time}{logcolor.ENDC}")
        
        # Move important files to the project dir
        shutil.copy(self.logpath, self.output_path)
        shutil.copy(self.config_path, self.output_path)
        shutil.copy(self.manifest_path, self.output_path)



    def convert_manifest(self):
        manifest_conv = []
        for line in self.manifest_list:
            manifest_conv.append({
                "id": line["id"],
                "input1": None,
                "input2": None,
                "output1": os.path.join(self.input_fastq_path, line["forward_file"]),
                "output2": os.path.join(self.input_fastq_path, line["reverse_file"]) if line.get("reverse_file") else None,
            })
        return manifest_conv
    
    def validate_all(self):
        # Validate config file
        validate_config_options(config=self.config, logger=self.logger)
         # Validate more (user input from config files)
        validate_args(config=self.config, software_dictionary=self.software_dictionary, logger=self.logger)
        # Validate more more (Manifest files)
        validate_manifest_files(config=self.config, logger=self.logger)
    
    def validate_steps(self):
        for step in self.software_steps:
            validate_additional_args(
                stepname=step.stepname,
                unneeded_args=step.unneeded_args,
                additional_step_args=step.additional_args,
                logger=self.logger
            )
        validate_env(software_steps=self.software_steps, logger=self.logger)

    def read_config(self):
        # Get GENERAL config, input and output paths, dry run option
        # All of these are already valid by this point :D
        self.input_fastq_path = self.config["GENERAL"]["input_fastq_path"].replace('"', "")    
        self.threads = int(self.config["GENERAL"]["threads"])
        self.output_path = self.config["GENERAL"]["output_path"].replace('"', "")
        self.manifest_path = self.config["GENERAL"]["input_manifest"].replace('"',"")
        if self.config["GENERAL"]["dry_run"] in ["True", "true", True]:
            self.dry_run = True

        # Read file manifest into dictionary
        # It should be valid at this point
        with open(self.manifest_path, 'r') as data:
            for line in csv.DictReader(data):
                self.manifest_list.append(line)

    def setup_software_steps(self):
        # input RAW fastq files from seq (for step N=0)
        prev_step_output = self.input_fastq_path 
        step_num = 1

        # Iterate over each software in the config
        for section in self.config["GENERAL"]["order"].split(" "):  
            # software name (shortname, what gets used on the command line)
            section_software = self.config[section]["software"]
            # enumerate (for output directories)
            section_num = "0" + str(step_num)
            # step N takes step N-1 output as input (sometimes)
            section_input = prev_step_output
            # output should be in output base path under "stepname"                   
            section_output = f"{self.output_path}/{section_num[-2:]}_{section}"
            # get additional args for this step from config     
            section_args = self.config[section]["args"].replace('"', "")
            # If section has a mode, collect it
            section_mode = None
            if "mode" in self.config[section]:
                section_mode = self.config[section]["mode"]

            # Initialize from class
            SoftObj = self.software_dictionary[section_software](
                logger = self.logger,
                manifest_list = self.manifest_list,
                input_path = section_input, 
                output_path = section_output, 
                additional_args = section_args,
                threads = self.threads,
                stepname = section,
		mode = section_mode
            )
            
            # step N+1 should take step N output as input
            # if step is a leaf, 
            # step N+1 should take step N-1 output as input (should, at least)
            if SoftObj.leaf == False:
                prev_step_output = section_output                       
            self.software_steps.append(SoftObj)
            step_num = step_num + 1
