# -*- coding: UTF-8 -*-

author: Kristian K Ullrich
date: July 2017
License: MIT

The MIT License (MIT)

Copyright (c) 2017 Kristian K Ullrich

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

This software relies on the following python packages:

import sys
import os
import re
import numpy as np
np.seterr(divide='ignore', invalid='ignore')
from interval import interval, imath, inf
import random
import argparse
import tabix
import gzip
import itertools
from itertools import repeat
from Bio import SeqIO
from Bio.Data import CodonTable as CT
from Bio import SeqRecord
from Bio.Seq import Seq
from Bio.Alphabet import IUPAC
from Bio.SeqUtils import GC
import collections

class vcfRecord:
    def __init__(self, record, CHROM, POS, ID, REF, ALT, QUAL, FILTER, INFO, FORMAT, SAMPLES, tabix):
        if not tabix:
            self.splitted = record.strip().split('\t')
        if tabix:
            self.splitted = record
        self.CHROM = self.splitted[CHROM]
        self.POS = self.splitted[POS]
        self.ID = self.splitted[ID]
        self.REF = [self.splitted[REF]]
        self.ALT = self.splitted[ALT].split(',')
        self.QUAL = self.splitted[QUAL]
        self.FILTER = self.splitted[FILTER]
        self.INFO = self.splitted[INFO]
        self.FORMAT = self.splitted[FORMAT].split(':')
        self.SAMPLES = collections.OrderedDict()
        for s in SAMPLES:
            self.SAMPLES[s] = collections.OrderedDict()
            if len(self.splitted[s].split(':')) != len(self.FORMAT):
            for p,f in enumerate(self.FORMAT):
                self.SAMPLES[s][f] = self.splitted[s].split(':')[p].split(',')
                if len(self.SAMPLES[s][f]) == 1:
                    self.SAMPLES[s][f] = self.SAMPLES[s][f][0]
    def get_var_len(self):
        """" Return length of alleles """
        var_len = [len(x) for x in self.REF]
        var_len += [len(x) for x in self.ALT]
        return var_len
    def get_var_number(self):
        """ Return total number of alleles """
        return len([self.REF])+len(self.ALT)
    def get_ref_alt(self):
        """ Return alleles """
        REF_ALT = self.REF+self.ALT
        return REF_ALT
    def is_filtered(self, argsDICT):
        """ Return True if FILTER of record is in filterarg """
        if len(argsDICT['filterarg'])!=1:
            if self.FILTER not in argsDICT['filterarg']:
                return True
            if self.FILTER in argsDICT['filterarg']:
                return False
        if len(argsDICT['filterarg']) == 1 and argsDICT['filterarg'][0]!='':
            if self.FILTER not in argsDICT['filterarg']:
                return True
            if self.FILTER in argsDICT['filterarg']:
                return False
        if len(argsDICT['filterarg']) == 1 and argsDICT['filterarg'][0] == '':
            return False
    def get_var_type(self):
        """ Return variant type according to GATK """
        var_type = 'NA'
        var_len = self.get_var_len()
        var_number = self.get_var_number()
        if len(self.ALT) == 1 and self.ALT[0] == '.':
            var_type = 'noVAR'
            return var_type
        if len(self.ALT) == 1 and self.ALT[0] == '*':
            var_type = 'complexVAR'
            return var_type
        if var_number == 2:
            if var_len[0] == 1 and var_len[1] == 1:
                #equals GATK SNP
                var_type = 'biSNP'
            if var_len[0] > 1 and var_len[1] > 1 and var_len[0] == var_len[1]:
                #equals GATK MNP
                var_type = 'repSNP'
            if var_len[0] > var_len[1]:
                var_type = 'deletion'
            if var_len[0] < var_len[1]:
                var_type = 'insertion'
            return var_type
        #equals GATK MULTI-ALLELIC
        if var_number != 2:
            if all( [x == var_len[0] for x in var_len[1::]] ) and all( [x == 1 for x in var_len] ):
                #equals GATK SNP
                var_type = 'muSNP'
            #TODO 'murepSNP'
            if all( [x < var_len[0] for x in var_len[1::]] ):
                var_type = 'mudeletion'
            if all( [x > var_len[0] for x in var_len[1::]] ):
                var_type = 'muinsertion'
            if any( [x < var_len[0] for x in var_len[1::]] ) and any( [x == var_len[0] for x in var_len[1::]] ):
                var_type = 'complexDS'
            if any( [x > var_len[0] for x in var_len[1::]] ) and any( [x == var_len[0] for x in var_len[1::]] ):
                var_type = 'complexIS'
            if any( [x < var_len[0] for x in var_len[1::]] ) and any( [x > var_len[0] for x in var_len[1::]] ):
                var_type = 'complexDI'
            if any( [x < var_len[0] for x in var_len[1::]] ) and any( [x > var_len[0] for x in var_len[1::]] ) and any( [x == var_len[0] for x in var_len[1::]] ):
                var_type = 'complexDIS'
            return var_type
    def keep_samples(self, argsDICT):
        """ Keep only samples in record which are listed """
        for s in self.SAMPLES.copy():
            if s not in argsDICT['samples_pos']:
    def get_info(self, value):
        """" Return INFO value """
        return [x[1] for x in [x.split('=') for x in self.INFO.split(';')] if x[0] == value][0]
    def add_info(self, value):
        """ Add INFO value """
    def add_format_on_samples(self, s, f, v):
        """ Add new FORMAT field to sample """
        self.SAMPLES[s][f] = v
    def replace_format_on_samples(self, s, f, v):
        """ Replace FORMAT field of sample """ 
        self.SAMPLES[s][f] = v
    def get_samples_ad(self):
        """ Return samples AD as list """
        return [self.SAMPLES[x]['AD'] for x in self.SAMPLES]
    def get_samples_gt(self):
        """ Return samples GT as list """
        return [self.SAMPLES[x]['GT'] for x in self.SAMPLES]
    def get_samples_gq(self):
        """ Return samples GQ as list """
        return [0 if self.SAMPLES[x]['GQ'] == "." else int(self.SAMPLES[x]['GQ']) for x in self.SAMPLES]
    def get_samples_gp(self):
        """ Return samples GP as list """
        return [[0 if y == "." else float(y) for y in self.SAMPLES[x]['GP']] for x in self.SAMPLES]

    def get_samples_dp(self):
        """ Return samples DP as list """
        return [0 if self.SAMPLES[x]['DP'] == "." else int(self.SAMPLES[x]['DP']) for x in self.SAMPLES]
    def get_samples_pl(self):
        """ Return samples PL as list """
        return [self.SAMPLES[x]['PL'] for x in self.SAMPLES]
    def convert_pl_gp(self):
        """ Converts phred-scale PL into GP """
        return [list([np.exp(float(x)/-10) for x in y]/np.sum([np.exp(float(x)/-10) for x in y])) for y in self.get_samples_pl()]
    def get_samples_ad_len(self):
        """ Return samples AD len as list """
        return [len(self.SAMPLES[x]['AD']) for x in self.SAMPLES]
    def get_samples_ad_ref(self):
        """ Return samples AD REF as list """
        return [self.SAMPLES[x]['AD'][0] for x in self.SAMPLES]
    def get_samples_ad_alt(self):
        """ Return samples AD ALT as list """
        return [self.SAMPLES[x]['AD'][1:] for x in self.SAMPLES]
    def gt_type(self, x):
        if len(x)==1:
            if x[0] == '.':
                return 'NA'
            if x[0] != '.':
                return 'hap'
        if x[0] == '.' and x[1] == '.':
            return 'NA'
        if x[0] != '.' and x[1] != '.' and x[0] == x[1]:
            return 'hom'
        if x[0] != '.' and x[1] != '.' and x[0] != x[1]:
            return 'het'
    def get_gt_type(self):
        """ Return GT type """
        gt_split = re.compile("[|/]")
        GTs = [re.split(gt_split,x) for x in self.get_samples_gt()]
        return [self.gt_type(x) for x in GTs]
    def get_gt_na_pos(self):
        """ Return index of GT which is NA """
        samplesgttype = self.get_gt_type()
        samplesgtnapos = [x for x,y in enumerate(samplesgttype) if y == 'NA']
        return samplesgtnapos
    def fill_empty_ad(self):
        """ Replaces "." with 0 and fills to var_number if less """
        x_ = [x+list(np.repeat(0,self.get_var_number()-len(x))) for x in [[int(x.replace(".","0")) for x in y] for y in self.get_samples_ad()]]
        for x,y in zip(x_, self.SAMPLES.keys()):
            self.SAMPLES[y]['AD'] = x
    def fill_empty_pl(self):
        """ Replaces "." with [0,0,0] """
        x_ = [x+list(np.repeat(0,nCk(self.get_var_number()+1,2)-len(x))) for x in [[int(x.replace(".","0")) for x in y] for y in self.get_samples_pl()]]
        for x,y in zip(x_, self.SAMPLES.keys()):
            self.SAMPLES[y]['PL'] = x
    def set_samples_ad_by_gt(self, argsDICT):
        if argsDICT['missingGT'][0] == 'keep':
        if argsDICT['missingGT'][0] == 'zero':
            samplesgtnapos = self.get_gt_na_pos()
            #set AD to zero for all samples with GT equals NA
            for s in self.get_gt_na_pos():
                self.replace_format_on_samples(self.SAMPLES.keys()[s], 'AD', list(np.repeat(0, self.get_var_number())))
        if argsDICT['missingGT'][0] == 'set':
            samplesgtnapos = self.get_gt_na_pos()
            #set AD to certain value for all samples with GT equals NA
            for s in self.get_gt_na_pos():
                self.replace_format_on_samples(self.SAMPLES.keys()[s], 'AD', [argsDICT['missingGT'][1]]+list(np.repeat(argsDICT['missingGT'][2], self.get_var_number()-1)))
    def set_dp_by_ad(self):
        samplesad = self.get_samples_ad()
        samplesadsum = [np.sum([int(x) for x in x]) for x in samplesad]
        #set DP to ADSUM
        for i in range(0, len(self.SAMPLES)):
            self.replace_format_on_samples(self.SAMPLES.keys()[i], 'DP', samplesadsum[i])
    def set_samples_gp_by_gt(self, argsDICT):
        samplespl = self.get_samples_pl()
        samplesgp = self.convert_pl_gp()
        if argsDICT['missingGT'][0] == 'keep':
            for i in range(0, len(self.SAMPLES)):
                self.replace_format_on_samples(self.SAMPLES.keys()[i], 'GP', samplesgp[i])
        if argsDICT['missingGT'][0] == 'equal':
            samplesgtnapos = self.get_gt_na_pos()
            #set GP to 0.333 for all samples with GT equals NA
            for s in self.get_gt_na_pos():
            for i in range(0, len(self.SAMPLES)):
                self.replace_format_on_samples(self.SAMPLES.keys()[i], 'GP', samplesgp[i])
        if argsDICT['missingGT'][0] == 'set':
            samplesgtnapos = self.get_gt_na_pos()
            #set GP to certain value for all samples with GT equals NA
            for s in self.get_gt_na_pos():
            for i in range(0, len(self.SAMPLES)):
                self.replace_format_on_samples(self.SAMPLES.keys()[i], 'GP', samplesgp[i])

    def set_samples_gp_by_dp(self, argsDICT):
        samplesdp = self.get_samples_dp()
        samplesgp = self.get_samples_gp()
        #get NotCalled positions
        minDP_NCpos = [x for x,y in enumerate(self.get_samples_dp()) if y < argsDICT['minDP']]
        maxDP_NCpos = [x for x,y in enumerate(self.get_samples_dp()) if y > argsDICT['maxDP']]
        #combine and unique NotCalled positions
        NCpos = [x for x in set(minDP_NCpos+maxDP_NCpos)]
        for x in NCpos:
        for i in range(0, len(self.SAMPLES)):
            self.replace_format_on_samples(self.SAMPLES.keys()[i], 'GP', samplesgp[i])
    def set_samples_gp_by_gq(self, argsDICT):
        samplesgq = self.get_samples_gq()
        samplesgp = self.get_samples_gp()
        #get NotCalled positions
        minGQ_NCpos = [x for x,y in enumerate(self.get_samples_gq()) if y < argsDICT['minGQ']]
        #combine and unique NotCalled positions
        NCpos = [x for x in set(minGQ_NCpos)]
        for x in NCpos:
        for i in range(0, len(self.SAMPLES)):
            self.replace_format_on_samples(self.SAMPLES.keys()[i], 'GP', samplesgp[i])
    def set_samples_gt_by_ad(self, argsDICT):
        if argsDICT['resetGT'] == 'keep':
        if argsDICT['resetGT'] == 'AD':
            samplesad_ref = self.get_samples_ad_ref()
            samplesad_alt = self.get_samples_ad_alt()
            #NOTE np.argmax return the first element for equal values which would bias GT setting always to the first ALT entry if equal depth exist
            #NOTE this is only used to determine NotCalledFraction, AD for other alleles will be considered
            samplesad_alt_maxpos = [np.argmax(x) for x in samplesad_alt]
            samplesad_alt_max = [np.max(x) for x in samplesad_alt]
            samplesad_ref_alt_max = map(lambda a, b: [a,b], samplesad_ref, samplesad_alt_max)
            samplesad_ref_alt_maxpos = map(lambda a, b: [a,b], list(np.repeat(0,len(samplesad_ref))), [x+1 for x in samplesad_alt_maxpos])
            samplesgt = []
            for a, b in zip(samplesad_ref_alt_max, samplesad_ref_alt_maxpos):
                samplesgt.append(evaluate_ref_alt_max(a, b))
            #get NotCalled positions
            minDP_NCpos = [x for x,y in enumerate(self.get_samples_dp()) if y < argsDICT['minDP']]
            maxDP_NCpos = [x for x,y in enumerate(self.get_samples_dp()) if y > argsDICT['maxDP']]
            minGQ_NCpos = [x for x,y in enumerate(self.get_samples_gq()) if y < argsDICT['minGQ']]
            #combine and unique NotCalled positions
            NCpos = [x for x in set(minDP_NCpos+maxDP_NCpos+minGQ_NCpos)]
            for x in NCpos:
            for i in range(0, len(self.SAMPLES)):
                self.replace_format_on_samples(self.SAMPLES.keys()[i], 'GT', samplesgt[i])
    def get_consensus_calldata(self, argsDICT):
        """ Return consensus of existing record samples """
        consensusdict = collections.OrderedDict()
        for rf in self.FORMAT:
            consensusdict[rf] = None
        consensusdict['AD'] = self.merge_ad()
        consensusdict['DP'] = np.sum(consensusdict['AD'])
        consensusdict['GQ'] = self.merge_gq()
        consensusdict['PL'] = self.merge_pl()
        consensusdict['GT'] = self.set_consensus_gt_by_ad(consensusdict)
        return consensusdict
    def set_consensus_gt_by_ad(self, consensusdict):
        consensusad_ref = consensusdict['AD'][0]
        consensusad_alt = consensusdict['AD'][1:]
        #NOTE np.argmax return the first element for equal values which would bias GT setting always to the first ALT entry if equal depth exist
        consensusad_alt_maxpos = np.argmax(consensusad_alt)
        consensusad_alt_max = np.max(consensusad_alt)
        consensusad_ref_alt_max = [consensusad_ref, consensusad_alt_max]
        consensusad_ref_alt_maxpos = [0, consensusad_alt_maxpos+1]
        consensusgt = evaluate_ref_alt_max(consensusad_ref_alt_max, consensusad_ref_alt_maxpos)
        return consensusgt
    def merge_ad(self):
        ADOUT = []
        for i in range(0, self.get_var_number()):
            ADOUT.append(np.sum([int(y[i]) for x,y in enumerate(self.get_samples_ad()) if x not in self.get_gt_na_pos()]))
        return ADOUT
    def merge_gq(self):
        return int(np.median([y for x,y in enumerate(self.get_samples_gq()) if x not in self.get_gt_na_pos()]))
    def merge_dp(self):
        return np.sum([y for x,y in enumerate(self.get_samples_dp()) if x not in self.get_gt_na_pos()])
    def merge_pl(self):
        PLOUT = []
        for i in range(0,nCk(self.get_var_number()+1,2)):
            PLOUT.append(int(np.median([y[i] for x,y in enumerate(self.get_samples_pl()) if x not in self.get_gt_na_pos()])))
        return PLOUT
    def get_ACGT_allele_counts(self, var_type):
        if var_type != 'noVAR' and var_type != 'biSNP' and var_type != 'muSNP':
        if var_type == 'repSNP':
            for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
                self.add_format_on_samples(sample_, 'ACGT_count', get_ACGT_count_dict(0,0,0,0))
                self.SAMPLES[sample_]['ACGT_count'][list(set(self.REF[0]))[0]] = self.get_samples_ad_ref()[s_pos]
                for alt_pos, alt_ in enumerate(self.ALT):
                    self.SAMPLES[sample_]['ACGT_count'][list(set(alt_))[0]] = self.get_samples_ad_alt()[s_pos][alt_pos]
        for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
            self.add_format_on_samples(sample_, 'ACGT_count', get_ACGT_count_dict(0,0,0,0))
            self.SAMPLES[sample_]['ACGT_count'][self.REF[0]] = self.get_samples_ad_ref()[s_pos]
            for alt_pos, alt_ in enumerate(self.ALT):
                self.SAMPLES[sample_]['ACGT_count'][alt_] = self.get_samples_ad_alt()[s_pos][alt_pos]
    def get_alleles(self, var_type_, type_):
        if var_type_ == 'noVAR':
            MAJOUT = list(np.repeat(self.REF,len(self.SAMPLES)))
            MINOUT = list(np.repeat(self.REF,len(self.SAMPLES)))
            IUPACOUT = list(np.repeat(self.REF,len(self.SAMPLES)))
            for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
                MAJOUT[s_pos], MINOUT[s_pos], IUPACOUT[s_pos] = _biSNP(self, sample_, type_)
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'biSNP':
            MAJOUT = list(np.repeat('N',len(self.SAMPLES)))
            MINOUT = list(np.repeat('N',len(self.SAMPLES)))
            IUPACOUT = list(np.repeat('N',len(self.SAMPLES)))
            for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
                MAJOUT[s_pos], MINOUT[s_pos], IUPACOUT[s_pos] = _biSNP(self, sample_, type_)
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'muSNP':
            MAJOUT = list(np.repeat('N',len(self.SAMPLES)))
            MINOUT = list(np.repeat('N',len(self.SAMPLES)))
            IUPACOUT = list(np.repeat('N',len(self.SAMPLES)))
            for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
                MAJOUT[s_pos], MINOUT[s_pos], IUPACOUT[s_pos] = _muSNP(self, sample_, type_)
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'repSNP':
            MAJOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            MINOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            IUPACOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
                MAJOUT[s_pos], MINOUT[s_pos], IUPACOUT[s_pos] = _repSNP(self, sample_, type_)
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'deletion':
            MAJOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            MINOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            IUPACOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))            
            for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
                MAJOUT[s_pos], MINOUT[s_pos], IUPACOUT[s_pos] = _deletion(self, sample_, type_)
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'mudeletion':
            MAJOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            MINOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            IUPACOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))            
            for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
                MAJOUT[s_pos], MINOUT[s_pos], IUPACOUT[s_pos] = _mudeletion(self, sample_, type_)
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'insertion':
            MAJOUT = list(np.repeat(self.REF,len(self.SAMPLES)))
            MINOUT = list(np.repeat(self.REF,len(self.SAMPLES)))
            IUPACOUT = list(np.repeat(self.REF,len(self.SAMPLES)))            
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'muinsertion':
            MAJOUT = list(np.repeat(self.REF,len(self.SAMPLES)))
            MINOUT = list(np.repeat(self.REF,len(self.SAMPLES)))
            IUPACOUT = list(np.repeat(self.REF,len(self.SAMPLES)))            
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'complexDS':
            MAJOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            MINOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            IUPACOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
                MAJOUT[s_pos], MINOUT[s_pos], IUPACOUT[s_pos] = _complexDS(self, sample_, type_)
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'complexIS':
            MAJOUT = list(np.repeat('N',len(self.SAMPLES)))
            MINOUT = list(np.repeat('N',len(self.SAMPLES)))
            IUPACOUT = list(np.repeat('N',len(self.SAMPLES)))
            for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
                MAJOUT[s_pos], MINOUT[s_pos], IUPACOUT[s_pos] = _complexIS(self, sample_, type_)
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'complexDI':
            MAJOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            MINOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            IUPACOUT = list(np.repeat('N'*np.max(self.get_var_len()),len(self.SAMPLES)))
            for s_pos, sample_ in enumerate(self.SAMPLES.keys()):
                MAJOUT[s_pos], MINOUT[s_pos], IUPACOUT[s_pos] = _complexDI(self, sample_, type_)
            return [MAJOUT,MINOUT,IUPACOUT]
        if var_type_ == 'complexDIS':

