#!/usr/bin/env python3
# -*- coding: utf-8 -*-



import os
import sys
import argparse
import re
import pandas as pd
import numpy as np
import matplotlib; matplotlib.use('agg')
import matplotlib.pyplot as plt;

from Bio import SeqIO,SeqRecord
from pathlib import Path
from tqdm import tqdm




def contigs_per_tax(bin_fasta,df):
    taxo_dict={}
    data = df.transpose().to_dict()
    for taxo in list(set(list(df["lineage"]))):
            taxo_dict[taxo]=[]
    bin_total_length = 0
    for record in SeqIO.parse(bin_fasta, "fasta"):
        try:
            contigs_taxo = df.loc[record.id].lineage
            record.description = str(record.description)+" taxonomy="+str(contigs_taxo)
            taxo_dict[contigs_taxo].append(record)
            bin_total_length += len(record.seq)
        except:
            raise ValueError("Contigs %s from %s not found in %s" % ( record.id , bin_fasta , "contigs summary"))
    
    return taxo_dict, bin_total_length


def filteroutcontigs(bin_fasta,df,out):
    per_taxa , bin_total_length = contigs_per_tax(bin_fasta,df)

    counts_per_taxa = [len(x) for x in list(per_taxa.values())]
    n_contigs = sum(counts_per_taxa)
    counts_per_taxa_dict = {}
    cpt=0
    for i in per_taxa.keys():
        counts_per_taxa_dict[i]=counts_per_taxa[cpt]
        cpt+=1
    records = per_taxa[list(per_taxa.keys())[counts_per_taxa.index(max(counts_per_taxa))]]
    

    taxonomy = list(per_taxa.keys())[counts_per_taxa.index(max(counts_per_taxa))]


    for taxo in per_taxa.keys():
        if re.match(str(taxonomy),str(taxo)):
            records+per_taxa.get(taxo)
    
    
    SeqIO.write(records,out,"fasta-2line")
    
    #n_contigs=len(SeqIO.index(bin_fasta,"fasta"))

    n_contigs_maj_taxo = max(counts_per_taxa)
    
    filename = ".".join(os.path.basename(bin_fasta).split(".")[0:-1])
    contigsout=(n_contigs - len(records))


    return {filename : {"composition":counts_per_taxa_dict,
                        "length": bin_total_length,
                        "contigs":n_contigs,
                        "contigs_out":contigsout,
                        "contigs_consensus": n_contigs_maj_taxo,
                        "contamination":((n_contigs-contigsout)/n_contigs)*100,
                        "taxonomy":taxonomy
                        }}


def plot_and_stats(bin_dict,out):
        os.makedirs(out, exist_ok=True) 
        bin_df = pd.DataFrame(bin_dict).transpose()
    
        bin_contamination = pd.DataFrame(bin_df["composition"].to_dict()).transpose()

        #save some stats
        bin_stats = bin_df[['length', 'contigs', 'contigs_out', 'contigs_consensus',
        'contamination', 'taxonomy']]
        bin_stats.to_csv(os.path.join(out,"bin_stats.csv"),sep="\t",header=True,index=True)
        
        #save stacked barchart
        bin_contamination.plot(kind="barh", stacked=True , legend=False, fontsize = 3)
        plt.savefig(os.path.join(out,"bins_contam.pdf"))

        bin_stats.plot.scatter(x='contigs',
                            y='contamination',
                            c='DarkBlue',rot=45)
        plt.savefig(os.path.join(out,"bins_Ncontigs2contam.pdf"))    
        bin_stats.plot.scatter(x='length',
                            y='contamination',
                            c='DarkBlue',rot=45)
        plt.savefig(os.path.join(out,"bins_length2contam.pdf"))
        bin_stats.plot.scatter(x='contigs',
                            y='contigs_out',
                            c='DarkBlue',rot=45)
        plt.savefig(os.path.join(out,"bins_length2contam.pdf"))




if __name__ == "__main__":
    try :
        fasta = str(snakemake.params.bins)
        allcontigs = str(snakemake.input[1])
        outdir = str(snakemake.output)
        outdir_stats = str(snakemake.params.stats)
        sep = "\t"
    except NameError:
        parser = argparse.ArgumentParser(
            prog='filtermags',
            description='filter out contigs from bin for which assignment is not relevant')
        parser.add_argument(
            'bin', type=str,
            help='(required) bin file or bins directory (fasta format) ')
        parser.add_argument(
            'contigs', type=str,
            help='(required) contigs assignement file (contigs id : contigs taxon)')
        parser.add_argument(
            '--sep', type=str,
            default = "\t",
            help='(optional) contigs file separator)')
        parser.add_argument(
            '-o','--out', type=str,
            default=None,
            help='(optional) Output directory where filtered bin will be saved, default stdout')
        parser.add_argument(
            '-s','--stats', type=str, default=None,
            help='(if <bin> is a directory) , directory where bin statistics and plot will be saved (length, #contigs ...)')
        args = parser.parse_args()

        fasta=args.bin
        allcontigs=args.contigs
        outdir = args.out
        sep = args.sep
        outdir_stats = args.stats

    
    outfile = sys.stdout

    #load contigs file : 
    df = pd.read_csv(allcontigs,header=0,sep=sep,index_col=0)

    if os.path.isdir(fasta):
        files = [os.path.join(fasta,i) for i in os.listdir(fasta)]
    else:
        files = [fasta]
    
    os.makedirs(outdir, exist_ok=True) 
    
    data=[]
    bin_dict = {}
    cpt = 0
    for f_mags in tqdm(files):
        cpt += 1
        if outdir is not None:
            outfile = os.path.join(outdir,"{}_filtered.fa".format(".".join(os.path.basename(f_mags).split(".")[0:-1])))
        filter_stats_dict = filteroutcontigs(f_mags,df,outfile)
        bin_dict.update(filter_stats_dict)

    if outdir_stats is not None:
        plot_and_stats(bin_dict, outdir_stats)

    