import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import flip_evaluator as flip

class Exp(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return torch.exp(x)

class NN(nn.Module):
    def __init__(self, channels, hidden, colors):
        super().__init__()
        self.l1 = nn.Linear(channels, hidden)
        self.a1 = nn.Sigmoid()
        self.l2 = nn.Linear(hidden, colors)
        self.a2 = nn.Sigmoid()

    def forward(self, x):
        y = self.l1(x)
        y = self.a1(y)
        y = self.l2(y)
        y = self.a2(y)
        return y

    def save(self, path):
        w1 = self.l1.weight.detach().numpy()
        b1 = self.l1.bias.detach().numpy()

        w2 = self.l2.weight.detach().numpy()
        b2 = self.l2.bias.detach().numpy()

        # save with numpy
        np.savez(path, w1 = w1, b1 = b1, w2 = w2, b2 = b2)
        print(path, 'w1', w1)
        print(path, 'b1', b1)
        print(path, 'w2', w2)
        print(path, 'b2', b2)

class NN2(nn.Module):
    def __init__(self, channels, hidden, colors):
        super().__init__()
        self.l1 = nn.Linear(channels, hidden)
        self.a1 = nn.Sigmoid()
        self.l2 = nn.Linear(hidden, hidden)
        self.a2 = nn.Sigmoid()
        self.l3 = nn.Linear(hidden, colors)
        self.a3 = Exp()

    def forward(self, x):
        y = self.l1(x)
        y = self.a1(y)
        y = self.l2(y)
        y = self.a2(y)
        y = self.l3(y)
        y = self.a3(y)
        return y

    def save(self, path):
        w1 = self.l1.weight.detach().numpy()
        b1 = self.l1.bias.detach().numpy()

        w2 = self.l2.weight.detach().numpy()
        b2 = self.l2.bias.detach().numpy()

        w3 = self.l3.weight.detach().numpy()
        b3 = self.l3.bias.detach().numpy()

        # save with numpy
        np.savez(path, w1 = w1, b1 = b1, w2 = w2, b2 = b2, w3 = w3, b3 = b3)
        print(path, 'w1', w1)
        print(path, 'b1', b1)
        print(path, 'w2', w2)
        print(path, 'b2', b2)
        print(path, 'w3', w3)
        print(path, 'b3', b3)

class SplitNN(nn.Module):
    def __init__(self, channels, hidden, colors):
        super().__init__()
        self.ray = nn.Sequential(
            nn.Linear(channels, hidden), nn.Sigmoid(),
            nn.Linear(hidden, colors), nn.Sigmoid(),
        )
        self.mie = nn.Sequential(
            nn.Linear(channels, hidden), nn.Sigmoid(),
            nn.Linear(hidden, colors), nn.Sigmoid(),
        )

    def forward(self, x):
        cos_theta = x.select(-1, 3).unsqueeze(-1)
        ray = self.ray(x)
        mie = self.mie(x)

        # phase functions
        k_ray = 3.0 / (16.0 * torch.pi)
        phase_ray = k_ray * (1.0 + cos_theta * cos_theta)

        k_mie = 3.0 / (8.0 * torch.pi) * (1.0 - 0.7 * 0.7) / (2.0 + 0.7 * 0.7)
        phase_mie = k_mie * (1.0 + cos_theta * cos_theta) / torch.pow(1 + 0.7 * 0.7 - 2 * 0.7 * cos_theta, 1.5)

        return ray * phase_ray + mie * phase_mie

class Chap(nn.Module):
    def __init__(self, colors):
        super().__init__()

        # TODO: try random stuff
        # idea 1: look at the tables and try to figure out if it's fittable with smoothstep
        # probably possible for transmittance at least
        # sigmoid(height + polynomial for cos_theta), one for each channel

        # scale heights
        self.ray_pars = nn.Parameter(torch.rand(4, colors))
        self.mie_pars = nn.Parameter(torch.rand(4, colors))
        
    def forward(self, x):
        # get params out
        height, cos_theta = x.select(-1, 0).unsqueeze(-1), x.select(-1, 1).unsqueeze(-1)

        # smoothstep via sigmoid
        ray = self.ray_pars[0] + self.ray_pars[1] * height + self.ray_pars[2] * cos_theta + self.ray_pars[3] * cos_theta * cos_theta
        mie = self.mie_pars[0] + self.mie_pars[1] * height + self.mie_pars[2] * cos_theta + self.mie_pars[3] * cos_theta * cos_theta

        # TODO: ozone?

        # smoothstep
        ray = 1.0 / (1.0 + torch.exp(-ray))
        mie = 1.0 / (1.0 + torch.exp(-mie))
        
        return ray * mie

def train_model(x, y, model, name=None, batch_size=8192, epochs=10, lr=1e-2):
    # to tensors
    channels = x.shape[-1]
    colors = y.shape[-1]
    x = torch.tensor(x, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.float)

    # network
    loss_fn = nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_curve = []

    # epoch
    model.train()
    for i in range(epochs):
        # permutations
        perm = torch.randperm(x.size()[0])
        perm = torch.reshape(perm, (-1, batch_size))

        xt = x[perm]
        yt = y[perm]

        # batches
        avg_loss = 0
        for batch in range(perm.shape[0]):
            # forward
            optimizer.zero_grad()
            pred = model(xt[batch])
            loss = loss_fn(pred, yt[batch])
    
            # backprop
            loss.backward()
            optimizer.step()

            # stats
            avg_loss += loss.item() / perm.shape[0]

        loss_curve.append(avg_loss)
        print(name, "epoch", i + 1, "loss", avg_loss)

    model.eval()

    print(name + ':')
    for k, v in model.state_dict().items():
        print(' - ', k + ':', v)

    return model, loss_curve


def main():
    # load tables
    t1 = np.load('../../images/tables/pathtrace-transmittance.npy')[:, :, :-1] # x = height, y = cos-theta
    s1 = np.load('../../images/tables/pathtrace-scattering-proper.npy')[:, :, :-1] # x = height, y = cos-theta, z = cos-light, w = cos-gamma
    sk = np.load('../../images/tables/pathtrace-scattering-no-mie.npy')[:, :, :] # x = height, y = cos_theta, z = cos-light
    #s1 = np.load('../../images/tables/pathtrace-scattering-skewed.npy')[:,:, :-1]

    # ensure it's cliped to some maximum value to reduce fireflies
    s1 = np.clip(s1, 0, 4)
    sk = np.clip(sk, 0, 4)

    # generate parameters
    t1_x, t1_y = np.meshgrid(np.linspace(0, 1, t1.shape[0]), np.linspace(0, 1, t1.shape[1]))

    # stack for channels, makes it easier with flattening
    t1_p = np.stack([t1_x * (6460 - 6360), t1_y * 2 - 1], axis=-1)
    #t1_p = np.stack([t1_x, t1_y, t1_x * t1_y, t1_x + t1_y], axis=-1)

    # scattering is the 4d-ish grid, just redo what was done in the shader
    # it wraps every 16 times
    # see the path tracer (../comparison/shaders/pt-tables) for reasoning
    s1_u, s1_v = np.meshgrid(np.linspace(0, 16, s1.shape[0]), np.linspace(0, 16, s1.shape[1]))
    s1_x, s1_y = np.modf(s1_u)[0], np.modf(s1_v)[0]
    s1_z, s1_w = np.modf(s1_u)[1] / 16, np.modf(s1_v)[1] / 16

    # clamp nu
    s1_w = np.clip(s1_w * 2 - 1,
        s1_y * s1_z - np.sqrt((1.0 - s1_y * s1_y) * (1.0 - s1_z * s1_z)),
        s1_y * s1_z + np.sqrt((1.0 - s1_y * s1_y) * (1.0 - s1_z * s1_z))
    )

    # again, stack
    s1_p = np.stack([s1_x * s1_x * (6460 - 6360), s1_y * 2 - 1, s1_z * 1.2 - 0.2, s1_w], axis=-1)

    # 3d grid for the scattering
    # again, redo the shader
    sk_x, sk_y = np.meshgrid(np.linspace(0, 1, sk.shape[1]), np.linspace(0, 1, sk.shape[0]))
    sk_x, sk_z = np.modf(sk_x * 32)
    sk_x *= sk_x
    sk_z /= 32

    # stack
    sk_p = np.stack([sk_x * (6460 - 6360), sk_y * 2 - 1, sk_z * 1.2 - 0.2], axis=-1)
    
    # flat input-output pairs
    t1_p_f = np.reshape(t1_p, (-1, t1_p.shape[-1]))
    t1_r_f = np.reshape(t1, (-1, 3))

    s1_p_f = np.reshape(s1_p, (-1, s1_p.shape[-1]))
    s1_r_f = np.reshape(s1, (-1, 3))

    sk_p_f = np.reshape(sk_p, (-1, sk_p.shape[-1]))
    sk_r_f = np.reshape(sk, (-1, 4))
    
    # fit 4 neural nets
    t1_model, t1_loss = train_model(t1_p_f, t1_r_f, NN(2,4,3), "t1", lr=1e-2, epochs=40) # 2 works but misses ozone, 4 works best
    s1_model, s1_loss = train_model(s1_p_f, s1_r_f, NN2(4, 8, 3), "s1", lr=1e-2, epochs=40) # 4 works but misses height, 16 works best
    sk_model, sk_loss = train_model(sk_p_f, sk_r_f, NN(3, 4, 4), "sk", lr=1e-3, epochs=2) # 8 works kinda, 16 works best

    # no more training
    _ = torch.no_grad()
    
    # image
    t1_img = t1_model(torch.tensor(t1_p, dtype=torch.float)).detach().numpy()
    s1_img = s1_model(torch.tensor(s1_p, dtype=torch.float)).detach().numpy()
    sk_img = sk_model(torch.tensor(sk_p, dtype=torch.float)).detach().numpy()

    # ensure it's cliped to some maximum value to reduce fireflies
    print(s1_img.shape, np.min(s1_img), np.max(s1_img), np.any(np.isnan(s1_img)))
    
    # save the image for inspection
    print("s1 bounds", np.min(s1_img), np.max(s1_img))
    plt.imsave("scattering.png", np.clip(s1_img ** (1.0 / 2.2), 0.0, 1.0))

    # done, plot
    print("plotting...")
    
    t1_comp, t1_err, _ = flip.evaluate(t1, t1_img, 'HDR', computeMeanError=True)
    s1_comp, s1_err, _ = flip.evaluate(s1, s1_img, 'HDR', computeMeanError=True)
    sk_comp, sk_err, _ = flip.evaluate(sk, sk_img, 'HDR', computeMeanError=True)
    t1_l1 = np.abs(t1_img - t1)
    s1_l1 = np.abs(s1_img - s1)
    sk_l1 = np.abs(sk_img - sk)

    print("t1 flip", t1_err, "t1 MSE", np.mean((t1_img - t1)**2))
    print("s1 flip", s1_err, "s1 MSE", np.mean((s1_img - s1)**2))
    print("sk flip", sk_err, "sk MSE", np.mean((sk_img - sk)**2))

    #fig = plt.figure()
    #fig.add_subplot(4, 5, 1); plt.imshow(t1)
    #fig.add_subplot(4, 5, 2); plt.imshow(t1_img)
    #fig.add_subplot(4, 5, 3); plt.imshow(t1_comp)
    #fig.add_subplot(4, 5, 4); plt.imshow(t1_l1)
    #fig.add_subplot(4, 5, 5); plt.plot(t1_loss)
    #
    #fig.add_subplot(4, 5, 6); plt.imshow(s1 * 4)
    #fig.add_subplot(4, 5, 7); plt.imshow(s1_img * 4)
    #fig.add_subplot(4, 5, 8); plt.imshow(s1_comp)
    #fig.add_subplot(4, 5, 9); plt.imshow(s1_l1 * 8)
    #fig.add_subplot(4, 5, 10); plt.plot(s1_loss)

    #fig.add_subplot(4, 2, 5); plt.imshow(sk[:,:,:3] * 4)
    #fig.add_subplot(4, 2, 6); plt.imshow(sk_img[:,:,:3] * 4)
    #fig.add_subplot(4, 3, 10); plt.imshow(sk_comp * 4)
    #fig.add_subplot(4, 3, 11); plt.imshow(sk_l1[:,:,:-1] * 8)
    #fig.add_subplot(4, 3, 12); plt.plot(sk_loss)

    fig = plt.figure()
    fig.add_subplot(2, 2, 1); plt.imshow(s1 * 6)
    fig.add_subplot(2, 2, 2); plt.imshow(s1_img * 6)
    fig.add_subplot(2, 2, 3); plt.imshow(s1_comp)
    fig.add_subplot(2, 2, 4); plt.imshow(s1_l1 * 10)
    
    plt.show()

    # save trained models
    t1_model.save('transmittance-model.npz')
    s1_model.save('scattering-model.npz')

    # TODO: turn into code
    # TODO: train independent models for scattering?

if __name__ == "__main__":
    main()

