#!/usr/bin/env python3

import pandas as pd
import os
import argparse
import numpy as np # Import numpy for potential NaN handling if needed, though pandas handles it well

# --- Helper function to extract sampling horizon ---
def get_sampling_horizon(sample_id):
    """
    Extracts the sampling horizon from a sample ID (directory name).
    Assumes horizon is the first digit (1-8) of the directory name.
    Returns an integer for valid horizons, None otherwise.
    """
    if len(sample_id) > 0 and sample_id[0].isdigit():
        horizon_digit = int(sample_id[0])
        if 1 <= horizon_digit <= 8:
            return horizon_digit
    return None

def analyze_mapdamage_outputs(input_dir, c_to_t_threshold, g_to_a_threshold, deltas_threshold, deltad_threshold, bases_5p_end, bases_3p_end):
    """
    Analyzes mapDamage 2.0 output files to identify samples exhibiting DNA damage.
    Calculates misincorporation frequencies across a specified number of positions at each end,
    and derives a combined 'damage_score'.

    Args:
        input_dir (str): Path to the directory containing individual mapDamage run directories.
                         Each run should be in its own subdirectory (e.g., /input_dir/sample_id/).
        c_to_t_threshold (float): Minimum C>T frequency at 5' first N positions to be considered damaged.
        g_to_a_threshold (float): Minimum G>A frequency at 3' first N positions to be considered damaged.
        deltas_threshold (float): Minimum DeltaS (deamination rate in single-stranded DNA) mean value to be considered damaged.
        deltad_threshold (float): Minimum DeltaD (deamination rate in double-stranded DNA) mean value to be considered damaged.
        bases_5p_end (int): Number of bases from the 5' end to consider for C>T frequency.
        bases_3p_end (int): Number of base positions from the 3' end to consider for G>A frequency.

    Returns:
        pandas.DataFrame: A DataFrame containing metrics for all processed samples,
                          and a subset DataFrame of identified damaged samples.
    """

    # Ensure input_dir is absolute for robust pathing
    input_dir = os.path.abspath(input_dir)

    sample_dirs = [d for d in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, d))]

    if not sample_dirs:
        print(f"No sample directories found in '{input_dir}'. Please ensure each run is in a subdirectory.")
        return pd.DataFrame(), pd.DataFrame()

    all_sample_metrics = []

    print(f"Starting analysis of {len(sample_dirs)} samples from '{input_dir}'...")

    for sample_id in sorted(sample_dirs): # Sort for consistent output
        base_sample_dir = os.path.join(input_dir, sample_id)
        
        # Assume files are directly in base_sample_dir first
        misincorporation_file = os.path.join(base_sample_dir, "misincorporation.txt")
        stats_file = os.path.join(base_sample_dir, "Stats_out_MCMC_iter_summ_stat.csv")

        # Check if files exist in the base directory
        files_found = os.path.exists(misincorporation_file) and os.path.exists(stats_file)

        # If not found, try the 'results' subdirectory (common mapDamage output structure)
        if not files_found:
            misincorporation_file_results = os.path.join(base_sample_dir, "results", "misincorporation.txt")
            stats_file_results = os.path.join(base_sample_dir, "results", "Stats_out_MCMC_iter_summ_stat.csv")
            
            if os.path.exists(misincorporation_file_results) and os.path.exists(stats_file_results):
                misincorporation_file = misincorporation_file_results
                stats_file = stats_file_results
                files_found = True
            
        if not files_found:
            print(f"Error: MapDamage output files (misincorporation.txt, Stats_out_MCMC_iter_summ_stat.csv) not found in '{base_sample_dir}' or '{os.path.join(base_sample_dir, 'results')}' for sample '{sample_id}'. Skipping this sample.")
            # Add this sample with None values for metrics, but try to get horizon
            all_sample_metrics.append({
                'sample_id': sample_id,
                f'C_to_T_5p_first{bases_5p_end}_freq': None,
                f'G_to_A_3p_first{bases_3p_end}_freq': None,
                'deltaS_mean': None,
                'deltaD_mean': None, # Added DeltaD
                'damage_score': None,
                'sampling_horizon': get_sampling_horizon(sample_id) 
            })
            continue # Skip to next sample in the loop

        # Initialize metrics for the current sample
        c_to_t_5p_freq = None
        g_to_a_3p_freq = None
        delta_s_mean = None
        delta_d_mean = None # Added DeltaD
        damage_score = None 

        # --- Process misincorporation.txt ---
        try: 
            df_mis = pd.read_csv(misincorporation_file, sep='\t', engine='python')
            df_mis.columns = df_mis.columns.str.strip() # Clean column names
            
            if 'End' in df_mis.columns:
                df_mis['End'] = df_mis['End'].astype(str).str.strip()
            
            if 'Pos' in df_mis.columns:
                df_mis['Pos'] = pd.to_numeric(df_mis['Pos'], errors='coerce')
                df_mis = df_mis.dropna(subset=['Pos']) # Drop rows where 'Pos' could not be converted
                df_mis['Pos'] = df_mis['Pos'].astype(int)

            # Calculate C>T 5' End Frequency
            ct_5p_data = df_mis[(df_mis['Pos'] >= 1) & (df_mis['Pos'] <= bases_5p_end) & (df_mis['End'] == '5p')]
            if not ct_5p_data.empty and 'C>T' in ct_5p_data.columns and 'Total' in ct_5p_data.columns:
                total_ct_counts = ct_5p_data['C>T'].sum()
                total_bases_at_5p_range = ct_5p_data['Total'].sum()
                if total_bases_at_5p_range > 0:
                    c_to_t_5p_freq = total_ct_counts / total_bases_at_5p_range

            # Calculate G>A 3' End Frequency
            ga_3p_data = df_mis[(df_mis['Pos'] >= 1) & (df_mis['Pos'] <= bases_3p_end) & (df_mis['End'] == '3p')]
            if not ga_3p_data.empty and 'G>A' in ga_3p_data.columns and 'Total' in ga_3p_data.columns:
                total_ga_counts = ga_3p_data['G>A'].sum()
                total_bases_at_3p_range = ga_3p_data['Total'].sum()
                if total_bases_at_3p_range > 0:
                    g_to_a_3p_freq = total_ga_counts / total_bases_at_3p_range
                
        except pd.errors.EmptyDataError:
            print(f"Warning: {misincorporation_file} is empty for {sample_id}. C>T/G>A metrics will be None.")
        except KeyError as e:
            print(f"Warning: Missing expected column or invalid data in {misincorporation_file} for {sample_id}: {e}. C>T/G>A metrics will be None.")
        except Exception as e:
            print(f"Error processing {misincorporation_file} for {sample_id}: {e}. C>T/G>A metrics will be None.")

        # --- Process Stats_out_MCMC_iter_summ_stat.csv ---
        try: 
            df_stats = pd.read_csv(stats_file, index_col=0)
            df_stats.columns = df_stats.columns.str.strip() # Clean column names
            if 'DeltaS' in df_stats.columns:
                delta_s_mean = df_stats['DeltaS'].iloc[0] # Assuming the first row contains the summary stat
            else:
                print(f"Warning: 'DeltaS' column not found in {stats_file} for {sample_id}. DeltaS metric will be None.")
            
            if 'DeltaD' in df_stats.columns: # Added DeltaD
                delta_d_mean = df_stats['DeltaD'].iloc[0] # Assuming the first row contains the summary stat
            else:
                print(f"Warning: 'DeltaD' column not found in {stats_file} for {sample_id}. DeltaD metric will be None.")

        except pd.errors.EmptyDataError:
            print(f"Warning: {stats_file} is empty for {sample_id}. DeltaS/DeltaD metrics will be None.")
        except KeyError as e:
            print(f"Warning: Missing expected column in {stats_file} for {sample_id}: {e}. DeltaS/DeltaD metrics will be None.")
        except Exception as e:
            print(f"Error processing {stats_file} for {sample_id}: {e}. DeltaS/DeltaD metrics will be None.")

        # --- Calculate damage_score ---
        # Treat None values as 0 for the summation to get a score even with partial data.
        score_components = [
            c_to_t_5p_freq if c_to_t_5p_freq is not None else 0,
            g_to_a_3p_freq if g_to_a_3p_freq is not None else 0,
            delta_s_mean if delta_s_mean is not None else 0,
            delta_d_mean if delta_d_mean is not None else 0 # Added DeltaD to score components
        ]
        
        # Calculate damage_score only if at least one original metric was valid (not None)
        if any(x is not None for x in [c_to_t_5p_freq, g_to_a_3p_freq, delta_s_mean, delta_d_mean]): # Check for DeltaD too
             damage_score = sum(score_components)
        # Otherwise, damage_score remains None (initialized as None)

        all_sample_metrics.append({
            'sample_id': sample_id,
            f'C_to_T_5p_first{bases_5p_end}_freq': c_to_t_5p_freq,
            f'G_to_A_3p_first{bases_3p_end}_freq': g_to_a_3p_freq,
            'deltaS_mean': delta_s_mean,
            'deltaD_mean': delta_d_mean, # Added DeltaD
            'damage_score': damage_score,
            'sampling_horizon': get_sampling_horizon(sample_id) 
        })

    master_df = pd.DataFrame(all_sample_metrics)

    # Filter for damaged samples based on thresholds
    c_to_t_col = f'C_to_T_5p_first{bases_5p_end}_freq'
    g_to_a_col = f'G_to_A_3p_first{bases_3p_end}_freq'

    damaged_samples_subset = master_df[
        (master_df[c_to_t_col].notna()) & (master_df[c_to_t_col] >= c_to_t_threshold) &
        (master_df[g_to_a_col].notna()) & (master_df[g_to_a_col] >= g_to_a_threshold) &
        (master_df['deltaS_mean'].notna()) & (master_df['deltaS_mean'] >= deltas_threshold) &
        (master_df['deltaD_mean'].notna()) & (master_df['deltaD_mean'] >= deltad_threshold) # Used deltad_threshold
    ].copy() # .copy() to avoid SettingWithCopyWarning in future operations

    return master_df, damaged_samples_subset

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Identify DNA damaged samples from mapDamage 2.0 output directories.",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument(
        "input_directory",
        help="Path to the parent directory containing all mapDamage run subdirectories (e.g., /path/to/my_mapdamage_runs/).\n"
             "Each mapDamage output for a sample should be in a directory like: /path/to/my_mapdamage_runs/TAXON_ID/ or /path/to/my_mapdamage_runs/TAXON_ID/results/."
    )
    parser.add_argument(
        "--c_to_t_threshold",
        type=float,
        default=0.10,
        help="Minimum C>T frequency at 5' end, first N positions to flag as damaged (default: 0.10, i.e., 10%%)."
    )
    parser.add_argument(
        "--g_to_a_threshold",
        type=float,
        default=0.05,
        help="Minimum G>A frequency at 3' end, first N positions to flag as damaged (default: 0.05, i.e., 5%%)."
    )
    parser.add_argument(
        "--deltas_threshold",
        type=float,
        default=0.05,
        help="Minimum DeltaS (deamination rate in single-stranded DNA) value to flag as damaged (default: 0.05).\n"
             "This refers to the value in the 'DeltaS' column of Stats_out_MCMC_iter_summ_stat.csv."
    )
    parser.add_argument(
        "--deltad_threshold", # Corrected argument name
        type=float,
        default=0.01, # A common baseline for DeltaD, can be adjusted
        help="Minimum DeltaD (deamination rate in double-stranded DNA) value to flag as damaged (default: 0.01).\n"
             "This refers to the value in the 'DeltaD' column of Stats_out_MCMC_iter_summ_stat.csv."
    )
    parser.add_argument(
        "--bases_5p_end",
        type=int,
        default=10,
        help="Number of base positions from the 5' end to consider for C>T frequency calculation (default: 10)."
    )
    parser.add_argument(
        "--bases_3p_end",
        type=int,
        default=10,
        help="Number of base positions from the 3' end to consider for G>A frequency calculation (default: 10)."
    )
    parser.add_argument(
        "--output_all_metrics",
        type=str,
        default="all_sample_damage_metrics.csv",
        help="Filename for the CSV output containing metrics for all processed samples (default: all_sample_damage_metrics.csv)."
    )
    parser.add_argument(
        "--output_damaged_list",
        type=str,
        default="damaged_samples_list.csv",
        help="Filename for the CSV output containing only the identified damaged samples (default: damaged_samples_list.csv)."
    )
    parser.add_argument(
        "--output_horizon_summary",
        type=str,
        default="sampling_horizon_damage_summary.csv",
        help="Filename for the CSV output containing the aggregated damage scores by sampling horizon (default: sampling_horizon_damage_summary.csv)."
    )

    args = parser.parse_args()

    # Run the analysis
    all_metrics_df, damaged_subset_df = analyze_mapdamage_outputs(
        args.input_directory,
        args.c_to_t_threshold,
        args.g_to_a_threshold,
        args.deltas_threshold,
        args.deltad_threshold, # Used corrected argument name
        args.bases_5p_end,
        args.bases_3p_end
    )

    if not all_metrics_df.empty:
        # Ensure sampling_horizon is integer where possible in the master dataframe
        # Use pd.Int64Dtype() to allow for NaN/None values while keeping integer type
        all_metrics_df['sampling_horizon'] = all_metrics_df['sampling_horizon'].astype(pd.Int64Dtype())
        
        # Define float columns to round dynamically
        float_cols_to_round = [
            f'C_to_T_5p_first{args.bases_5p_end}_freq',
            f'G_to_A_3p_first{args.bases_3p_end}_freq',
            'deltaS_mean',
            'deltaD_mean',
            'damage_score'
        ]
        # Apply rounding only to these specific float columns
        for col in float_cols_to_round:
            if col in all_metrics_df.columns and pd.api.types.is_float_dtype(all_metrics_df[col]):
                all_metrics_df[col] = all_metrics_df[col].round(5)

        # Save all metrics to a CSV
        output_all_path = os.path.join(os.getcwd(), args.output_all_metrics)
        all_metrics_df.to_csv(output_all_path, index=False)
        print(f"\nAll sample damage metrics saved to: {output_all_path}")
        print(f"Total samples processed: {len(all_metrics_df)}")

        # --- Aggregate by Sampling Horizon ---
        # Filter out samples that don't have a valid sampling horizon (None in 'sampling_horizon' column)
        horizon_df = all_metrics_df[all_metrics_df['sampling_horizon'].notna()].copy()

        if not horizon_df.empty:
            # Group by sampling horizon and calculate mean, std, and count of damage scores
            horizon_summary = horizon_df.groupby('sampling_horizon').agg(
                mean_damage_score=('damage_score', 'mean'),
                std_damage_score=('damage_score', 'std'),
                mean_deltaS=('deltaS_mean', 'mean'),   # Added DeltaS stats
                std_deltaS=('deltaS_mean', 'std'),     # Added DeltaS stats
                mean_deltaD=('deltaD_mean', 'mean'),   # Added DeltaD stats
                std_deltaD=('deltaD_mean', 'std'),     # Added DeltaD stats
                num_samples=('sample_id', 'count') # Count based on sample_id for total samples in group
            ).reset_index()
            
            # Ensure the 'sampling_horizon' column in the summary is also an integer type
            horizon_summary['sampling_horizon'] = horizon_summary['sampling_horizon'].astype(pd.Int64Dtype())
            
            # Define float columns for horizon_summary to round
            float_cols_to_round_horizon = [
                'mean_damage_score', 'std_damage_score',
                'mean_deltaS', 'std_deltaS',
                'mean_deltaD', 'std_deltaD'
            ]
            # Apply rounding only to these specific float columns
            for col in float_cols_to_round_horizon:
                if col in horizon_summary.columns and pd.api.types.is_float_dtype(horizon_summary[col]):
                    horizon_summary[col] = horizon_summary[col].round(5)

            # Sort by sampling horizon for cleaner output
            horizon_summary = horizon_summary.sort_values('sampling_horizon')

            # Save the aggregated summary
            output_horizon_path = os.path.join(os.getcwd(), args.output_horizon_summary)
            horizon_summary.to_csv(output_horizon_path, index=False)
            print(f"\nSampling horizon damage summary saved to: {output_horizon_path}")
        else:
            print("\nNo samples with valid sampling horizon (1-8) found for aggregation.")

    if not damaged_subset_df.empty:
        # Ensure sampling_horizon is integer where possible in the damaged_subset_df dataframe
        damaged_subset_df['sampling_horizon'] = damaged_subset_df['sampling_horizon'].astype(pd.Int64Dtype())
        
        # Apply rounding only to the specific float columns (using the same list from all_metrics_df)
        # Note: damaged_subset_df will have the same float columns as all_metrics_df
        for col in float_cols_to_round: # Reusing the list
            if col in damaged_subset_df.columns and pd.api.types.is_float_dtype(damaged_subset_df[col]):
                damaged_subset_df[col] = damaged_subset_df[col].round(5)

        # Save the list of damaged samples to a CSV
        output_damaged_path = os.path.join(os.getcwd(), args.output_damaged_list)
        damaged_subset_df.to_csv(output_damaged_path, index=False)
        print(f"\nIdentified {len(damaged_subset_df)} damaged samples.")
        print(f"List of damaged samples saved to: {output_damaged_path}")
    else:
        print("\nNo samples met the criteria for significant damage.")

    print("\nAnalysis complete.")
