#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import pysam
import pysamstats
import numpy as np
import argparse
import sys
import collections
import pandas as pd
import gffutils
import sqlite3


class Sample(pysam.AlignmentFile):
    """Extension of the AlignmentFile class of pysam.

    Contains methods to compute coverage statistics for each reference
    present in the AlignmentFile (Nils Giordano) 
    """

    def reference_coverage(
        self,
        out=None,
        gff=None,
        min_baseq=13,
        N=1,
        only_covered=False,
        **kwargs):
        """Count base coverage for each contig present in the index.

        Any extra argument will be passed to `pysamstats.stat_coverage()`.

        Parameters
        ----------
        out : str (default none)
            file to store results as a tabular
            
        gff : str (default None)
            Location of the gff file containing the genomic regions for which
            the coverage will be computed. If None, coverage will be computed
            for each contig.

            Note: this can be a concatenation of gff from different genomes.
            It is way faster than launching metacovest for every single gff.

        min_baseq : int (default 13)
            Only reads with base quality equal to or greater than this value
            will be counted.

            Warning: If you put this value to 0, overlapping paired reads
            will be double-counted in the region of the overlap.
            See https://github.com/alimanfoo/pysamstats/issues/98

        N : int (default 1)
            Minimum coverage for a base to count for the `base_covered`
            statistics.

        only_covered : bool (default False)
            Only include a feature (contig or genomic region) if its coverage
            is not null. This could save a lot of memory.

        Returns
        -------
        coverage_dict : dict
            Nested dictionary containing coverage information for each region.

            {
                region_name: {
                    "length": uint,
                    "reads": uint,
                    "reads_properly_paired": uint,
                    "total_coverage": uint,
                    "total_coverage_properly_paired": uint,
                    "bases_covered": uint,
                    "bases_covered_properly_paired": uint,
                    "breath_coverage": uint,
                    "breath_coverage_properly_paired": uint,
                    "depth_avg_coverage": uint,
                    "depth_avg_coverage_properly_paired": uint,
                }
            }

            length: the contig length in bases
            reads: number of reads mapping on the contig
            total_coverage: sum of base coverage depth
            bases_covered: number of bases covered at least N times
            breath_coverage: percentage of contig lenth covered by reads
            depth_avg_coverage: mean depth coverage across the contigs.
            *_properly_paired: same as above with properly paired reads only

            Note: When providing a gff file, region_name will be of the form
            `contigname_N` with N an increment (Nth region in contig). This
            mimics the way Prodigal names predicted proteins.
            If no gff file was provided, region_name will be the contig
            name (default behavior).

        """

        

        # Initiate coverage dictionary and compute length dictionary 
        coverage_dict = {}
        len_dict = dict(zip(self.references, self.lengths))

        # Initiate final statistics dictionary as defaultdict
        # We will populate this dict either with contig or gene coverage
        covstats_dict = collections.defaultdict(
            lambda: {
                'length': 0,
                'reads': 0,
                'bases_covered': 0,
                'total_coverage': 0,
                'reads_properly_paired': 0,
                'bases_covered_properly_paired': 0,
                'total_coverage_properly_paired': 0,
                'breath_coverage': 0,
                'breath_coverage_properly_paired': 0,
                'depth_avg_coverage': 0,
                'depth_avg_coverage_properly_paired': 0
            }
        )

        # Iterator over every positions in the BAM file
        bp_iterator = pysamstats.stat_coverage(
            self,
            min_baseq=min_baseq,
            one_based=False,         # Important since np.array use O-indexing
            **kwargs,
        )

        
        

        # Iterate over every position in the BAM file
        for bp in bp_iterator:
            # Initialize coverage array if not already present in dict
            
            if bp['chrom'] not in coverage_dict:
                coverage_dict[bp['chrom']] = np.zeros(
                    (len_dict[bp['chrom']], 2), dtype=np.uint32
                )
    
            # Update coverage at the position for all reads (first column)
            
            coverage_dict[bp['chrom']][bp['pos'], 0] = bp['reads_all']
            # Update coverage for properly-paired reads only (second column)
            
            coverage_dict[bp['chrom']][bp['pos'], 1] = bp['reads_pp']


            
        # At this point we have a dict containing coverage for contigs
        # with at least 1 mapping read in the BAM

        
        # If no gff provided, we just compute stats for every contig
        if gff is None:
            for contig, covarray in coverage_dict.items():
                # bases covered at least N times
                covstats_dict[contig]['bases_covered'] = (
                    covarray[:, 0] >= N).sum()
                # bases covered at least N times by properly-paired reads
                covstats_dict[contig]['bases_covered_properly_paired'] = (
                    covarray[:, 1] >= N).sum()
                # total base coverage
                covstats_dict[contig]['total_coverage'] = (
                    covarray[:, 0]).sum()
                # total base coverage
                covstats_dict[contig]['total_coverage_properly_paired'] = (
                    covarray[:, 1]).sum()
                # Count total number of reads
                covstats_dict[contig]['reads'] = self.count(
                    contig=contig,
                )
                # Count total number of properly-paired reads
                covstats_dict[contig]['reads_properly_paired'] = self.count(
                    contig=contig,
                    read_callback=lambda read: read.is_proper_pair,
                )
                covstats_dict[contig]['length'] = len_dict.get(contig)
                covstats_dict[contig]["breath_coverage"] = covstats_dict[contig]["bases_covered"]/covstats_dict[contig]["length"] * 100
                covstats_dict[contig]["breath_coverage_properly_paired"] = covstats_dict[contig]["bases_covered_properly_paired"]/covstats_dict[contig]["length"] * 100 
                covstats_dict[contig]["depth_avg_coverage"] = covstats_dict[contig]["total_coverage"]/covstats_dict[contig]["length"]
                covstats_dict[contig]["depth_avg_coverage_properly_paired"] = covstats_dict[contig]["total_coverage_properly_paired"]/covstats_dict[contig]["length"]
                
                
            if not only_covered:
                # Load length values and also create entries for each remaining
                # contig (expected to have null coverage)
                for contig, length in len_dict.items():
                    # Update entry with length, but also initiate missing entry
                    # because covstats_dict is a defaultdict
                    covstats_dict[contig]['length'] = length
        else:   # gff is not None, so we compute stats for every gff region
            # Create gff DB (necessary for gffutils), will be deleted later
            # It generates a gffutils id similar to the prodigal protein name
            try:
                db = gffutils.create_db(
                    gff,
                    gff + '.db',
                    id_spec=lambda d: 'autoincrement:' + d.seqid,
                    merge_strategy='error',
                )
            except sqlite3.OperationalError:
                print('{} already exists, resuming...'.format(gff + '.db'))
                db = gffutils.FeatureDB(gff + '.db')

            # Iterate over every region and compute coverage
            for region in db.all_features():
                # Quick test to check if the names are coherent
                if (
                    region.attributes['ID'][0].split('_')[-1] !=
                    region.id.split('_')[-1]
                ):
                    raise NameError(
                        'Something went wrong when loading the gff file. '
                        'The generated feature.id might not match with '
                        'the protein name used by Prodigal. Aborting. '
                        'This is very peculiar and should be replaced by '
                        'unittests later. '
                    )

                # Slice coverage array, beware gffutils use 1-based position
                if region.seqid not in coverage_dict:
                    if only_covered:
                        # we do not include this region in the dictionnary
                        continue
                    else:
                        # If seqid not in dictionnary, coverage of region is 0
                        # We use a size 2 array to save memory
                        covarray = np.zeros((2, 2))
                else:
                    covarray = coverage_dict[region.seqid][
                        region.start-1:region.end-1, :]

                # Populate stats dictionnary

                # Compute length from start and stop positions
                covstats_dict[region.id]['length'] = (
                    region.stop - region.start + 1)
                # bases covered at least N times
                covstats_dict[region.id]['bases_covered'] = (
                    covarray[:, 0] >= N).sum()
                # bases covered at least N times by properly-paired reads
                covstats_dict[region.id]['bases_covered_properly_paired'] = (
                    covarray[:, 1] >= N).sum()
                # total base coverage
                covstats_dict[region.id]['total_coverage'] = (
                    covarray[:, 0]).sum()
                # total base coverage
                covstats_dict[region.id]['total_coverage_properly_paired'] = (
                    covarray[:, 1]).sum()
                # Count total number of reads
                covstats_dict[region.id]['reads'] = self.count(
                    contig=region.seqid,
                    start=region.start-1,
                    stop=region.end-1,
                )
                # Count total number of properly-paired reads
                covstats_dict[region.id]['reads_properly_paired'] = self.count(
                    contig=region.seqid,
                    start=region.start-1,
                    stop=region.end-1,
                    read_callback=lambda read: read.is_proper_pair,
                )
                covstats_dict[region.id]["breath_coverage"] = covstats_dict[region.id]["bases_covered"]/covstats_dict[region.id]["length"] * 100
                covstats_dict[region.id]["breath_coverage_properly_paired"] = covstats_dict[region.id]["bases_covered_properly_paired"]/covstats_dict[region.id]["length"] * 100 
                covstats_dict[region.id]["depth_avg_coverage"] = covstats_dict[region.id]["total_coverage"]/covstats_dict[region.id]["length"]
                covstats_dict[region.id]["depth_avg_coverage_properly_paired"] = covstats_dict[region.id]["total_coverage_properly_paired"]/covstats_dict[region.id]["length"]
               

        # Go back to classic dict structure (safer for Exception handling)
        covstats_dict = dict(covstats_dict)
        if out is not None:
            df = pd.DataFrame(coverage_dict)
            df=df.transpose()
            df.to_csv(out)
        # Return resulting dictionary
        return covstats_dict
    
    

    
