#version 460

// won't need all of them, but some are used in this model
layout(set = 0, binding = 15) uniform atmosphere {    
    vec3 solar_irradiance;
    float sun_angular_radius;
    float bottom_radius;
    float top_radius;
    
    // rayleigh density
    float ray_width_a;
    float ray_exp_term_a;
    float ray_exp_scale_a;
    float ray_linear_term_a;
    float ray_constant_term_a;
    
    float ray_width_b;
    float ray_exp_term_b;
    float ray_exp_scale_b;
    float ray_linear_term_b;
    float ray_constant_term_b;
        
    vec3 ray_scattering;
    
     // mie density
    float mie_width_a;
    float mie_exp_term_a;
    float mie_exp_scale_a;
    float mie_linear_term_a;
    float mie_constant_term_a;
        
    float mie_width_b;
    float mie_exp_term_b;
    float mie_exp_scale_b;
    float mie_linear_term_b;
    float mie_constant_term_b;
    
    vec3 mie_scattering;
    vec3 mie_extinction;
    float mie_g;
    
    // absorbsion (ozone) density
    float absorb_width_a;
    float absorb_exp_term_a;
    float absorb_exp_scale_a;
    float absorb_linear_term_a;
    float absorb_constant_term_a; 
    
    float absorb_width_b;
    float absorb_exp_term_b;
    float absorb_exp_scale_b;
    float absorb_linear_term_b;
    float absorb_constant_term_b;
        
    vec3 absorb_extinction;
    vec3 ground_albedo;
};

layout(set = 0, binding = 12) uniform view {
    vec3 cam_pos; // camera position
    vec3 sun_dir; // sun direction
    float exposure;
};

layout(push_constant) uniform push {
    uint layer;
    uint width;
    uint height;
};

layout(location = 0) in noperspective centroid vec2 fragcoord;
layout(location = 0) out vec4 fragcolor;

// <START>
// neural networks
vec3 sigmoid(vec3 x) {
    return 1.0 / (1.0 + exp(-x));
}

vec4 sigmoid(vec4 x) {
    return 1.0 / (1.0 + exp(-x));
}

vec3 transmittance_nn(float height, float cos_theta) {
    // input
    vec2 x0 = vec2(height, cos_theta);

    // layer 1
    vec4 x1 = sigmoid(
        mat2x4(0.0214, -0.0301, 0.0167, -0.0137, -7.7263, 0.3688, -23.9130, 27.3124)
        * x0 + vec4(0.2639, 0.9069, 2.9238, -3.3962)
    );

    // layer 2
    return sigmoid(
        mat4x3(5.1255, 6.4604, 5.3080, -10.4257, -12.4410, -19.4679, 5.5411, 6.7883, 8.2314, -13.7020, -16.9345, -24.2136)
        * x1 + vec3(-1.3815 ,-3.5609 ,0.0589)
    );
}

vec3 scattering_nn(float height, float cos_theta, float cos_light, float cos_gamma) {
    // input
    vec4 x0 = vec4(height, cos_theta, cos_light, cos_gamma);

    // layer 1
    vec4 x1_0 = sigmoid(
        mat4x4(-0.0172, -0.1585, 0.0333, -0.0285, 22.3921, -0.2200, 1.4461, 25.3310, -0.0007, -0.0781, 0.1343, 0.0405, -0.0648, 0.0780, -0.3812, 0.0019)
        * x0 + vec4(-0.6036, -0.7347, -0.1642, -1.1600)
    );
    vec4 x1_1 = sigmoid(
        mat4x4(0.0037, -0.0006, -0.0144, 0.0025, -2.6395, 0.1363, 35.6077, 3.8818, -0.4070, -4.3260, 0.1487, 4.8999, 0.4393, -0.0526, -0.0515, -5.9698)
        * x0 + vec4(-1.2288, -2.2124, -0.0969, -6.3626)
    );

    // layer 2
    vec4 x2_0 = sigmoid(
        mat4x4(-3.6339, -1.7529, -160.8163, -1.2848, 2.3444, 4.6788, 0.3880, 0.5233, 0.1393, 0.4772, 0.6404, 1.2324, 1.9955, 3.0016, -96.2228, -1.2282) * x1_0
        + mat4x4(0.8563, -10.3158, 10.8031, 0.9464, 32.1428, 6.0045, -19.5529, 5.2667, 2.1930, 1.2001, 6.2481, -1.3785, -4.7430, 1.5145, 3.7958, -0.7921) * x1_1
        + vec4(-5.7566, -3.9390, -3.9847, -0.7349)
    );

    vec4 x2_1 = sigmoid(
        mat4x4(-7.8952, -3.5147, -2.8325, -3.5076, -0.4839, -1.4481, 3.8031, 42.6138, 5.0020, 4.6249, 1.0171, 1.1628, -1.0066, -6.9750, 4.4482, -1.1279) * x1_0
        + mat4x4(5.6382, 2.8548, -10.0346, -10.7135, 1.1769, -5.9692, -7.9183, -0.9696, 1.3309, 3.4965, 3.7415, 9.3012, -0.8607, -1.1860, -2.3595, -1.5907) * x1_1
        + vec4(-2.1066, -4.0293, -3.1043, -18.8468)
    );

    // layer 3
    return exp(
        mat4x3(-10.5722, -14.9728, -15.4197, -5.1958, -5.2867, -5.5605, -0.9151, -0.9667, -1.0897, -2.0293, -3.0072, -7.4153) * x2_0
        + mat4x3(0.0178, 1.6596, 5.6463, -13.0418, -12.7041, -12.4718, -3.0943, -2.1226, -0.9017, -27.3160, -20.7745, -18.8013) * x2_1
        + vec3(-1.1564 ,-1.2119 ,-1.2709)
    );
}

// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 atmo_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - top_radius * top_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

// where does the planet start?
// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 planet_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - bottom_radius * bottom_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

void transmittance() {
    // camera offset
    vec3 ray_start = cam_pos + vec3(0.0, bottom_radius, 0.0);

    // ray direction
    vec2 uv = (fragcoord - 0.5) * vec2(float(width) / float(height), 1.0);
    vec3 ray_dir = normalize(vec3(uv, 1.0));

    // bonuds
    vec2 t_start = atmo_bounds(ray_start, ray_dir);
    vec2 t_end = planet_bounds(ray_start, ray_dir);

    // stop if no hit
    if (t_start.y <= 0.0) fragcolor.xyz = vec3(1.0);
    else if (t_end.y <= 0.0) {
        // single hit
        // move to boundary
        ray_start += ray_dir * max(t_start.x, 0.0);

        // transmittance
        float height = length(ray_start) - bottom_radius;
        float cos_theta = dot(ray_dir, normalize(ray_start));

        fragcolor.xyz = transmittance_nn(height, -cos_theta);        
    } else {
        // hit the planet
        // move to boundary
        vec3 ray_start_n = ray_start + ray_dir * max(t_start.x, 0.0);

        // transmittance
        float height = length(ray_start_n) - bottom_radius;
        float cos_theta = dot(ray_dir, normalize(ray_start_n));

        // reverse the direction here, this makes transmittance correct
        fragcolor.xyz = transmittance_nn(height, cos_theta);

        // move to planet boundary
        vec3 ray_start_f = ray_start + ray_dir * max(t_end.x, 0.0);

        // transmittance
        height = length(ray_start_f) - bottom_radius;
        cos_theta = dot(ray_dir, normalize(ray_start_f));

        // same here
        fragcolor.xyz = transmittance_nn(height, cos_theta) / fragcolor.xyz;
    }

    // ensure range
    fragcolor.xyz = clamp(fragcolor.xyz, vec3(0.0), vec3(1.0));
}

void render() {
    // camera offset
    vec3 ray_start = cam_pos + vec3(0.0, bottom_radius, 0.0);

    // ray direction
    vec2 uv = (fragcoord - 0.5) * vec2(float(width) / float(height), 1.0);
    vec3 ray_dir = normalize(vec3(uv, 1.0));
    // bounds
    vec2 t_start = atmo_bounds(ray_start, ray_dir);
    vec2 t_end = planet_bounds(ray_start, ray_dir);

    // stop if no hit
    if (t_start.y <= 0.0) fragcolor.xyz = vec3(0.0);
    else if (t_end.y <= 0.0) {
        // single hit
        // move to boundary
        ray_start += ray_dir * max(t_start.x, 0.0);

        // transmittance
        float height = length(ray_start) - bottom_radius;
        float cos_theta = dot(ray_dir, normalize(ray_start));
        float cos_light = dot(sun_dir, normalize(ray_start));
        float cos_gamma = dot(ray_dir, sun_dir);
        
        fragcolor.xyz = scattering_nn(height, -cos_theta, cos_light, cos_gamma) * solar_irradiance * exposure;
    } else {
        // hit the planet
        // move to boundary
        vec3 ray_start_n = ray_start + ray_dir * max(t_start.x, 0.0);

        // transmittance
        float height = length(ray_start_n) - bottom_radius;
        float cos_theta = dot(ray_dir, normalize(ray_start_n));
        float cos_light = dot(sun_dir, normalize(ray_start_n));
        float cos_gamma = dot(ray_dir, sun_dir);

        vec3 transmittance_viewer = transmittance_nn(height, cos_theta);
        fragcolor.xyz = scattering_nn(height, -cos_theta, cos_light, cos_gamma) * solar_irradiance * exposure;

        // move to planet boundary
        vec3 ray_start_f = ray_start + ray_dir * max(t_end.x, 0.0);

        // transmittance
        height = length(ray_start_f) - bottom_radius;
        cos_theta = dot(ray_dir, normalize(ray_start_f));
        cos_light = dot(sun_dir, normalize(ray_start_f));
        cos_gamma = dot(ray_dir, sun_dir);

        vec3 transmittance_surface = transmittance_nn(height, cos_theta);
        fragcolor.xyz -= scattering_nn(height, -cos_theta, cos_light, cos_gamma)
            * solar_irradiance * exposure * (transmittance_surface / transmittance_viewer);
    }

    // ensure range
    fragcolor.xyz = clamp(fragcolor.xyz, vec3(0.0), vec3(1.0));
}
// <END>

// main, decides the entrypoint
void main() {
    // purple to know when nothing is written
    fragcolor = vec4(1.0, 1.0, 0.0, 1.0);
#ifdef TRANSMITTANCE 
    transmittance();
#endif
#ifdef RENDER 
    render();
#endif
}
