import click
import gzip
from collections import defaultdict
import struct
from scipy.stats.stats import spearmanr
import numpy as np
import sys

@click.command()
@click.option('--b1',  help='barcodes.txt file')
@click.option('--b2',  help='barcodes.txt file')
@click.option('--two',  help='Eqclass Binary file')
@click.option('--one',  help='Eqclass Binary file')
@click.option('--type',  help='[eq, sf, csv]')
@click.option('--txps',  help='Eqclass Binary file')
def match(one, two, type, b1, b2, txps):
    b1_idx, b1_rev_idx = read_bc(b1)
    b2_idx, b2_rev_idx = read_bc(b2)
    numTxps = int(txps)
    if type == "eq":
        c1 = read_eq_bin(one)
        c2 = read_eq_bin(two)
        num_cells = max(c1.keys())+1
        num_cells_c1 = len(c1)
        num_cells_c2 = len(c2)
        c1_vec = [0]*num_cells
        c2_vec = [0]*num_cells

        for cid, cell in c1.items():
            c1_vec[cid] = sum(cell.values())
        for cid, cell in c2.items():
            cid = b1_idx[ b2_rev_idx[cid] ]
            c2_vec[cid] = sum(cell.values())
        print ("Diff in number of reads across all cells")
        diff = sum(np.abs(np.array(c1_vec) - np.array(c2_vec)))
        print (diff)
        if diff > 0:
            print ("FAILED")
    elif type == "sf":
        d1 = read_sf_bin(one, numTxps)
        d2 = read_sf_bin(two, numTxps)
        num_cells = len(d1)
        diff_sum = 0.0
        assert len(d1)==len(d2), "length not equal"
        for cid,ct1 in enumerate(d1):
            cid = b2_idx[ b1_rev_idx[cid] ]
            ct2 = d2[cid]
            diff_sum += abs(sum(ct1) - sum(ct2))
        print (diff_sum)
        if diff_sum > 0.0:
            print ("FAILED")

    elif type == "csv":
        cl1 = read_csv(one)
        cl2 = read_csv(two)
        ts1 = 0
        ts2 = 0
        count = 0
        for cl in cl1:
            count += 1
            if count %100==0:
                print ("\r Done reading " + str(count) + " cells", end="")
                sys.stdout.flush()

            rc = sum([float(x) for x in cl])
            if rc == 0 :
                print ("ERROR")
            ts1 += rc
        count = 0
        for cl in cl2:
            count += 1
            if count %100==0:
                print ("\r Done reading " + str(count) + " cells", end="")
                sys.stdout.flush()

            rc = sum([float(x) for x in cl])
            if rc == 0 :
                print ("ERROR")
            ts2 += rc
        print (ts1, ts2)
        diff = ts1-ts2
        print (diff)
        if diff > 0:
            print ("FAILED")

def read_csv(bf):
    countList = []
    count = 0
    with open(bf) as f:
        for line in f:
            count += 1
            if count %100==0:
                print ("\r Done reading " + str(count) + " cells", end="")
                sys.stdout.flush()
            countList.append( line.strip().split(",")[:-1] )
    return countList


def read_bc(bf):
    bc_idx = {}
    bc_rev_idx = {}
    with open(bf) as f:
        count = 0
        for line in f:
            bc_idx[ line.strip() ] = count
            bc_rev_idx[ count ] = line.strip()
            count += 1
    return bc_idx, bc_rev_idx

def read_sf_bin(bf, numTxps):
    header_struct = struct.Struct("d"*numTxps)
    with gzip.open(bf) as f:
        count = 0
        no_read_count = 0
        tot_read_count = 0
        umiCounts = []
        while True:
            count += 1
            if count%100 == 0:
                print ("\r Done reading " + str(count) + " cells", end="")
                sys.stdout.flush()
            try:
                cell_counts = header_struct.unpack_from( f.read(header_struct.size) )
            except:
                print ("\nRead total " + str(count-1) + " cells")
                print ("Found total " + str(tot_read_count) + " reads")
                break
            read_count = 0.0
            for x in cell_counts:
                read_count += float(x)
            tot_read_count += read_count
            if read_count > 0.0:
                umiCounts.append( cell_counts )
            else:
                no_read_count += 1
    print ("No Read Count Cells: "+str(no_read_count))
    return umiCounts

def read_eq_bin(bf):
    header_struct = struct.Struct("Q"*2)
    with gzip.open(bf) as f:
        count = 0
        read_count = 0
        umiCounts = defaultdict(lambda: defaultdict(int))
        while True:
            count += 1
            if count%100 == 0:
                print ("\r Done reading " + str(count) + " cells", end="")
                sys.stdout.flush()
            try:
                bc, num_classes = header_struct.unpack_from( f.read(header_struct.size) )
            except:
                print ("\nRead total " + str(count-1) + " cells")
                print ("Found total " + str(read_count) + " reads")
                break
            if num_classes != 0:
                data_struct = struct.Struct("I"*2*num_classes)
                data = data_struct.unpack_from( f.read(data_struct.size) )
                for eqId, eqCount in zip(data[:num_classes], data[num_classes:]):
                    read_count += eqCount
                    umiCounts[bc][eqId] += eqCount
            else:
                print ("ERROR: Cell with no eqclass")
                exit(1)
    return umiCounts

if __name__=="__main__":
    match()