#common dictionaries
IUPACdict = {'A':'A','AA':'A','C':'C','CC':'C','G':'G','GG':'G','T':'T','TT':'T','-':'-','N':'N','.':'.',

#common functions
def sliding_window_steps_generator(sw,j,start,end):
    if end<=start:
        print 'end must be smaller than start'
    start_seq = np.arange(start,end,j)
    end_seq = np.arange(start+sw-1,end,j)
    end_seq = np.append(end_seq,np.repeat(end,len(start_seq)-len(end_seq)))
    mid_seq = (start_seq+end_seq)/2

def getIUPAC(x):
    return IUPACdict[x]

def POINT2zero(x):
    if x == '.':
        return '0'
        return x

def POINTNA2zero(x):
    if x == '.':
        return '0'
    if x == 'NA':
        return '0'
        return x

def NA2zero(x):
    if x == 'NA':
        return int(0)
        return int(x)

def nCk(n, k):
    return len(list(itertools.combinations(range(n), k)))

def evaluate_ref_alt_max(x,y):
    if x[0]==0 and x[1]==0:
        return './.'
    if x[0]>0 and x[1]==0:
        return str(y[0])+'/'+str(y[0])
    if x[0]==0 and x[1]>0:
        return str(y[1])+'/'+str(y[1])
    if x[0]>0 and x[1]>0:
        return str(y[0])+'/'+str(y[1])

def get_header_from_vcf( argsDICT ):
    HEADER = []
    with open(argsDICT['ivcf'],'rb') as f:
        for flines in f:
           if flines[:2]=='##':
           if flines[:4]=='#CHR':
               HEADER.append('##INFO=<ID=NCFnumber,Number=1,Type=String,Description="NotCalledFraction number">\n')
               HEADER.append('##INFO=<ID=NCFfraction,Number=1,Type=Float,Description="NotCalledFraction fraction">\n')
               HEADER.append(' '+argsDICT['args'][0]+'\n')
               HEADER.append(changeHEADER(flines, argsDICT['add'], argsDICT['id'], argsDICT['samples']))
           if flines[0]!='#':
    return HEADER

def get_header_from_gz( argsDICT ):
    HEADER = []
    with['ivcf'],'rb') as f:
        for flines in f:
           if flines[:2]=='##':
           if flines[:4]=='#CHR':
               HEADER.append('##INFO=<ID=NCFnumber,Number=1,Type=String,Description="NotCalledFraction number">\n')
               HEADER.append('##INFO=<ID=NCFfraction,Number=1,Type=Float,Description="NotCalledFraction fraction">\n')
               HEADER.append(' '+argsDICT['args'][0]+'\n')
               HEADER.append(changeHEADER(flines, argsDICT['add'], argsDICT['id'], argsDICT['samples']))
           if flines[0]!='#':
    return HEADER

