import matplotlib.pyplot as plt
import numpy as np
import scipy
import concurrent.futures as futures

pool = futures.ThreadPoolExecutor(31)

# useful
def clamp(x):
    return np.minimum(np.maximum(x, 0.0), 1.0)

def smoothstep(x):
    return np.exp(-np.exp(-x))
    x = clamp(x)
    return x*x*(3.0 - 2.0*x)

# functions to fit
def poly3_2d(p, a, b, c, d, e, f, g):
    x, y = p.take(0, axis=-1), p.take(1, axis=-1)
    return g + a*x + b*x*x + c*x*x*x + d*y + e*y*y + f*y*y*y

def poly3_4d(p, a, b, c, d, e, f, g, h, i, j, k, l, m):
    x, y, z, w = p.take(0, axis=-1), p.take(1, axis=-1), p.take(2, axis=-1), p.take(3, axis=-1)
    return m + a*x + b*x*x + c*x*x*x + d*y + e*y*y + f*y*y*y \
             + g*z + h*z*z + i*z*z*z + j+w + k*w*w + l*w*w*w

def poly3_step_2d(p, a, b, c, d, e, f, g):
    x, y = p.take(0, axis=-1), p.take(1, axis=-1)
    return smoothstep(g + a*x + b*x*x + c*x*x*x + d*y + e*y*y + f*y*y*y)

def poly3_step_4d(p, a, b, c, d, e, f, g, h, i, j, k, l, m):
    x, y, z, w = p.take(0, axis=-1), p.take(1, axis=-1), p.take(2, axis=-1), p.take(3, axis=-1)
    return smoothstep(m + a*x + b*x*x + c*x*x*x + d*y + e*y*y + f*y*y*y
             + g*z + h*z*z + i*z*z*z + j+w + k*w*w + l*w*w*w)

def step_2d(p, a, b, d):
    h, c = p.take(0, axis=-1), p.take(1, axis=-1) * 2 - 1
    return np.exp(-np.exp(a - b * (h * h - d) * (c + 1)))

def opt_2d(p, a, b, d, e):
    # also fails to train
    h, c = p.take(0, axis=-1), p.take(1, axis=-1)
    return np.exp(-a * np.exp(e - h * b) / (1 - c * d))

# swap these out
#def fn_2d(p, a, b, c, d, e, f, g):
#    return poly3_2d(p, a, b, c, d, e, f, g)

def fn_2d(p, a, b, c, d):
    return opt_2d(p, a, b, c, d)

def fn_4d(p, a, b, c, d, e, f, g, h, i, j, k, l, m):
    return poly3_4d(p, a, b, c, d, e, f, g, h, i, j, k, l, m)

# functions that fit a single channel
def fit_work(name, fn, x, y, start):
    params, _ = scipy.optimize.curve_fit(fn, x, y, p0 = start)
    print(name, 'done')

    def showid():
        print(name + ":", params)
        return params

    return showid

def fit(name, fn, x, y, start=None):
    r = pool.submit(fit_work, name + " r", fn, x, y[:,0], start)
    g = pool.submit(fit_work, name + " g", fn, x, y[:,1], start)
    b = pool.submit(fit_work, name + " b", fn, x, y[:,2], start)
    return r, g, b

