from pycompss.api.api import compss_barrier
from pycompss.api.task import task
from pycompss.api.constraint import constraint
from pycompss.api.parameter import Cache, Type, INOUT, Weight

import dislib as ds
from dislib.data.array import Array
import os
import glob
import time
import numpy
import csv
import sys



@constraint(processors=[
                {"processorType": "CPU", "computingUnits": "1"},
                {"processorType": "GPU", "computingUnits": "1"},
            ])
@task(A={Cache: True}, B={Cache: True}, cache_returns=False)
def fused_multiply_add_case1(A, B):
    import cupy as cp
    return cp.matmul(cp.asarray(A), cp.asarray(B))


@constraint(processors=[
                {"processorType": "CPU", "computingUnits": "1"},
                {"processorType": "GPU", "computingUnits": "1"},
            ])
@task(A={Cache: False, Weight: "1"}, B={Cache: False, Weight: "1"}, C={Cache: True, Type: INOUT, Weight: "100000"})
def fused_multiply_add_case2(A, B, C):
    import cupy as cp
    from cupy_backends.cuda.libs import cublas

    handle = cp.cuda.device.get_cublas_handle()
    one = numpy.array(1, dtype=numpy.float64)

    start_gpu = cp.cuda.Event()
    end_gpu = cp.cuda.Event()
    start_gpu.record()
    cublas.dgemm(
            handle,
            0,  # transa
            0,  # transb
            A.shape[0], B.shape[1], A.shape[1], one.ctypes.data,
            A.data.ptr, A.shape[0],
            B.data.ptr, B.shape[0],
            one.ctypes.data, C.data.ptr, A.shape[0])
    end_gpu.record()
    end_gpu.synchronize()


@constraint(processors=[
                {"processorType": "CPU", "computingUnits": "1"},
                {"processorType": "GPU", "computingUnits": "1"},
            ])
@task(A={Cache: True, Weight: "1"}, B={Cache: True, Weight: "1"}, C={Cache: True, Type: INOUT, Weight: "2"})
def fused_multiply_add_case3(A, B, C):
    import cupy as cp
    from cupy_backends.cuda.libs import cublas

    handle = cp.cuda.device.get_cublas_handle()
    one = numpy.array(1, dtype=numpy.float64)

    start_gpu = cp.cuda.Event()
    end_gpu = cp.cuda.Event()
    start_gpu.record()
    cublas.dgemm(
            handle,
            0,  # transa
            0,  # transb
            A.shape[0], B.shape[1], A.shape[1], one.ctypes.data,
            A.data.ptr, A.shape[0],
            B.data.ptr, B.shape[0],
            one.ctypes.data, C.data.ptr, A.shape[0])
    end_gpu.record()
    end_gpu.synchronize()


def matmul_case1(A, B):
    n, m = len(A._blocks), len(B._blocks[0])
    l = len(A._blocks[0])

    C = Array._get_out_blocks((n, m))

    for i in range(n):
        for j in range(m):
            for k in range(l):
                C[i][j] = fused_multiply_add_case1(A._blocks[i][k], B._blocks[k][j])
    
    return C


def matmul_case2(A, B):
    n, m = len(A._blocks), len(B._blocks[0])
    l = len(A._blocks[0])

    out_bs = A._reg_shape[0], B._reg_shape[1]
    C = ds.zeros((n*out_bs[0], m*out_bs[1]), out_bs, dtype=numpy.float64)

    for i in range(n):
        for j in range(m):
            for k in range(l):
                fused_multiply_add_case2(A._blocks[i][k], B._blocks[k][j], C._blocks[i][j])
    
    return C


def matmul_case3(A, B):
    n, m = len(A._blocks), len(B._blocks[0])
    l = len(A._blocks[0])

    out_bs = A._reg_shape[0], B._reg_shape[1]
    C = ds.zeros((n*out_bs[0], m*out_bs[1]), out_bs, dtype=numpy.float64)

    for i in range(n):
        for j in range(m):
            for k in range(l):
                fused_multiply_add_case3(A._blocks[i][k], B._blocks[k][j], C._blocks[i][j])
    
    return C


if __name__ == '__main__':
    for file in glob.glob('/gpfs/scratch/bsc19/bsc19959/cache/kk/*'):
        os.remove(file)

    gpus = sys.argv[1]
    cache = sys.argv[2]
    case = int(sys.argv[3])


    if case == 1:
        x = ds.random_array((10*int(gpus), 11380000*5), (10, 11380000))     # num_blocks (gpus, 5)
        y = ds.random_array((11380000*5, 10*1), (11380000, 10))             # num_blocks (5, 1)
        matmul_func = matmul_case1
    elif case == 2:
        x = ds.random_array((16000*int(gpus), 10*5), (16000, 10))           # num_blocks (gpus, 5)
        y = ds.random_array((10*5, 16000*1), (10, 16000))                   # num_blocks (5, 5)
        matmul_func = matmul_case2
    elif case == 3:
        x = ds.random_array((15000*int(gpus), 15000*5), (15000, 15000))     # num_blocks (gpus, 5)
        y = ds.random_array((15000*5, 15000*1), (15000, 15000))             # num_blocks (5, 1)
        matmul_func = matmul_case3
    else:
        raise Exception('UNKNOWN CASE')


    r = matmul_func(x, y)