#!/usr/bin/env python3

# originally written by Jingwei Dai, Spring 2024
 
# fixme:
# add support for threads in config

import argparse
from enum import verify
import os
import re
import subprocess
import json

from typing import List
from datetime import datetime
from file_utils import *


class WGSParameters:
    def __init__(
        self,
        source_directory: str = "",
        working_directory: str = "",
        log_filename: str = "logs.txt",
        token_count: int = 2,
        token_separator: str = "_",
        debug: bool = False,
        dry_run: bool = False,
        low_complexity_filter: int = 30,
        qualified_quality_phred: int = 15,
        unqualified_percent_limit: int = 40,
        database: str = "/mounts/tmpfs/kraken2-nt-local",
        confidence: float = 0.2,
    ):
        self.source_directory = os.path.abspath(source_directory)
        self.working_directory = os.path.abspath(working_directory)
        self.log_filename = os.path.join(os.path.abspath(working_directory), log_filename)
        self.token_count = token_count
        self.token_separator = token_separator
        self.debug = debug
        self.dry_run = dry_run
        self.low_complexity_filter = low_complexity_filter
        self.qualified_quality_phred = qualified_quality_phred
        self.unqualified_percent_limit = unqualified_percent_limit
        self.database = database
        self.confidence = confidence

    def verify(self):
        # could auto create, but for testing purposes it's set to False for now
        validate_directories([self.working_directory], should_create=False)
        validate_directories([self.source_directory, self.database], should_create=False)            

    def read_from(self, config_file):
        # could throw json.JSONDecodeError or FileNotFoundError
        # allow crash because the program should not proceed without params
        with open(config_file, 'r') as file:
            config_data = json.load(file)

        if config_data is not None:
            for key, value in config_data.items():
                # value cannot be "" or None, otherwise don't assign
                if hasattr(self, key) and value:
                    setattr(self, key, value)
        self.verify()


class WGSRunner:

    def __init__(self, params: WGSParameters):
        self.source_directory = params.source_directory
        self.working_directory = params.working_directory
        self.token_count = params.token_count
        self.token_separator = params.token_separator
        self.debug = params.debug
        self.dry_run = params.dry_run
        self.low_complexity_filter = params.low_complexity_filter
        self.qualified_quality_phred = params.qualified_quality_phred
        self.unqualified_percent_limit = params.unqualified_percent_limit
        self.database = params.database
        self.confidence = params.confidence

        self.logger = build_logger(os.path.join(params.working_directory, params.log_filename))

        # formatters
        self.standardize_basename = lambda filename: preserve_first_n_tokens(
            filename, params.token_separator, params.token_count
        )
        self.basename_without_extension = lambda filename: os.path.basename(
            filename
        ).split(".")[0]

    def _build_fastp_command(self, source_file_name, source_file_path):
        basename = self.standardize_basename(source_file_name)  # Sample_S1_Something_Redundant_R1.fastq -> Sample_S1
        return [
            "fastp",
            "--low_complexity_filter", str(self.low_complexity_filter),
            "--qualified_quality_phred", str(self.qualified_quality_phred),
            "--unqualified_percent_limit", str(self.unqualified_percent_limit),
            "--merge",
            "--include_unmerged",
            "--in1", source_file_path,
            "--in2", re.sub(r"R1", "R2", source_file_path),
            "--merged_out", os.path.join(self.working_directory, f"{basename}.fastq"),
            "-h", os.path.join(self.working_directory, f"{basename}.html"),
            "-j", os.path.join(self.working_directory, f"{basename}.json"),
        ], os.path.join(self.working_directory, f"{basename}.fastq")

    def _build_kraken2_command(self, fastp_output_file_path):
        basename = self.basename_without_extension(fastp_output_file_path)  # Sample_S1.fastq -> Sample_S1
        return [
            "kraken2",
            "--db", self.database,
            "--threads", "8",
            "--classified-out", os.path.join(self.working_directory, f"{basename}-classified.out"),
            "--unclassified-out", os.path.join(self.working_directory, f"{basename}-unclassified.out"),
            "--confidence",str(self.confidence),
            "--report-zero-counts",
            "--report",
            os.path.join(self.working_directory, f"{basename}-report.out"),
            "--use-names",
            fastp_output_file_path,
        ]

    def _is_double_laned(self):
        return any("L001" in f for f in os.listdir(self.source_directory)) and any(
            "L002" in f for f in os.listdir(self.source_directory)
        )

    def _get_r1_source_files(self):
        if self._is_double_laned():
            # concatenate double-laned files
            for entry in os.scandir(self.source_directory):
                source_file_name = entry.name
                source_file_path = entry.path
                if (
                    "R1" in source_file_name
                    and "L001" in source_file_name
                    and (
                        entry.name.endswith(".fastq")
                        or entry.name.endswith(".fastq.gz")
                    )
                ):
                    r1_source_path_1 = source_file_path
                    r1_source_path_2 = re.sub(
                        r"L001", "L002", source_file_path
                    )  # Sample_S1_L001_R1.fastq -> Sample_S1_L002_R1.fastq
                    r1_output_name = re.sub(
                        r"_L001", "", source_file_name
                    )  # Sample_S1_L001_R1.fastq -> Sample_S1_R1.fastq
                    r1_output_path = os.path.join(
                        self.working_directory, r1_output_name
                    )  # ./ -> ./Sample_S1_R1.fastq
                    r2_source_path_1 = re.sub(
                        r"R1", "R2", r1_source_path_1
                    )  # Sample_S1_L001_R1.fastq -> Sample_S1_L001_R2.fastq
                    r2_source_path_2 = re.sub(
                        r"R1", "R2", r1_source_path_2
                    )  # Sample_S1_L002_R1.fastq -> Sample_S1_L002_R2.fastq
                    r2_output_path = re.sub(
                        r"R1", "R2", r1_output_path
                    )  # ./Sample_S1_R1.fastq -> ./Sample_S1_R2.fastq

                    if self.dry_run or self.debug:
                        self.logger.info(f"0. Concatenated source files. New paths:")
                        self.logger.info(f"   R1: {r1_output_path}")
                        self.logger.info(f"   R2: {r2_output_path}")
                    if not self.dry_run:
                        combine_files(
                            r1_source_path_1, r1_source_path_2, r1_output_path
                        )
                        combine_files(
                            r2_source_path_1, r2_source_path_2, r2_output_path
                        )
                    # yield updated source_file_name and source_file_path
                    yield r1_output_name, r1_output_path
        else:
            for entry in os.scandir(self.source_directory):
                if "R1" in entry.name and (
                    entry.name.endswith(".fastq") or entry.name.endswith(".fastq.gz")
                ):
                    yield entry.name, entry.path

    def build_command_queue(self):
        for source_file_name, source_file_path in self._get_r1_source_files():
            if self.debug or self.dry_run:
                self.logger.info(f"1. Using source files:")
                self.logger.info(f"   R1: {source_file_path}")
                self.logger.info(f"   R2: {re.sub(r'R1', 'R2', source_file_path)}")
            fastp_command, fastp_output_path = self._build_fastp_command(
                source_file_name, source_file_path
            )
            kraken2_command = self._build_kraken2_command(fastp_output_path)
            if self.debug or self.dry_run:
                self.logger.info(f"2. Commands:")
                self.logger.info(f"   → {' '.join(fastp_command)}")
                self.logger.info(f"   → {' '.join(kraken2_command)}\n")
            yield fastp_command
            if not self.skip_kraken:
                yield kraken2_command

    def run_command(self, command):
        process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            shell=False,
            cwd=self.working_directory,
        )
        stdout, stderr = process.communicate()
        return_code = process.returncode

        self.logger.info(stdout.decode())
        self.logger.error(stderr.decode())

        return return_code

    def run(self):
        change_working_directory(self.working_directory, self.logger)

        self.logger.info(
            "\n*** Dry Run ***" if self.dry_run else "\n*** Captured Output ***"
        )
        for command in self.build_command_queue():
            if not self.dry_run:
                self.run_command(command)