def main():
    # load tables
    t1 = np.load('../../images/tables/pathtrace-transmittance.npy')[:, :, :-1] # x = height, y = cos-theta
    s1 = np.load('../../images/tables/pathtrace-scattering.npy')[:, :, :-1] # x = cos-theta, y = cos-gamma, z = cos-light, w = height

    t2 = np.load('../../images/tables/pathtrace-transmittance-skewed.npy')[:, :, :-1] # x = cos-theta, y = height
    s2 = np.load('../../images/tables/pathtrace-scattering-skewed.npy')[:, :, :-1] # 

    # re-order scattering to a 4d + color array
    s1 = s1
    s2 = s2

    # generate parameters
    t1_x, t1_y = np.meshgrid(np.linspace(0, 1, t1.shape[0]), np.linspace(0, 1, t1.shape[1]))
    t2_x, t2_y = np.meshgrid(np.linspace(0, 1, t2.shape[0]), np.linspace(0, 1, t2.shape[1]))

    # stack for channels, makes it easier with flattening
    t1_p = np.stack([t1_x, t1_y], axis=-1)
    t2_p = np.stack([t2_x, t2_y], axis=-1)

    # scattering is the 4d-ish grid, just redo what was done in the shader
    # it wraps every 16 times
    s1_u, s1_v = np.meshgrid(np.linspace(0, 16, s1.shape[0]), np.linspace(0, 16, s1.shape[1]))
    s1_x, s1_y = np.modf(s1_u)[0], np.modf(s1_v)[0]
    s1_z, s1_w = np.modf(s1_u)[1] / 16, np.modf(s1_v)[1] / 16

    # same for s2
    s2_u, s2_v = np.meshgrid(np.linspace(0, 16, s2.shape[0]), np.linspace(0, 16, s2.shape[1]))
    s2_x, s2_y = np.modf(s2_u)[0], np.modf(s2_v)[0]
    s2_z, s2_w = np.modf(s2_u)[1] / 16, np.modf(s2_v)[1] / 16

    # again, stack
    s1_p = np.stack([s1_x, s1_y, s1_z, s1_w], axis=-1)
    s2_p = np.stack([s2_x, s2_y, s2_z, s2_w], axis=-1)
    
    # flat input-output pairs
    t1_p_f = np.reshape(t1_p, (-1, 2))
    t2_p_f = np.reshape(t2_p, (-1, 2))
    t1_r_f = np.reshape(t1, (-1, 3))
    t2_r_f = np.reshape(t2, (-1, 3))

    s1_p_f = np.reshape(s1_p, (-1, 4))
    s2_p_f = np.reshape(s2_p, (-1, 4))
    s1_r_f = np.reshape(s1, (-1, 3))
    s2_r_f = np.reshape(s2, (-1, 3))

    # polynomial fit!
    t1_fit_r, t1_fit_g, t1_fit_b = fit("t1", fn_2d, t1_p_f, t1_r_f, start=np.array([1, 1, 0.8, 1]))
    t2_fit_r, t2_fit_g, t2_fit_b = fit("t2", fn_2d, t2_p_f, t2_r_f, start=np.array([1, 1, 0.8, 1]))

    s1_fit_r, s1_fit_g, s1_fit_b = fit("s1", fn_4d, s1_p_f, s1_r_f)
    s2_fit_r, s2_fit_g, s2_fit_b = fit("s2", fn_4d, s2_p_f, s2_r_f)
    
    # image
    t1_img = np.stack([
        fn_2d(t1_p, *t1_fit_r.result()()),
        fn_2d(t1_p, *t1_fit_g.result()()),
        fn_2d(t1_p, *t1_fit_b.result()())],
        axis=-1
    )

    t2_img = np.stack([
        fn_2d(t2_p, *t2_fit_r.result()()),
        fn_2d(t2_p, *t2_fit_g.result()()),
        fn_2d(t2_p, *t2_fit_b.result()())],
        axis=-1
    )

    #s1_img = np.stack([
    #    fn_4d(s1_p, *s1_fit_r.result()()),
    #    fn_4d(s1_p, *s1_fit_g.result()()),
    #    fn_4d(s1_p, *s1_fit_b.result()())],
    #    axis=-1
    #)

    #s2_img = np.stack([
    #    fn_4d(s2_p, *s2_fit_r.result()()),
    #    fn_4d(s2_p, *s2_fit_g.result()()),
    #    fn_4d(s2_p, *s2_fit_b.result()())],
    #    axis=-1
    #)

    # show image
    plt.imshow(clamp(t1))
    plt.show()
    plt.imshow(clamp(t1_img))
    plt.show()
    #plt.imshow(clamp(t2_img))
    #plt.show()
    #plt.imshow(np.power(clamp(s1_img * 4), 1 / 2.2))
    #plt.show()
    #plt.imshow(np.power(clamp(s2_img * 4), 1 / 2.2))
    #plt.show()

    # reshape the scattering tables to fit the parameters
    # TODO

    # TODO: assign params to the curve fit

    # fit with scipy.curve_fit and np.polyfit

if __name__ == "__main__":
    main()