def get_sample_pos_from_header( argsDICT ):
    SAMPLE_POS = []
    if argsDICT['tabix']:
        with['ivcf'],'rb') as f:
            for flines in f:
               if flines[:4]=='#CHR':
                   flines_stripped = flines.strip().split('\t')
               if flines[0]!='#':
    if not argsDICT['tabix']:
        with open(argsDICT['ivcf'],'rb') as f:
            for flines in f:
                if flines[:4]=='#CHR':
                    flines_stripped = flines.strip().split('\t')
                if flines[0]!='#':
    for s in argsDICT['samples']:
        if s in ['CHROM','POS','ID','REF','ALT','QUAL','FILTER','INFO','FORMAT']:
            SAMPLE_POS.append([x for x,y in enumerate(flines_stripped) if y == s][1])
        if s not in ['CHROM','POS','ID','REF','ALT','QUAL','FILTER','INFO','FORMAT']:
            SAMPLE_POS.append([x for x,y in enumerate(flines_stripped) if y == s][0])
    return SAMPLE_POS

def get_sample_pos_dict_from_header( argsDICT ):
    sampleposDICT = {}
    if argsDICT['tabix']:
        with['ivcf'],'rb') as f:
            for flines in f:
               if flines[:4]=='#CHR':
                   flines_stripped = flines.strip().split('\t')
               if flines[0]!='#':
    if not argsDICT['tabix']:
        with open(argsDICT['ivcf'],'rb') as f:
            for flines in f:
                if flines[:4]=='#CHR':
                    flines_stripped = flines.strip().split('\t')
                if flines[0]!='#':
    for s in argsDICT['samples']:
        if s in ['CHROM','POS','ID','REF','ALT','QUAL','FILTER','INFO','FORMAT']:
            sampleposDICT[s]=[x for x,y in enumerate(flines_stripped) if y == s][1]
        if s not in ['CHROM','POS','ID','REF','ALT','QUAL','FILTER','INFO','FORMAT']:
            sampleposDICT[s]=[x for x,y in enumerate(flines_stripped) if y == s][0]
    return sampleposDICT

