import matplotlib.pyplot as plt
import numpy as np
import torch
from kan import *
from kan.utils import create_dataset_from_data
import flip_evaluator as flip

def train_kan(x, y, name=None, batch_size=4096, epochs=10, hidden=4, lr=1e-2):
    # to tensors
    channels = x.shape[-1]
    colors = y.shape[-1]
    x = torch.tensor(x, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.float)

    # make dataset
    ds = create_dataset_from_data(x, y)

    # make model
    model = KAN(width=[channels, hidden, colors], grid=3, k=3, seed=42)

    # fit the model
    model.speed()
    model.fit(ds, opt="LBFGS", steps=epochs*10, lamb=0.0, batch=1024)

    model = model.eval()

    # run the model
    def run(x):
        # do it batched else we run out of ram
        # it still uses ~30G here so apologies
        return torch.stack([model(x[b]) for b in range(x.shape[0])])
    
    return run, []


def main():
    # load tables
    t1 = np.load('../../images/tables/pathtrace-transmittance.npy')[:, :, :-1] # x = height, y = cos-theta
    s1 = np.load('../../images/tables/pathtrace-scattering.npy')[:, :, :-1] # x = cos-theta, y = cos-gamma, z = cos-light, w = height
    sk = np.load('../../images/tables/pathtrace-scattering-no-mie.npy')[:, :, :] # x = height, y = cos_theta, z = cos-light

    # ensure it's cliped to some maximum value to reduce fireflies
    s1 = np.clip(s1, 0, 4)
    sk = np.clip(sk, 0, 4)

    # generate parameters
    t1_x, t1_y = np.meshgrid(np.linspace(0, 1, t1.shape[0]), np.linspace(0, 1, t1.shape[1]))

    # stack for channels, makes it easier with flattening
    t1_p = np.stack([t1_x * (6460 - 6360), t1_y * 2 - 1], axis=-1)
    #t1_p = np.stack([t1_x, t1_y, t1_x * t1_y, t1_x + t1_y], axis=-1)

    # scattering is the 4d-ish grid, just redo what was done in the shader
    # it wraps every 16 times
    # see the path tracer (../comparison/shaders/pt-tables) for reasoning
    s1_u, s1_v = np.meshgrid(np.linspace(0, 16, s1.shape[0]), np.linspace(0, 16, s1.shape[1]))
    s1_x, s1_y = np.modf(s1_u)[0], np.modf(s1_v)[0]
    s1_z, s1_w = np.modf(s1_u)[1] / 16, np.modf(s1_v)[1] / 16

    # again, stack
    s1_p = np.stack([s1_x * 2 - 1, s1_y * 2 - 1, s1_z * 1.1 - 0.1, s1_w * s1_w * (6460 - 6360)], axis=-1)

    # 3d grid for the scattering
    # again, redo the shader
    sk_x, sk_y = np.meshgrid(np.linspace(0, 1, sk.shape[1]), np.linspace(0, 1, sk.shape[0]))
    sk_x, sk_z = np.modf(sk_x * 32)
    sk_x *= sk_x
    sk_z /= 32

    # stack
    sk_p = np.stack([sk_x * (6460 - 6360), sk_y * 2 - 1, sk_z * 1.2 - 0.2], axis=-1)
    
    # flat input-output pairs
    t1_p_f = np.reshape(t1_p, (-1, t1_p.shape[-1]))
    t1_r_f = np.reshape(t1, (-1, 3))

    s1_p_f = np.reshape(s1_p, (-1, s1_p.shape[-1]))
    s1_r_f = np.reshape(s1, (-1, 3))

    sk_p_f = np.reshape(sk_p, (-1, sk_p.shape[-1]))
    sk_r_f = np.reshape(sk, (-1, 4))
    
    # fit 4 neural nets
    t1_model, t1_loss = train_kan(t1_p_f, t1_r_f, "t1", hidden=4, lr=1e-2, epochs=10) # 2 works but misses ozone, 4 works best
    s1_model, s1_loss = train_kan(s1_p_f, s1_r_f, "s1", hidden=16, lr=1e-3, epochs=20) # 4 works but misses height, 16 works best
    sk_model, sk_loss = train_kan(sk_p_f, sk_r_f, "sk", hidden=16, lr=1e-3, epochs=20) # 8 works kinda, 16 works best

    # no more training
    torch.no_grad()
    
    # image
    t1_img = t1_model(torch.tensor(t1_p, dtype=torch.float)).detach().numpy()
    s1_img = s1_model(torch.tensor(s1_p, dtype=torch.float)).detach().numpy()
    sk_img = sk_model(torch.tensor(sk_p, dtype=torch.float)).detach().numpy()

    # ensure it's cliped to some maximum value to reduce fireflies
    s1_img = np.clip(s1_img, 0, 8)
    print(s1_img.shape, np.min(s1_img), np.max(s1_img), np.any(np.isnan(s1_img)))

    # done, plot
    print("plotting...")
    t1_comp, t1_err, _ = flip.evaluate(t1, t1_img, 'HDR', computeMeanError=True)
    s1_comp, s1_err, _ = flip.evaluate(s1, s1_img, 'HDR', computeMeanError=True)
    sk_comp, sk_err, _ = flip.evaluate(sk, sk_img, 'HDR', computeMeanError=True)
    t1_l1 = np.abs(t1_img - t1)
    s1_l1 = np.abs(s1_img - s1)
    sk_l1 = np.abs(sk_img - sk)

    print("t1 flip", t1_err, "t1 MSE", np.mean((t1_img - t1)**2))
    print("s1 flip", s1_err, "s1 MSE", np.mean((s1_img - s1)**2))
    print("sk flip", sk_err, "sk MSE", np.mean((sk_img - sk)**2))

    fig = plt.figure()
    fig.add_subplot(4, 5, 1); plt.imshow(t1)
    fig.add_subplot(4, 5, 2); plt.imshow(t1_img)
    fig.add_subplot(4, 5, 3); plt.imshow(t1_comp)
    fig.add_subplot(4, 5, 4); plt.imshow(t1_l1)
    fig.add_subplot(4, 5, 5); plt.plot(t1_loss)
    
    fig.add_subplot(4, 5, 6); plt.imshow(s1 * 4)
    fig.add_subplot(4, 5, 7); plt.imshow(s1_img * 4)
    fig.add_subplot(4, 5, 8); plt.imshow(s1_comp)
    fig.add_subplot(4, 5, 9); plt.imshow(s1_l1 * 8)
    fig.add_subplot(4, 5, 10); plt.plot(s1_loss)

    fig.add_subplot(4, 2, 5); plt.imshow(sk[:,:,:3] * 4)
    fig.add_subplot(4, 2, 6); plt.imshow(sk_img[:,:,:3] * 4)
    fig.add_subplot(4, 3, 10); plt.imshow(sk_comp * 4)
    fig.add_subplot(4, 3, 11); plt.imshow(sk_l1[:,:,:-1] * 8)
    fig.add_subplot(4, 3, 12); plt.plot(sk_loss)
    
    plt.show()

if __name__ == "__main__":
    main()


