#version 460

// won't need all of them, but some are used in this model
layout(set = 0, binding = 15) uniform atmosphere {    
    vec3 solar_irradiance;
    float sun_angular_radius;
    float bottom_radius;
    float top_radius;
    
    // rayleigh density
    float ray_width_a;
    float ray_exp_term_a;
    float ray_exp_scale_a;
    float ray_linear_term_a;
    float ray_constant_term_a;
    
    float ray_width_b;
    float ray_exp_term_b;
    float ray_exp_scale_b;
    float ray_linear_term_b;
    float ray_constant_term_b;
        
    vec3 ray_scattering;
    
     // mie density
    float mie_width_a;
    float mie_exp_term_a;
    float mie_exp_scale_a;
    float mie_linear_term_a;
    float mie_constant_term_a;
        
    float mie_width_b;
    float mie_exp_term_b;
    float mie_exp_scale_b;
    float mie_linear_term_b;
    float mie_constant_term_b;
    
    vec3 mie_scattering;
    vec3 mie_extinction;
    float mie_g;
    
    // absorbsion (ozone) density
    float absorb_width_a;
    float absorb_exp_term_a;
    float absorb_exp_scale_a;
    float absorb_linear_term_a;
    float absorb_constant_term_a; 
    
    float absorb_width_b;
    float absorb_exp_term_b;
    float absorb_exp_scale_b;
    float absorb_linear_term_b;
    float absorb_constant_term_b;
        
    vec3 absorb_extinction;
    vec3 ground_albedo;
};

layout(set = 0, binding = 12) uniform view {
    vec3 cam_pos; // camera position
    vec3 sun_dir; // sun direction
    float exposure;
};

layout(push_constant) uniform push {
    uint layer;
    uint width;
    uint height;
};

layout(location = 0) in noperspective centroid vec2 fragcoord;
layout(location = 0) out vec4 fragcolor;

const float PI = 3.14159265358979323846;

// <START>
// where does the planet start?
// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 planet_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - bottom_radius * bottom_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

// bruneton uses the cornette-shanks phase function
// so use the same here
// rayleigh assumes g = 0
float phase_ray(float cos_theta) {
    float k = 3.0 / (16.0 * PI);
    return k * (1.0 + cos_theta * cos_theta);
}

// mie assumes g = mie_g
float phase_mie(float cos_theta) {
    float k = 3.0 / (8.0 * PI) * (1.0 - mie_g * mie_g) / (2.0 + mie_g * mie_g);
    return k * (1.0 + cos_theta * cos_theta) / pow(1.0 + mie_g * mie_g - 2.0 * mie_g * cos_theta, 1.5);
}

// optical depth, without the coefficients applied
// rayleigh in x, mie in y, ozone in z
vec3 scaled_depth(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray);
    float h = sqrt(c);

    // sphere sizes
    vec4 r = vec4(
        // ray and mie
        max(h - 1.0 / vec2(ray_exp_scale_b, mie_exp_scale_b), bottom_radius),
        // ozone, no defined width of the ozone layer, so this is the best next thing
        // height + const / linear gives how far it extends outside the middle of the ozone layer
        max(h, bottom_radius + 1.5 * absorb_width_a + 0.5 * absorb_constant_term_b / absorb_linear_term_b),
        max(h, bottom_radius + 1.5 * absorb_width_a + 0.5 * absorb_constant_term_a / absorb_linear_term_a)
    );

    // scales, clamp to ensure it doesn't reach infinity too fast
    vec2 s = exp(max(0.0, h - bottom_radius) * vec2(ray_exp_scale_b, mie_exp_scale_b));

    // discriminant, max to ensure no NaN
    vec4 d = sqrt(max(b*b + r*r - c, 0.0));

    // scaled optical depths, as length inside a sphere
    return vec3(s * (d.xy - b), d.w - d.z);
}

// optical depth, along a ray
vec3 optical_depth(vec3 ray, vec3 dir) {
    // closest point along the ray to the planet center
    float mid = dot(ray, dir);

    // optical depth
    return mid > 0.0
        // looking up
        ? scaled_depth(ray, dir)
        // looking down, use the full ray - what's behind the viewer
        : scaled_depth(ray - dir * mid, dir) * 2.0 - scaled_depth(ray, -dir);
}

// attenuation integral
vec3 attenuate(vec3 a, vec3 b) {
    // prevent division by zero
    return mix((exp(-a) - exp(-b)) / (b - a), exp(-a), equal(a, b));
}

// scattering
vec3 scatter(vec3 ray, vec3 dir, vec3 light, float depth) {
    // coefficients, matrix makes it easy later
    mat3 extinct = mat3(ray_scattering, mie_extinction, absorb_extinction);

    // depth towards the camera and light
    vec3 opt_view_start = optical_depth(ray, dir);
    vec3 opt_light_start = optical_depth(ray, light);

    // depth towards the camera and light, at the surface intersection
    vec3 opt_view_end = depth < 0.0 ? vec3(0.0) : optical_depth(ray + dir * depth, dir);
    vec3 opt_light_end = depth < 0.0 ? vec3(0.0) : optical_depth(ray + dir * depth, light);

    // attenuation for scattering
    vec3 attn = attenuate(
        extinct * opt_light_start,
        extinct * (opt_light_end + opt_view_start - opt_view_end)
    );

    // phase function angle
    float cos_gamma = dot(dir, light);

    // combined scattering
    return solar_irradiance * (
        attn * (opt_view_start - opt_view_end).x * ray_scattering * phase_ray(cos_gamma) + 
        attn * (opt_view_start - opt_view_end).y * mie_scattering * phase_mie(cos_gamma)
    );
}

void transmittance() {
    // camera offset
    vec3 ray_start = cam_pos + vec3(0.0, bottom_radius, 0.0);

    // ray direction
    vec2 uv = (fragcoord - 0.5) * vec2(float(width) / float(height), 1.0);
    vec3 ray_dir = normalize(vec3(uv, 1.0));

    // check if we hit the planet
    vec2 planet = planet_bounds(ray_start, ray_dir);

    // coefficients, matrix makes it easy later
    mat3 extinct = mat3(ray_scattering, mie_extinction, absorb_extinction);

    // transmittance
    fragcolor.xyz = planet.x < 0.0
        // no intersect, full ray
        ? exp(extinct * -optical_depth(ray_start, ray_dir))
        // intersect, only part of the ray
        : exp(extinct * -(optical_depth(ray_start, ray_dir) - optical_depth(ray_start + planet.x * ray_dir, ray_dir)));
}

void render() {
    // camera offset
    vec3 ray_start = cam_pos + vec3(0.0, bottom_radius, 0.0);

    // ray direction
    vec2 uv = (fragcoord - 0.5) * vec2(float(width) / float(height), 1.0);
    vec3 ray_dir = normalize(vec3(uv, 1.0));

    // check if we hit the planet
    vec2 planet = planet_bounds(ray_start, ray_dir);

    // scattering
    fragcolor.xyz = scatter(ray_start, ray_dir, normalize(sun_dir), planet.x) * exposure;
}
// <END>

// main, decides the entrypoint
void main() {
    // purple to know when nothing is written
    fragcolor = vec4(1.0, 1.0, 0.0, 1.0);
#ifdef TRANSMITTANCE 
    transmittance();
#endif
#ifdef RENDER 
    render();
#endif
}
