#!/usr/bin/env python3

import pandas as pd
import os
import argparse
import numpy as np

# --- Helper function to extract sampling horizon ---
def get_sampling_horizon(sample_id):
    """
    Extracts the sampling horizon from a sample ID (directory name).
    Assumes horizon is the first digit (1-8) of the directory name.
    Returns an integer for valid horizons, None otherwise.
    """
    if len(sample_id) > 0 and sample_id[0].isdigit():
        horizon_digit = int(sample_id[0])
        if 1 <= horizon_digit <= 8:
            return horizon_digit
    return None

# --- Main analysis function (simplified for DamageProfiler) ---
def analyze_damageprofiler_outputs(input_dir, c_to_t_threshold, g_to_a_threshold, bases_5p_end, bases_3p_end):
    """
    Analyzes DamageProfiler output files (misincorporation.txt) to identify samples exhibiting DNA damage.
    Calculates misincorporation frequencies across a specified number of positions at each end,
    and derives a combined 'damage_score' based solely on C>T and G>A frequencies.

    Args:
        input_dir (str): Path to the directory containing individual DamageProfiler run directories.
                         Each run should be in its own subdirectory (e.g., /input_dir/sample_id/).
        c_to_t_threshold (float): Minimum C>T frequency at 5' first N positions to be considered damaged.
        g_to_a_threshold (float): Minimum G>A frequency at 3' first N positions to be considered damaged.
        bases_5p_end (int): Number of bases from the 5' end to consider for C>T frequency.
        bases_3p_end (int): Number of base positions from the 3' end to consider for G>A frequency.

    Returns:
        pandas.DataFrame: A DataFrame containing metrics for all processed samples,
                          and a subset DataFrame of identified damaged samples.
    """

    # Ensure input_dir is absolute for robust pathing
    input_dir = os.path.abspath(input_dir)

    sample_dirs = [d for d in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, d))]

    if not sample_dirs:
        print(f"No sample directories found in '{input_dir}'. Please ensure each run is in a subdirectory.")
        return pd.DataFrame(), pd.DataFrame()

    all_sample_metrics = []

    print(f"Starting analysis of {len(sample_dirs)} samples from '{input_dir}'...")

    for sample_id in sorted(sample_dirs): # Sort for consistent output
        base_sample_dir = os.path.join(input_dir, sample_id)
        
        # DamageProfiler typically puts misincorporation.txt directly in the sample directory.
        misincorporation_file = os.path.join(base_sample_dir, "misincorporation.txt")

        if not os.path.exists(misincorporation_file):
            print(f"Error: misincorporation.txt not found in '{base_sample_dir}' for sample '{sample_id}'. Skipping this sample.")
            # Add this sample with None values for metrics, but try to get horizon
            all_sample_metrics.append({
                'sample_id': sample_id,
                f'C_to_T_5p_first{bases_5p_end}_freq': None,
                f'G_to_A_3p_first{bases_3p_end}_freq': None,
                'damage_score': None,
                'sampling_horizon': get_sampling_horizon(sample_id) 
            })
            continue # Skip to next sample in the loop

        # Initialize metrics for the current sample
        c_to_t_5p_freq = None
        g_to_a_3p_freq = None
        damage_score = None 

        # --- Process misincorporation.txt ---
        try: 
            # Added comment='#' to ignore lines starting with '#'
            df_mis = pd.read_csv(misincorporation_file, sep='\t', engine='python', comment='#')
            df_mis.columns = df_mis.columns.str.strip() # Clean column names
            
            if 'End' in df_mis.columns:
                df_mis['End'] = df_mis['End'].astype(str).str.strip()
            
            if 'Pos' in df_mis.columns:
                df_mis['Pos'] = pd.to_numeric(df_mis['Pos'], errors='coerce')
                df_mis = df_mis.dropna(subset=['Pos']) # Drop rows where 'Pos' could not be converted
                df_mis['Pos'] = df_mis['Pos'].astype(int)

            # Calculate C>T 5' End Frequency
            ct_5p_data = df_mis[(df_mis['Pos'] >= 1) & (df_mis['Pos'] <= bases_5p_end) & (df_mis['End'] == '5p')]
            if not ct_5p_data.empty and 'C>T' in ct_5p_data.columns and 'Total' in ct_5p_data.columns:
                total_ct_counts = ct_5p_data['C>T'].sum()
                total_bases_at_5p_range = ct_5p_data['Total'].sum()
                if total_bases_at_5p_range > 0:
                    c_to_t_5p_freq = total_ct_counts / total_bases_at_5p_range

            # Calculate G>A 3' End Frequency
            ga_3p_data = df_mis[(df_mis['Pos'] >= 1) & (df_mis['Pos'] <= bases_3p_end) & (df_mis['End'] == '3p')]
            if not ga_3p_data.empty and 'G>A' in ga_3p_data.columns and 'Total' in ga_3p_data.columns:
                total_ga_counts = ga_3p_data['G>A'].sum()
                total_bases_at_3p_range = ga_3p_data['Total'].sum()
                if total_bases_at_3p_range > 0:
                    g_to_a_3p_freq = total_ga_counts / total_bases_at_3p_range
                
        except pd.errors.EmptyDataError:
            print(f"Warning: {misincorporation_file} is empty for {sample_id}. C>T/G>A metrics will be None.")
        except KeyError as e:
            print(f"Warning: Missing expected column or invalid data in {misincorporation_file} for {sample_id}: {e}. C>T/G>A metrics will be None.")
        except Exception as e:
            print(f"Error processing {misincorporation_file} for {sample_id}: {e}. C>T/G>A metrics will be None.")

        # --- Calculate damage_score ---
        # Treat None values as 0 for the summation to get a score even with partial data.
        score_components = [
            c_to_t_5p_freq if c_to_t_5p_freq is not None else 0,
            g_to_a_3p_freq if g_to_a_3p_freq is not None else 0
        ]
        
        # Calculate damage_score only if at least one original metric was valid (not None)
        if any(x is not None for x in [c_to_t_5p_freq, g_to_a_3p_freq]):
             damage_score = sum(score_components)
        # Otherwise, damage_score remains None (initialized as None)

        all_sample_metrics.append({
            'sample_id': sample_id,
            f'C_to_T_5p_first{bases_5p_end}_freq': c_to_t_5p_freq,
            f'G_to_A_3p_first{bases_3p_end}_freq': g_to_a_3p_freq,
            'damage_score': damage_score,
            'sampling_horizon': get_sampling_horizon(sample_id) 
        })

    master_df = pd.DataFrame(all_sample_metrics)

    # Filter for damaged samples based on thresholds
    c_to_t_col = f'C_to_T_5p_first{bases_5p_end}_freq'
    g_to_a_col = f'G_to_A_3p_first{bases_3p_end}_freq'

    # Filtering now only depends on C>T and G>A frequencies
    damaged_samples_subset = master_df[
        (master_df[c_to_t_col].notna()) & (master_df[c_to_t_col] >= c_to_t_threshold) &
        (master_df[g_to_a_col].notna()) & (master_df[g_to_a_col] >= g_to_a_threshold)
    ].copy() # .copy() to avoid SettingWithCopyWarning in future operations

    return master_df, damaged_samples_subset

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Identify DNA damaged samples from DamageProfiler output directories.",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument(
        "input_directory",
        help="Path to the parent directory containing all DamageProfiler run subdirectories (e.g., /path/to/my_damageprofiler_runs/).\n"
             "Each DamageProfiler output for a sample should be in a directory like: /path/to/my_damageprofiler_runs/TAXON_ID/."
    )
    parser.add_argument(
        "--c_to_t_threshold",
        type=float,
        default=0.10,
        help="Minimum C>T frequency at 5' end, first N positions to flag as damaged (default: 0.10, i.e., 10%%)."
    )
    parser.add_argument(
        "--g_to_a_threshold",
        type=float,
        default=0.05,
        help="Minimum G>A frequency at 3' end, first N positions to flag as damaged (default: 0.05, i.e., 5%%)."
    )
    # Removed --deltas_threshold and --deltad_threshold as they are not applicable for DamageProfiler.
    parser.add_argument(
        "--bases_5p_end",
        type=int,
        default=10,
        help="Number of base positions from the 5' end to consider for C>T frequency calculation (default: 10)."
    )
    parser.add_argument(
        "--bases_3p_end",
        type=int,
        default=10,
        help="Number of base positions from the 3' end to consider for G>A frequency calculation (default: 10)."
    )
    parser.add_argument(
        "--output_all_metrics",
        type=str,
        default="all_sample_damage_metrics.csv",
        help="Filename for the CSV output containing metrics for all processed samples (default: all_sample_damage_metrics.csv)."
    )
    parser.add_argument(
        "--output_damaged_list",
        type=str,
        default="damaged_samples_list.csv",
        help="Filename for the CSV output containing only the identified damaged samples (default: damaged_samples_list.csv)."
    )
    parser.add_argument(
        "--output_horizon_summary",
        type=str,
        default="sampling_horizon_damage_summary.csv",
        help="Filename for the CSV output containing the aggregated damage scores by sampling horizon (default: sampling_horizon_damage_summary.csv)."
    )

    args = parser.parse_args()

    # Run the analysis (simplified call)
    all_metrics_df, damaged_subset_df = analyze_damageprofiler_outputs(
        args.input_directory,
        args.c_to_t_threshold,
        args.g_to_a_threshold,
        args.bases_5p_end,
        args.bases_3p_end
    )

    if not all_metrics_df.empty:
        # Ensure sampling_horizon is integer where possible in the master dataframe
        all_metrics_df['sampling_horizon'] = all_metrics_df['sampling_horizon'].astype(pd.Int64Dtype())
        
        # Define float columns to round dynamically (no DeltaS/DeltaD)
        float_cols_to_round = [
            f'C_to_T_5p_first{args.bases_5p_end}_freq',
            f'G_to_A_3p_first{args.bases_3p_end}_freq',
            'damage_score'
        ]
        # Apply rounding only to these specific float columns
        for col in float_cols_to_round:
            if col in all_metrics_df.columns and pd.api.types.is_float_dtype(all_metrics_df[col]):
                all_metrics_df[col] = all_metrics_df[col].round(5)

        # Save all metrics to a CSV
        output_all_path = os.path.join(os.getcwd(), args.output_all_metrics)
        all_metrics_df.to_csv(output_all_path, index=False)
        print(f"\nAll sample damage metrics saved to: {output_all_path}")
        print(f"Total samples processed: {len(all_metrics_df)}")

        # --- Aggregate by Sampling Horizon ---
        # Filter out samples that don't have a valid sampling horizon (None in 'sampling_horizon' column)
        horizon_df = all_metrics_df[all_metrics_df['sampling_horizon'].notna()].copy()

        if not horizon_df.empty:
            # Group by sampling horizon and calculate mean, std, and count of damage scores
            horizon_summary = horizon_df.groupby('sampling_horizon').agg(
                mean_damage_score=('damage_score', 'mean'),
                std_damage_score=('damage_score', 'std'),
                num_samples=('sample_id', 'count') # Count based on sample_id for total samples in group
            ).reset_index()
            
            # Ensure the 'sampling_horizon' column in the summary is also an integer type
            horizon_summary['sampling_horizon'] = horizon_summary['sampling_horizon'].astype(pd.Int64Dtype())
            
            # Define float columns for horizon_summary to round (no DeltaS/DeltaD stats)
            float_cols_to_round_horizon = [
                'mean_damage_score', 'std_damage_score'
            ]
            # Apply rounding only to these specific float columns
            for col in float_cols_to_round_horizon:
                if col in horizon_summary.columns and pd.api.types.is_float_dtype(horizon_summary[col]):
                    horizon_summary[col] = horizon_summary[col].round(5)

            # Sort by sampling horizon for cleaner output
            horizon_summary = horizon_summary.sort_values('sampling_horizon')

            # Save the aggregated summary
            output_horizon_path = os.path.join(os.getcwd(), args.output_horizon_summary)
            horizon_summary.to_csv(output_horizon_path, index=False)
            print(f"\nSampling horizon damage summary saved to: {output_horizon_path}")
        else:
            print("\nNo samples with valid sampling horizon (1-8) found for aggregation.")

    if not damaged_subset_df.empty:
        # Ensure sampling_horizon is integer where possible in the damaged_subset_df dataframe
        damaged_subset_df['sampling_horizon'] = damaged_subset_df['sampling_horizon'].astype(pd.Int64Dtype())
        
        # Apply rounding only to the specific float columns (using the same list from all_metrics_df)
        for col in float_cols_to_round: # Reusing the list
            if col in damaged_subset_df.columns and pd.api.types.is_float_dtype(damaged_subset_df[col]):
                damaged_subset_df[col] = damaged_subset_df[col].round(5)

        # Save the list of damaged samples to a CSV
        output_damaged_path = os.path.join(os.getcwd(), args.output_damaged_list)
        damaged_subset_df.to_csv(output_damaged_path, index=False)
        print(f"\nIdentified {len(damaged_subset_df)} damaged samples.")
        print(f"List of damaged samples saved to: {output_damaged_path}")
    else:
        print("\nNo samples met the criteria for significant damage.")

    print("\nAnalysis complete.")