if __name__ == '__main__':
    try :
        bam = snakemake.input.align
        gff = None
        sep = "\t"

        outfile = snakemake.output.stats

        only_covered = snakemake.params.only_covered
        min_baseq = snakemake.params.min_baseq
        depth_per_base = snakemake.params.depth_per_base


    except NameError:
        # parsing script arguments
        parser = argparse.ArgumentParser(
            prog='metacovest',
            description='Estimate coverage depth and breadth from indexed '
            'BAM file and filter reference assembly.')
        parser.add_argument(
            'infile', type=str,
            help='input indexed BAM file')
        parser.add_argument(
            '-g', '--gff', nargs='?', type=str,
            default=None,
            help='(Optional) input gff file for coverage over some regions')
        parser.add_argument(
            '--outfile', nargs='?',
            default=sys.stdout,
            help='statistics csv outfile (default: stdout)')
        parser.add_argument(
            '-s', '--sep', nargs='?', type=str,
            default="\t",
            help='separator for csv file (default: tab)')
        parser.add_argument(
            '-q', '--mapping-quality', nargs='?', type=int,
            default=20,
            help='minimum mapping quality of counted reads (default: 20)')
        parser.add_argument(
            '--depth-per-base', nargs='?', type=int,
            default=20,
            help='minimum depth per base')
        parser.add_argument(
            '--no-label', action='store_true',
            help='remove column label from output')
        parser.add_argument(
            '--only-covered', action='store_true',
            help='remove features never covered from output')
        
        args = parser.parse_args()

        bam = args.infile
        gff= args.gff
        sep = args.sep
        outfile = args.outfile

        only_covered = args.only_covered
        min_baseq = args.mapping_quality
        depth_per_base = args.depth_per_base
        
    
    outdir = os.path.dirname(outfile)
    
    os.makedirs(outdir,exist_ok=True)

    with Sample(bam, 'rb') as sample:
        # Parse file to get coverage for each contig
        coverage_dict = sample.reference_coverage(
            gff=gff,
            min_mapq=min_baseq,
            only_covered=only_covered,
        )
        # Convert into a DataFrame
        coverage_df = pd.DataFrame.from_dict(coverage_dict, orient='index')
        # Sort by contig name
        coverage_df.sort_index(inplace=True)
        # Save to a csv file
        coverage_df.to_csv(
            outfile,
            sep=sep,
            na_rep='NA'
        )