import matplotlib.pyplot as plt
import numpy as np
import scipy
from pysr import PySRRegressor

N = 10_000

def fit_model(x, y):
    model = PySRRegressor(
        #populations=30,
        #population_size=64,
        maxsize=20,
        niterations=10,
        binary_operators=["+", "-", "*", "/"],
        warm_start=False,
        turbo=True,
        unary_operators=["exp", "erf"],
    )

    model.fit(x, y, variable_names=["h", "c", "z", "g"][:np.shape(x)[1]])

    return model

def main():
    # load tables
    t1 = np.load('../../images/tables/pathtrace-transmittance.npy')[:, :, :-1] # x = height, y = cos-theta
    s1 = np.load('../../images/tables/pathtrace-scattering.npy')[:, :, :-1] # x = cos-theta, y = cos-gamma, z = cos-light, w = height

    t2 = np.load('../../images/tables/pathtrace-transmittance-skewed.npy')[:, :, :-1] # x = cos-theta, y = height
    s2 = np.load('../../images/tables/pathtrace-scattering-skewed.npy')[:, :, :-1] # 

    # re-order scattering to a 4d + color array
    s1 = s1
    s2 = s2

    # generate parameters
    t1_x, t1_y = np.meshgrid(np.linspace(0, 1, t1.shape[0]), np.linspace(0, 1, t1.shape[1]))
    t2_x, t2_y = np.meshgrid(np.linspace(0, 1, t2.shape[0]), np.linspace(0, 1, t2.shape[1]))

    # stack for channels, makes it easier with flattening
    t1_p = np.stack([t1_x, t1_y], axis=-1)
    t2_p = np.stack([t2_x, t2_y], axis=-1)

    # scattering is the 4d-ish grid, just redo what was done in the shader
    # it wraps every 16 times
    s1_u, s1_v = np.meshgrid(np.linspace(0, 16, s1.shape[0]), np.linspace(0, 16, s1.shape[1]))
    s1_x, s1_y = np.modf(s1_u)[0], np.modf(s1_v)[0]
    s1_z, s1_w = np.modf(s1_u)[1] / 16, np.modf(s1_v)[1] / 16

    # same for s2
    s2_u, s2_v = np.meshgrid(np.linspace(0, 16, s2.shape[0]), np.linspace(0, 16, s2.shape[1]))
    s2_x, s2_y = np.modf(s2_u)[0], np.modf(s2_v)[0]
    s2_z, s2_w = np.modf(s2_u)[1] / 16, np.modf(s2_v)[1] / 16

    # again, stack
    s1_p = np.stack([s1_x, s1_y, s1_z, s1_w], axis=-1)
    s2_p = np.stack([s2_x, s2_y, s2_z, s2_w], axis=-1)

    # 10k samples each
    t1_samp = np.random.randint(N, size=N)
    t2_samp = np.random.randint(N, size=N)
    s1_samp = np.random.randint(N, size=N)
    s2_samp = np.random.randint(N, size=N)
    
    # flat input-output pairs
    t1_p_f = np.reshape(t1_p, (-1, 2))[t1_samp]
    t2_p_f = np.reshape(t2_p, (-1, 2))[t2_samp]
    t1_r_f = np.reshape(t1, (-1, 3))[t1_samp]
    t2_r_f = np.reshape(t2, (-1, 3))[t2_samp]

    s1_p_f = np.reshape(s1_p, (-1, 4))[s1_samp]
    s2_p_f = np.reshape(s2_p, (-1, 4))[s2_samp]
    s1_r_f = np.reshape(s1, (-1, 3))[s1_samp]
    s2_r_f = np.reshape(s2, (-1, 3))[s2_samp]

    # function fit!
    m_t1_r = fit_model(t1_p_f, t1_r_f[:,0])

    print(m_t1_r)
    exit()
    
    m_t2_r = fit_model(t2_p_f, t2_r_f[:,0])
    m_s1_r = fit_model(s1_p_f, s1_r_f[:,0])
    m_s2_r = fit_model(s2_p_f, s2_r_f[:,0])

    m_t1_g = fit_model(t1_p_f, t1_r_f[:,1])
    m_t2_g = fit_model(t2_p_f, t2_r_f[:,1])
    m_s1_g = fit_model(s1_p_f, s1_r_f[:,1])
    m_s2_g = fit_model(s2_p_f, s2_r_f[:,1])

    m_t1_b = fit_model(t1_p_f, t1_r_f[:,2])
    m_t2_b = fit_model(t2_p_f, t2_r_f[:,2])
    m_s1_b = fit_model(s1_p_f, s1_r_f[:,2])
    m_s2_b = fit_model(s2_p_f, s2_r_f[:,2])

    # image
    t1_img = np.stack([
        fn_2d(t1_p, *t1_fit_r.result()()),
        fn_2d(t1_p, *t1_fit_g.result()()),
        fn_2d(t1_p, *t1_fit_b.result()())],
        axis=-1
    )

    t2_img = np.stack([
        fn_2d(t2_p, *t2_fit_r.result()()),
        fn_2d(t2_p, *t2_fit_g.result()()),
        fn_2d(t2_p, *t2_fit_b.result()())],
        axis=-1
    )

    s1_img = np.stack([
        fn_4d(s1_p, *s1_fit_r.result()()),
        fn_4d(s1_p, *s1_fit_g.result()()),
        fn_4d(s1_p, *s1_fit_b.result()())],
        axis=-1
    )

    s2_img = np.stack([
        fn_4d(s2_p, *s2_fit_r.result()()),
        fn_4d(s2_p, *s2_fit_g.result()()),
        fn_4d(s2_p, *s2_fit_b.result()())],
        axis=-1
    )

    # show image
    plt.imshow(clamp(t1_img))
    plt.show()
    plt.imshow(clamp(t2_img))
    plt.show()
    plt.imshow(np.power(clamp(s1_img * 4), 1 / 2.2))
    plt.show()
    plt.imshow(np.power(clamp(s2_img * 4), 1 / 2.2))
    plt.show()

    # reshape the scattering tables to fit the parameters
    # TODO

    # TODO: assign params to the curve fit

    # fit with scipy.curve_fit and np.polyfit

if __name__ == "__main__":
    main()