def changeHEADER(header, add, name, samples):
    if add:
        headersplit = header.strip().split('\tFORMAT')
    if not add:
        headersplit = header.strip().split('\tFORMAT')

def get_missing_type(x):
    if x == 'keep':
        return ['keep', 0, 0]
    if x == 'zero':
        return ['zero', 0, 0]
    if 'set' in x:
        return ['set', int(x.split(':')[1]), int(x.split(':')[2])]

def get_missing_type2(x):
    if x == 'keep':
        return ['keep', 0.333, 0.333, 0.333]
    if x == 'equal':
        return ['equal', 0.333, 0.333, 0.333]
    if 'set' in x:
        return ['set', float(x.split(':')[1]), float(x.split(':')[2]), float(x.split(':')[3])]

def get_reference_len( argsDICT ):
    print 'start obtaining reference length'
    length_dict = {}
    for seq in SeqIO.parse( argsDICT['R'], 'fasta' ):
        if argsDICT['verbose']:
        if not in argsDICT['chr']:
        if in argsDICT['chr']:
            length_dict[] = len( seq )
    print 'finished obtaining reference length'    
    return length_dict

def duplicate_sequence(fastafile, id):
    # type: (object, object) -> object
    for rec in SeqIO.parse(fastafile, "fasta"):
        if == id:
            duplicated_seq = SeqRecord.SeqRecord(rec.seq.tomutable())
            duplicated_seq.description = rec.description
            return duplicated_seq

