import crypt4gh
import crypt4gh.lib as lib
from crypt4gh.keys import get_public_key, get_private_key
import sys

import base64

def extract_raw_private_key(encoded_key):
    key_data_str = encoded_key.decode('utf-8')

    key_data_str = key_data_str.replace('-----BEGIN CRYPT4GH PRIVATE KEY-----', '')
    key_data_str = key_data_str.replace('-----END CRYPT4GH PRIVATE KEY-----', '')
    key_data_str = key_data_str.strip()
    
    # Remove header and footer, and decode from base64
    #key_data = encoded_key.strip().replace('-----BEGIN CRYPT4GH PRIVATE KEY-----', '').replace('-----END CRYPT4GH PRIVATE KEY-----', '')
    decoded_key = base64.b64decode(key_data_str)

    # Depending on the format, the raw 32-byte key should be somewhere within the decoded data.
    # Usually Crypt4GH metadata is 16 bytes, then the 32-byte key follows.
    secret_key = decoded_key[-32:]  # Get the last 32 bytes

    if len(secret_key) != 32:
        raise ValueError("Extracted key is not 32 bytes!")

    return decoded_key

def extract_raw_public_key(encoded_key):
    # Decode the bytes-like object to a string
    key_data_str2 = encoded_key.decode('utf-8')
    
    # Remove the header and footer
    key_data_str2 = key_data_str2.replace('-----BEGIN CRYPT4GH PUBLIC KEY-----', '')
    key_data_str2 = key_data_str2.replace('-----END CRYPT4GH PUBLIC KEY-----', '')
    key_data_str2 = key_data_str2.strip()
    
    # Base64 decode the cleaned key string
    decoded_key = base64.b64decode(key_data_str2)
    
    # Extract the first 32 bytes (the raw key)
    public_key = decoded_key[:32]

    if len(public_key) != 32:
        raise ValueError("Extracted key is not 32 bytes!")

    return decoded_key

def get_passphrase():
    return "12345678"

def encrypt_file(input_file, output_file, pubkey, seckey):
    with open(input_file, 'rb') as f_in, open(output_file, 'wb') as f_out:
        #crypt4gh.lib.encrypt(
        """
        lib.encrypt(
            f_in,
            f_out,
            keys=[(None, pubkey)],  # Public key for encryption
        )
        """
        #raw_private_key = extract_raw_private_key(seckey)
        #raw_public_key = extract_raw_public_key(pubkey)
        raw_public_key = get_public_key("credentials/alice.pub")
        raw_private_key = get_private_key("credentials/alice.sec", callback=get_passphrase)
        lib.encrypt(
            #[(None, pubkey)],
            [(0, raw_private_key, raw_public_key)],
            f_in,
            f_out,

        )
        
"""
def decrypt_file(input_file, output_file, seckey):
    with open(input_file, 'rb') as f_in, open(output_file, 'wb') as f_out:
        lib.decrypt(
            f_in,
            f_out,
            keys=[(seckey, None)]  # Secret key for decryption
        )
"""

def decrypt_file(input_file, output_file, seckey, pubkey):
    with open(input_file, 'rb') as f_in, open(output_file, 'wb') as f_out:
        #raw_private_key = extract_raw_private_key(seckey)
        #raw_public_key = extract_raw_public_key(pubkey)
        raw_public_key = get_public_key("credentials/alice.pub")
        raw_private_key = get_private_key("credentials/alice.sec", callback=get_passphrase)
        lib.encrypt(
            #[(None, pubkey)],
            [(0, raw_public_key, raw_private_key)],
            f_in,
            f_out,

        )



if __name__ == "__main__":
    action = sys.argv[1]
    input_file = sys.argv[2]
    output_file = sys.argv[3]
    priv_key_file = sys.argv[4]
    pub_key_file = sys.argv[5]

    if action == "encrypt":
        with open("credentials/alice.pub", 'rb') as key_pub, open("credentials/alice.sec", 'rb') as key_priv:
            public_key_bytes = key_pub.read()
            private_key_bytes = key_priv.read()
        encrypt_file(input_file, output_file, public_key_bytes, private_key_bytes)

    elif action == "decrypt":
        with open("credentials/alice.sec", 'rb') as key_priv, open("credentials/alice.pub", 'rb') as key_pub:
            private_key_bytes = key_priv.read()
            public_key_bytes = key_pub.read()
            #secret_key = key_f.read()
        decrypt_file(input_file, output_file, private_key_bytes, public_key_bytes)
    else:
        print("Usage: python workflow.py [encrypt|decrypt] input_file output_file key_file")
