import os
import snakemake.io
import glob

configfile: 'config/config.yaml'
(SAMPLES,READS,) = glob_wildcards(os.path.join(config['hic_dir'],"{sample}_{read}.fastq.gz"))
SAMPLES=set(SAMPLES)
READS=set(READS)
print(SAMPLES,READS)

rule all:
    input:
        mcool=config['prefix']+'.dedup.mcool',
        pretext_map=config['prefix']+'.pretext'

rule index_fasta:
    input:
        ref=config['ref']
    output:
        config['ref']+'.pac',
        config['ref']+'.0123',
        config['ref']+'.amb',
        config['ref']+'.ann',
        config['ref']+'.bwt.2bit.64',
        config['ref']+'.fai',
        config['ref']+'.genome'
    conda:
        'bwamem2.yaml'
    threads:
        1
    resources:
        mem_mb=100000
    shell:
        """
        samtools faidx {input.ref}
        awk "{{print \$1,\$2}}" {input.ref}.fai > {input.ref}.genome
        sed -i 's/ /\t/g' {input.ref}.genome
        bwa-mem2 index {input.ref}
        """

rule map_reads:
    input:
        ref=config['ref'],
        ref_idx=config['ref']+'.pac',
        r1=os.path.join(config['hic_dir'],"{sample}_1.fastq.gz"),
        r2=os.path.join(config['hic_dir'],"{sample}_2.fastq.gz")
    output:
        out_sam=os.path.join(config['hic_dir'],"{sample}.bwa.sam")
    conda:
        'bwamem2.yaml'
    threads:
        24
    resources:
        mem_mb=100000
    shell:
        """
        bwa-mem2 mem -T0 -5SP -t {threads} {input.ref} {input.r1} {input.r2} -o {output.out_sam}
        """

rule merge_sams:
    input:
        expand(os.path.join(config['hic_dir'],"{sample}.bwa.sam"),sample=SAMPLES)
    output:
        merged_sam='mapped_reads/'+config['prefix']+'.bwa.sam'
    conda:
        'bwamem2.yaml'
    threads:
        6
    resources:
        mem_mb=100000
    shell:
        """
        samtools merge -@ {threads} -o {output.merged_sam} {input}
        """

rule create_pretextMap:
    input:
        merged_sam='mapped_reads/'+config['prefix']+'.bwa.sam'
    output:
        pretext_map=config['prefix']+'.pretext'
    conda:
        'pretext.yaml'
    threads:
        1
    resources:
        mem_mb=20000
    shell:
        """
        cat {input.merged_sam} | PretextMap -o {output.pretext_map} --sortby nosort --mapq 0 --highRes
        """

rule pairtools_parse:
    input:
        merged_sam='mapped_reads/'+config['prefix']+'.bwa.sam',
        chroms=config['ref']+'.genome'
    output:
        parsed_stats='mapped_reads/'+config['prefix']+'.parsed.stats',
        parsed_pairsam='mapped_reads/'+config['prefix']+'.parsed.pairsam'
    conda:
        'pairtools.yaml'
    threads:
        24
    resources:
        mem_mb=20000
    shell:
        """
        pairtools parse --output-stats {output.parsed_stats} --min-mapq 0 --walks-policy 5unique --max-inter-align-gap 30 --nproc-in {threads} --nproc-out {threads} --chroms-path {input.chroms} {input.merged_sam} > {output.parsed_pairsam}
        """

rule pairtools_sort:
    input:
        parsed_pairsam='mapped_reads/'+config['prefix']+'.parsed.pairsam'
    output:
        sorted_pairsam='mapped_reads/'+config['prefix']+'.sorted.pairsam'
    conda:
        'pairtools.yaml'
    threads:
        24
    resources:
        mem_mb=50000
    params:
        TMPDIR=config['tmp_dir']
    shell:
        """
        pairtools sort --nproc {threads} --tmpdir={params.TMPDIR} {input.parsed_pairsam} > {output.sorted_pairsam}
        """

rule pairtools_dedup:
    input:
        sorted_pairsam='mapped_reads/'+config['prefix']+'.sorted.pairsam'
    output:
        dedup_pairsam='mapped_reads/'+config['prefix']+'.dedup.pairsam',
        dedup_stats='mapped_reads/'+config['prefix']+'.dedup.stats'
    conda:
        'pairtools.yaml'
    threads:
        24
    resources:
        mem_mb=20000
    shell:
        """
        pairtools dedup --nproc-in {threads} --nproc-out {threads} --mark-dups --output-stats {output.dedup_stats} --output-dups {output.dedup_pairsam} --output {output.dedup_pairsam} {input.sorted_pairsam}
        """

rule pairtools_split:
    input:
        dedup_pairsam='mapped_reads/'+config['prefix']+'.dedup.pairsam'
    output:
        dedup_pairs='mapped_reads/'+config['prefix']+'.dedup.pairs',
        dedup_bam='mapped_reads/'+config['prefix']+'.dedup.bam'
    conda:
        'pairtools.yaml'
    threads:
        24
    resources:
        mem_mb=20000
    shell:
        """
        pairtools split --nproc-in {threads} --nproc-out {threads} --output-pairs {output.dedup_pairs} --output-sam {output.dedup_bam} {input.dedup_pairsam}
        """

rule index_pairs:
    input:
        dedup_pairs='mapped_reads/'+config['prefix']+'.dedup.pairs'
    output:
        dedup_pairs_gz='mapped_reads/'+config['prefix']+'.dedup.pairs.gz',
        dedup_pairs_idx='mapped_reads/'+config['prefix']+'.dedup.pairs.gz.px2'
    conda:
        'cooler.yaml'
    threads:
        1
    resources:
        mem_mb=10000
    shell:
        """
        bgzip {input.dedup_pairs}
        pairix {output.dedup_pairs_gz}
        """

rule cooler_cload:
    input:
        dedup_pairs_gz='mapped_reads/'+config['prefix']+'.dedup.pairs.gz',
        chroms=config['ref']+'.genome'
    output:
        cool='mapped_reads/'+config['prefix']+'.dedup.cool'
    conda:
        'cooler.yaml'
    threads:
        24
    resources:
        mem_mb=100000
    shell:
        """
        cooler cload pairix -p {threads} {input.chroms}:1000 {input.dedup_pairs_gz} {output.cool}
        """

rule cooler_zoomify:
    input:
        cool='mapped_reads/'+config['prefix']+'.dedup.cool'
    output:
        mcool=config['prefix']+'.dedup.mcool'
    conda:
        'cooler.yaml'
    threads:
        24
    resources:
        mem_mb=10000
    shell:
        """
        cooler zoomify -p {threads} {input.cool} -o {output.mcool}
        """