def get_mask_dict ( argsDICT ):
    cov = int(argsDICT['cov2N'])
    chromosomes = argsDICT['chr']
    ibga = argsDICT['ibga']
    print 'start loading masking bga'
    for lines in open(ibga, 'r'):
        chr_, start_, end_, cov_ = lines.strip().split( '\t' )
        start_ = int(start_)
        end = int(end_)
        cov_ = int(cov_)
        if chr_ not in chromosomes:
        if chr_ not in maskDICT:
            print 'working on %s' % ( chr_ )
            maskDICT[chr_] = {}
        if cov_ > cov:
        if cov_ <= cov:
            maskDICT[chr_][str( int( start_ ) + 1 ) + '_' + end_] = 0
    for mchr in maskDICT.keys():
        maskDICT[mchr] = interval.union([interval([x.split('_')[0],x.split('_')[1]]) for x in maskDICT[mchr].keys()])
    print 'finished loading masking bga'
    return maskDICT

def get_mask_dict_multiple_by_region ( argsDICT, referenceDICT ):
    cov = int(argsDICT['cov2N'])
    chromosome = argsDICT['chr']
    print 'start loading masking bga'
    for s_pos, sample_ in enumerate(argsDICT['samples_pos']):
        print 'working on %s' % ( argsDICT['samples'][s_pos] )
        ibga = argsDICT['ibga'][s_pos]
        for lines in open(ibga, 'r'):
            chr_, start_, end_, cov_ = lines.strip().split( '\t' )
            start_ = int(start_)
            end = int(end_)
            cov_ = int(cov_)
            if chr_ not in chromosome:
            if cov_ > cov:
            if cov_ <= cov:
                referenceDICT[sample_][chr_]['BGA'][str( int( start_ ) + 1 ) + '_' + end_] = 0
        referenceDICT[sample_][chromosome]['BGA'] = interval([int(argsDICT['chr_start']),int(argsDICT['chr_end'])]) & interval.union([interval([x.split('_')[0],x.split('_')[1]]) for x in referenceDICT[sample_][chromosome]['BGA'].keys()])
        print 'finished working on %s' % ( argsDICT['samples'][s_pos] )
    print 'finished loading masking bga'

