import numpy as np

transmittance = np.load('transmittance-model.npz')
scattering = np.load('scattering-model.npz')

shader = '''
vec3 sigmoid(vec3 x) {
    return 1.0 / (1.0 + exp(-x));
}

vec4 sigmoid(vec4 x) {
    return 1.0 / (1.0 + exp(-x));
}
'''

# TODO:properly convert?
# https://www.youtube.com/watch?v=8pwXpfi-0bU&t=749s

# transmittance
tw1 = transmittance['w1']
tw2 = transmittance['w2']
tb1 = transmittance['b1']
tb2 = transmittance['b2']

shader += f'''
vec3 transmittance_nn(float height, float cos_theta) {{
    // input
    vec2 x0 = vec2(height, cos_theta);

    // layer 1
    vec4 x1 = sigmoid(
        mat2x4({", ".join([f"{x:.4f}" for x in tw1.T.flatten()])})
        * x0 + vec4({", ".join([f"{x:.4f}" for x in tb1])})  
    );

    // layer 2
    return sigmoid(
        mat4x3({", ".join([f"{x:.4f}" for x in tw2.T.flatten()])})
        * x1 + vec3({" ,".join([f"{x:.4f}" for x in tb2])})
    );
}}
'''

# scattering
tw1 = scattering['w1']
tw2 = scattering['w2']
tw3 = scattering['w3']
tb1 = scattering['b1']
tb2 = scattering['b2']
tb3 = scattering['b3']

print(tb3)

# TODO: more, because we have bigger layers
shader += f'''
vec3 scattering_nn(float height, float cos_theta, float cos_light, float cos_gamma) {{
    // input
    vec4 x0 = vec4(height, cos_theta, cos_light, cos_gamma);

    // layer 1
    vec4 x1_0 = sigmoid(
        mat4x4({", ".join([f"{x:.4f}" for x in tw1[0:4,:].T.flatten()])})
        * x0 + vec4({", ".join([f"{x:.4f}" for x in tb1[0:4]])})  
    );
    vec4 x1_1 = sigmoid(
        mat4x4({", ".join([f"{x:.4f}" for x in tw1[4:8,:].T.flatten()])})
        * x0 + vec4({", ".join([f"{x:.4f}" for x in tb1[4:8]])})
    );

    // layer 2
    vec4 x2_0 = sigmoid(
        mat4x4({", ".join([f"{x:.4f}" for x in tw2[0:4,0:4].T.flatten()])}) * x1_0
        + mat4x4({", ".join([f"{x:.4f}" for x in tw2[0:4,4:8].T.flatten()])}) * x1_1
        + vec4({", ".join([f"{x:.4f}" for x in tb2[0:4]])})
    );

    vec4 x2_1 = sigmoid(
        mat4x4({", ".join([f"{x:.4f}" for x in tw2[4:8,0:4].T.flatten()])}) * x1_0
        + mat4x4({", ".join([f"{x:.4f}" for x in tw2[4:8,4:8].T.flatten()])}) * x1_1
        + vec4({", ".join([f"{x:.4f}" for x in tb2[4:8]])})
    );
    
    // layer 3
    return exp(
        mat4x3({", ".join([f"{x:.4f}" for x in tw3[:,0:4].T.flatten()])}) * x2_0
        + mat4x3({", ".join([f"{x:.4f}" for x in tw3[:,4:8].T.flatten()])}) * x2_1
        + vec3({" ,".join([f"{x:.4f}" for x in tb3])})
    );
}}
'''

print(shader)