def parse_arguments_in_command(parser):
    parser.add_argument(
        "--source-directory",
        required=True,
        help="Directory containing FASTQ source files.",
    )
    parser.add_argument(
        "--working-directory",
        required=True,
        help="Directory where all output files will be saved.",
    )
    parser.add_argument(
        "--log-filename",
        required=False,
        help="Name of the output log file. Default is 'logs.txt'.",
        default="logs.txt",
    )

    parser.add_argument(
        "--token-count",
        required=True,
        help="Number of tokens to be used in output filenames. Default is 2.",
        type=int,
        default=2,
    )
    parser.add_argument(
        "--token-separator",
        required=True,
        help="Token separator to be used in output filenames. Default is '_'.",
        default="_",
    )

    parser.add_argument(
        "--debug", action="store_true", help="Enable debug mode for detailed logging."
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Simulate the execution without making any changes.",
    )
    parser.add_argument(
        "--skip-kraken", action="store_true", help="Skip running kraken2."
    )

    # https://github.com/OpenGene/fastp
    # fastp optional params
    parser.add_argument(
        "--low-complexity-filter",
        required=False,
        help="Threshold for low complexity filter in fastp.",
        default=30,
    )
    parser.add_argument(
        "--qualified-quality-phred",
        required=False,
        help="Quality threshold in PHRED score for read trimming in fastp.",
        default=15,
    )
    parser.add_argument(
        "--unqualified-percent-limit",
        required=False,
        help="Maximum percentage of unqualified bases allowed in reads in fastp.",
        default=40,
    )

    # https://github.com/DerrickWood/kraken2
    # kraken optional params
    parser.add_argument(
        "--database",
        required=False,
        help="Path to the kraken2 database.",
        default="/mounts/tmpfs/kraken2-nt-local",
    )
    parser.add_argument(
        "--confidence",
        required=False,
        help="Minimum confidence threshold for taxonomy assignments in kraken2.",
        default=0.2,
    )

    args = parser.parse_args()

    params = WGSParameters()
    params.source_directory = args.source_directory
    params.working_directory = args.working_directory
    params.log_filename = args.log_filename
    params.token_count = args.token_count
    params.token_separator = args.token_separator
    params.low_complexity_filter = args.low_complexity_filter
    params.qualified_quality_phred = args.qualified_quality_phred
    params.unqualified_percent_limit = args.unqualified_percent_limit
    params.database = args.database
    params.confidence = args.confidence
    params.debug = args.debug
    params.dry_run = args.dry_run

    return params

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process FASTQ files.")

    parser.add_argument(
        "--config",
        required=False,
        help="Config file from where all command parameters should be extracted.",
        default=False
    )
    args = parser.parse_args()

    if args.config:
        wgs_params = WGSParameters()
        wgs_params.read_from(args.config)
    else:
        wgs_params = parse_arguments_in_command(parser)

    wgs_runner = WGSRunner(wgs_params)
    logger = wgs_runner.logger

    logger.info("\n********************")
    logger.info(f"{datetime.now()}: Starting WGS Runner.")

    wgs_runner.run()

    logger.info("*** Diagnostics ***")
    logger.info(os.getcwd())
    logger.info("New files created during this run:")
    for filename, size in get_new_files_info(os.getcwd(), "timestamp"):
        logger.info("'%s' at %s bytes", filename, size)
    logger.info("")

    logger.info(f"{datetime.now()}: Workflow complete.")
    logger.info("\n********************")