def get_ACGT_count_dict(A_, C_, G_, T_):
    ACGT_count_dict = collections.OrderedDict()
    ACGT_count_dict['A'] = A_
    ACGT_count_dict['C'] = C_
    ACGT_count_dict['G'] = G_
    ACGT_count_dict['T'] = T_
    return ACGT_count_dict

#allele functions
def _biSNP(self, sample_pos, type_):
    if sample_pos not in self.SAMPLES.keys():
    samplesad_ = self.SAMPLES[sample_pos]['AD']
    ref_alt_ = self.get_ref_alt()
    ref_ = ref_alt_[0]
    alt_ = ref_alt_[1:]
    ref_out = 'N'
    alt_out = 'N'
    iupac_out = 'N'
    samplesad_ = [int(x) for x in samplesad_]
    ref_alt_nonzero = [ref_alt_[x] for x,y in enumerate(samplesad_) if samplesad_[x]!=0]
    samplesad_nonzero = [samplesad_[x] for x,y in enumerate(samplesad_) if samplesad_[x]!=0]
    if len(samplesad_nonzero) == 0:
        return [ref_out, alt_out, iupac_out]
    admax = np.max([int(x) for x in samplesad_nonzero])
    admin = np.min([int(x) for x in samplesad_nonzero])
    admaxpos = [x for x,y in enumerate(samplesad_nonzero) if samplesad_nonzero[x]==admax]
    adminpos = [x for x,y in enumerate(samplesad_nonzero) if samplesad_nonzero[x]!=admax]
    if type_ == 'refaltN':
        if len(ref_alt_nonzero)==1:
            ref_out = ref_alt_nonzero[0]
            alt_out = ref_alt_nonzero[0]
            iupac_out = ref_alt_nonzero[0]
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2:
            return [ref_out, alt_out, iupac_out]
    if type_ == 'refaltsample':
        if len(ref_alt_nonzero)==1:
            ref_out = ref_alt_nonzero[0]
            alt_out = ref_alt_nonzero[0]
            iupac_out = ref_alt_nonzero[0]
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2:
            ref_out = random.sample(ref_alt_nonzero,1)[0]
            alt_out = [x for x in ref_alt_nonzero if x != ref_out][0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
    if type_ == 'refmajorsample':
        if len(ref_alt_nonzero)==1:
            ref_out = ref_alt_nonzero[0]
            alt_out = ref_alt_nonzero[0]
            iupac_out = ref_alt_nonzero[0]
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2 and len(set(samplesad_nonzero)) != 1:
            ref_out = ref_alt_nonzero[admaxpos[0]]
            alt_out = ref_alt_nonzero[adminpos[0]]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2 and len(set(samplesad_nonzero)) == 1:
            ref_out = random.sample(ref_alt_nonzero,1)[0]
            alt_out = [x for x in ref_alt_nonzero if x != ref_out][0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
    if type_ == 'refminorsample':
        if len(ref_alt_nonzero)==1:
            ref_out = ref_alt_nonzero[0]
            alt_out = ref_alt_nonzero[0]
            iupac_out = ref_alt_nonzero[0]
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2 and len(set(samplesad_nonzero)) != 1:
            ref_out = ref_alt_nonzero[adminpos[0]]
            alt_out = ref_alt_nonzero[admaxpos[0]]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2 and len(set(samplesad_nonzero)) == 1:
            ref_out = random.sample(ref_alt_nonzero,1)[0]
            alt_out = [x for x in ref_alt_nonzero if x != ref_out][0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
    if type_ == 'iupac':
        if len(ref_alt_nonzero)==1:
            ref_out = ref_alt_nonzero[0]
            alt_out = ref_alt_nonzero[0]
            iupac_out = ref_alt_nonzero[0]
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2:
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]

def _muSNP(self, sample_pos, type_):
    if sample_pos not in self.SAMPLES.keys():
    samplesad_ = self.SAMPLES[sample_pos]['AD']
    ref_alt_ = self.get_ref_alt()
    ref_ = ref_alt_[0]
    alt_ = ref_alt_[1:]
    ref_out = 'N'
    alt_out = 'N'
    iupac_out = 'N'
    samplesad_ = [int(x) for x in samplesad_]
    ref_alt_nonzero = [ref_alt_[x] for x,y in enumerate(samplesad_) if samplesad_[x]!=0]
    samplesad_nonzero = [samplesad_[x] for x,y in enumerate(samplesad_) if samplesad_[x]!=0]
    if len(samplesad_nonzero) == 0:
        return [ref_out, alt_out, iupac_out]
    admax = np.max([int(x) for x in samplesad_nonzero])
    admin = np.min([int(x) for x in samplesad_nonzero])
    admaxpos = [x for x,y in enumerate(samplesad_nonzero) if samplesad_nonzero[x]==admax]
    adminpos = [x for x,y in enumerate(samplesad_nonzero) if samplesad_nonzero[x]!=admax]
    if type_ == 'refaltN':
        if len(ref_alt_nonzero)==1:
            ref_out = ref_alt_nonzero[0]
            alt_out = ref_alt_nonzero[0]
            iupac_out = ref_alt_nonzero[0]
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)>1:
            return [ref_out, alt_out, iupac_out]
    if type_ == 'refaltsample':
        if len(ref_alt_nonzero)==1:
            ref_out = ref_alt_nonzero[0]
            alt_out = ref_alt_nonzero[0]
            iupac_out = ref_alt_nonzero[0]
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2:
            ref_out = random.sample(ref_alt_nonzero,1)[0]
            alt_out = [x for x in ref_alt_nonzero if x != ref_out][0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)>2:
            ref_out = random.sample(ref_alt_nonzero,1)[0]
            alt_out = random.sample([x for x in ref_alt_nonzero if x != ref_out],1)[0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
    if type_ == 'refmajorsample':
        if len(ref_alt_nonzero)==1:
            ref_out = ref_alt_nonzero[0]
            alt_out = ref_alt_nonzero[0]
            iupac_out = ref_alt_nonzero[0]
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2 and len(set(samplesad_nonzero)) != 1:
            ref_out = ref_alt_nonzero[admaxpos[0]]
            alt_out = ref_alt_nonzero[adminpos[0]]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2 and len(set(samplesad_nonzero)) == 1:
            ref_out = random.sample(ref_alt_nonzero,1)[0]
            alt_out = [x for x in ref_alt_nonzero if x != ref_out][0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)>2 and len(set(samplesad_nonzero)) != 1:
            max_nonzero = [x for y,x in enumerate(ref_alt_nonzero) if y in admaxpos]
            min_nonzero = [x for y,x in enumerate(ref_alt_nonzero) if y in adminpos]
            ref_out = random.sample(max_nonzero,1)[0]
            alt_out = random.sample(min_nonzero,1)[0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)>2 and len(set(samplesad_nonzero)) == 1:
            max_nonzero = [x for y,x in enumerate(ref_alt_nonzero) if y in admaxpos]
            min_nonzero = [x for y,x in enumerate(ref_alt_nonzero) if y in adminpos]
            ref_out = random.sample(max_nonzero,1)[0]
            alt_out = random.sample([x for x in max_nonzero if x != ref_out],1)[0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
    if type_ == 'refminorsample':
        if len(ref_alt_nonzero)==1:
            ref_out = ref_alt_nonzero[0]
            alt_out = ref_alt_nonzero[0]
            iupac_out = ref_alt_nonzero[0]
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2 and len(set(samplesad_nonzero)) != 1:
            ref_out = ref_alt_nonzero[adminpos[0]]
            alt_out = ref_alt_nonzero[admaxpos[0]]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)==2 and len(set(samplesad_nonzero)) == 1:
            ref_out = random.sample(ref_alt_nonzero,1)[0]
            alt_out = [x for x in ref_alt_nonzero if x != ref_out][0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)>2 and len(set(samplesad_nonzero)) != 1:
            max_nonzero = [x for y,x in enumerate(ref_alt_nonzero) if y in adminpos]
            min_nonzero = [x for y,x in enumerate(ref_alt_nonzero) if y in admaxpos]
            ref_out = random.sample(max_nonzero,1)[0]
            alt_out = random.sample(min_nonzero,1)[0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)>2 and len(set(samplesad_nonzero)) == 1:
            max_nonzero = [x for y,x in enumerate(ref_alt_nonzero) if y in adminpos]
            min_nonzero = [x for y,x in enumerate(ref_alt_nonzero) if y in admaxpos]
            ref_out = random.sample(min_nonzero,1)[0]
            alt_out = random.sample([x for x in min_nonzero if x != ref_out],1)[0]
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]
    if type_ == 'iupac':
        if len(ref_alt_nonzero)==1:
            ref_out = ref_alt_nonzero[0]
            alt_out = ref_alt_nonzero[0]
            iupac_out = ref_alt_nonzero[0]
            return [ref_out, alt_out, iupac_out]
        if len(ref_alt_nonzero)>1:
            iupac_out = getIUPAC(''.join(ref_alt_nonzero))
            return [ref_out, alt_out, iupac_out]

def _repSNP(self, sample_pos, type_):
    if sample_pos not in self.SAMPLES.keys():