#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import sys
import os
import pandas as pd
from metacovest import Sample


"""
{ genome_id :
        {
      length : O,
      reads : 0,
      reads_properly_paired : 0,
      total_coverage:0,
      total_coverage_properly_paired': 0
      abundance: 0,
      abundance_properly_paired:0,
        }
    }
"""

def generate_tables(files, min_baseq, dpb , only_covered , properly_paired, is_genomes):
    """
    Analyse a BAM file through the Sample object from metacovest.
    """
    #dictionaries for each sample : keys = contigs
    sample_reads_counts = {}
    sample_coverage_sum = {}
    sample_horizontal_coverage = {}

    #number of reads mapped to the contig
    reads = "reads"
    #total vertical coverage of the contig
    total_coverage = "total_coverage"
    #if dpb = 1, then "bases_covered" represents the horizontal coverage
    bases_covered = "bases_covered"
    if properly_paired:
        #same, but only with properly_paired reads
        reads = "reads_properly_paired"
        total_coverage = "total_coverage_properly_paired"
        bases_covered = "bases_covered_properly_paired"
    print("files :", files, "type", type(files))
    for f in files:

        print("f", f, "type", type(f))
        sampleid = os.path.basename(f).split(".")[0]
        print("sampleid", sampleid, "type", type(sampleid))
        samfile = Sample(f, "rb")

        contigstats = samfile.reference_coverage(min_baseq=min_baseq, N=dpb , only_covered=only_covered)

        reads_count = {}
        coverage_sum = {}
        horizontal_coverage = {}

        for contig in contigstats.keys():
            ref = contig
            if is_genomes:
                contig = contig.split('|')[0]
            if contig not in reads_count:
                reads_count[contig] = 0
                coverage_sum[contig] = 0
                horizontal_coverage[contig] = 0
            else:
                reads_count[contig] += contigstats[ref][reads]
                coverage_sum[contig] += contigstats[ref][total_coverage]
                horizontal_coverage[contig] += contigstats[ref][bases_covered]

        sample_reads_counts[sampleid] = reads_count
        sample_coverage_sum[sampleid] = coverage_sum
        sample_horizontal_coverage[sampleid] = horizontal_coverage

    return sample_reads_counts, sample_coverage_sum, sample_horizontal_coverage


if __name__ == "__main__":
    split_id = True
    try :
        bams = snakemake.input
        f_bams = None
        read_count_output = snakemake.output.reads_count
        vertical_coverage_output = snakemake.output.bp_covered
        horizontal_coverage_output = snakemake.output.horizontal_coverage
        sep = "\t"
        min_baseq = snakemake.params.min_baseq
        dpb = snakemake.params.min_depth_bp
        only_covered = snakemake.params.only_covered
        properly_paired = snakemake.params.properly_paired
        is_genomes = snakemake.params.is_genomes
        print(bams)
    except:
        # parsing script arguments
        parser = argparse.ArgumentParser(
            prog='generate_raw_table',
            description='generate reads counts and sum(bp_covered) per regions from bam file.'
        )
        parser.add_argument(
            '-i', type=str, nargs='+',
            help='list of indexed BAM file'
        )
        parser.add_argument(
            '-f', type=str,
            default = None,
            help='file containing path to each bam files (mutually exclusive with -i)'
        )
        parser.add_argument(
            '--reads_count', nargs='?', type=argparse.FileType('w'),
            default=sys.stdout,
            help='read count table outfile (default: stdout)'
        )
        parser.add_argument(
            '--bp_cov', nargs='?', type=argparse.FileType('w'),
            default=sys.stdout,
            help='base covered count table outfile (default: stdout)'
        )
        parser.add_argument(
            '--hor_cov', nargs='?', type=argparse.FileType('w'),
            default=sys.stdout,
            help='horizontal coverage table outfile (default: stdout)'
        )
        parser.add_argument(
            '--genomes', action='store_true',
            default=False,
            help='if true, regions id will be split on |  (genome_id|contigs_id) and counts will be summed by genome_id'
        )
        parser.add_argument(
            '-s', '--sep', nargs='?', type=str,
            default="\t",
            help='separator for csv file (default: tab)'
        )
        parser.add_argument(
            '-q', '--mapping-quality', nargs='?', type=int,
            default=20,
            help='minimum mapping quality of counted reads (default: 20)'
        )
        parser.add_argument(
            '--depth-per-base', nargs='?', type=int,
            default= 1 ,
            help='minimum depth per base (default : 1)'
        )
        parser.add_argument(
            '--only-covered', action='store_true',
            help='remove features never covered from output'
        )
        parser.add_argument(
                '-p','--properly-paired', action='store_true',
                help='properly paired reads statistics are used for filtering '
        )
        args = parser.parse_args()


        bams = args.i
        f_bams = args.f
        reads_count_output = args.reads_count
        coverage_output = args.bp_cov
        horizontal_coverage_output = args.hor_cov
        sep = args.sep
        min_baseq = args.mapping_quality
        dpb = args.depth_per_base
        only_covered = args.only_covered
        properly_paired = args.properly_paired
        is_genomes = args.genomes


    if f_bams is not None and bams is not None:
        raise ValueError ("input list and input file are mutually exclusive")
    elif f_bams is None and bams is None:
        raise ValueError ("input list ({}) or input file ({}) is mandatory".format(bams,f_bams))


    infiles=[]
    if f_bams is not None:
        #open f and store in list
        try:
            print("opening {}".format(f_bams))
            with open(f_bams) as f:
                infiles = f.read().splitlines()
        except OSError as err:
            print("OS error: {0}".format(err))
            raise
        except TypeError as err:
            print("Type erro {0}".format(err))
            raise
    else:
        infiles.extend(bams)


    reads_count, bp_covered_count, horizontal_coverage = generate_tables(infiles, min_baseq, dpb , only_covered , properly_paired , is_genomes)
    #print(reads_count)
    #print(bp_covered_count)
    #print(horizontal_coverage)
    rc_table = pd.DataFrame(reads_count)
    bpcc_table = pd.DataFrame(bp_covered_count)
    hor_cov_table = pd.DataFrame(horizontal_coverage)

    rc_table.to_csv(read_count_output, sep="\t")
    bpcc_table.to_csv(vertical_coverage_output, sep="\t")
    hor_cov_table.to_csv(horizontal_coverage_output, sep="\t")
