#version 460

// won't need all of them, but some are used in this model
layout(set = 0, binding = 15) uniform atmosphere {    
    vec3 solar_irradiance;
    float sun_angular_radius;
    float bottom_radius;
    float top_radius;
    
    // rayleigh density
    float ray_width_a;
    float ray_exp_term_a;
    float ray_exp_scale_a;
    float ray_linear_term_a;
    float ray_constant_term_a;
    
    float ray_width_b;
    float ray_exp_term_b;
    float ray_exp_scale_b;
    float ray_linear_term_b;
    float ray_constant_term_b;
        
    vec3 ray_scattering;
    
     // mie density
    float mie_width_a;
    float mie_exp_term_a;
    float mie_exp_scale_a;
    float mie_linear_term_a;
    float mie_constant_term_a;
        
    float mie_width_b;
    float mie_exp_term_b;
    float mie_exp_scale_b;
    float mie_linear_term_b;
    float mie_constant_term_b;
    
    vec3 mie_scattering;
    vec3 mie_extinction;
    float mie_g;
    
    // absorbsion (ozone) density
    float absorb_width_a;
    float absorb_exp_term_a;
    float absorb_exp_scale_a;
    float absorb_linear_term_a;
    float absorb_constant_term_a; 
    
    float absorb_width_b;
    float absorb_exp_term_b;
    float absorb_exp_scale_b;
    float absorb_linear_term_b;
    float absorb_constant_term_b;
        
    vec3 absorb_extinction;
    vec3 ground_albedo;
};

layout(set = 0, binding = 12) uniform view {
    vec3 cam_pos; // camera position
    vec3 sun_dir; // sun direction
    float exposure;
};

layout(push_constant) uniform push {
    uint layer;
    uint width;
    uint height;
};

layout(location = 0) in noperspective centroid vec2 fragcoord;
layout(location = 0) out vec4 fragcolor;

// Adapted from https://github.com/wwwtyro/glsl-atmosphere/blob/master/index.glsl
// Unilicense

// <START>
const float PI = 3.14159265358979323846;
const uint STEPS = 5u; // 5 gives good enough quality

// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 atmo_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - top_radius * top_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

// where does the planet start?
// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 planet_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - bottom_radius * bottom_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}
// bruneton uses the cornette-shanks phase function
// so use the same here
// rayleigh assumes g = 0
float phase_ray(float cos_theta) {
    float k = 3.0 / (16.0 * PI);
    return k * (1.0 + cos_theta * cos_theta);
}

// mie assumes g = mie_g
float phase_mie(float cos_theta) {
    float k = 3.0 / (8.0 * PI) * (1.0 - mie_g * mie_g) / (2.0 + mie_g * mie_g);
    return k * (1.0 + cos_theta * cos_theta) / pow(1.0 + mie_g * mie_g - 2.0 * mie_g * cos_theta, 1.5);
}

// optical depth, without the coefficients applied
// rayleigh in x, mie in y, ozone in z
vec3 scaled_depth(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray);
    float h = sqrt(c);

    // sphere sizes
    vec4 r = vec4(
        // ray and mie
        max(h - 1.0 / vec2(ray_exp_scale_b, mie_exp_scale_b), bottom_radius),
        // ozone, no defined width of the ozone layer, so this is the best next thing
        // height + const / linear gives how far it extends outside the middle of the ozone layer
        max(h, bottom_radius + 1.5 * absorb_width_a + 0.5 * absorb_constant_term_b / absorb_linear_term_b),
        max(h, bottom_radius + 1.5 * absorb_width_a + 0.5 * absorb_constant_term_a / absorb_linear_term_a)
    );

    // scales, clamp to ensure it doesn't reach infinity too fast
    vec2 s = exp(max(0.0, h - bottom_radius) * vec2(ray_exp_scale_b, mie_exp_scale_b));

    // discriminant, max to ensure no NaN
    vec4 d = sqrt(max(b*b + r*r - c, 0.0));

    // scaled optical depths, as length inside a sphere
    return vec3(s * (d.xy - b), d.w - d.z);
}

// optical depth
vec3 opt(vec3 ray, vec3 dir) {
    // closest point along the ray to the planet center
    float mid = dot(ray, dir);

    // optical depth
    return mid > 0.0
        // looking up
        ? scaled_depth(ray, dir)
        // looking down, use the full ray - what's behind the viewer
        : scaled_depth(ray - dir * mid, dir) * 2.0 - scaled_depth(ray, -dir);
}

// attenuation integral
vec3 attenuate(vec3 a, vec3 b) {
    // prevent division by zero
    return mix((exp(-a) - exp(-b)) / (b - a), exp(-a), equal(a, b));
}

// trace a single ray
vec3 scatter(vec3 start, vec3 dir, vec3 light) { 
    // how far to travel in the atmosphere
    float begin = atmo_bounds(start, dir).x;
    float end = planet_bounds(start, dir).y > 0.0
        ? planet_bounds(start, dir).x
        : atmo_bounds(start, dir).y;

    // stop if we hit anything
    if (end < 0.0) return vec3(0.0);

    // else, move over start to atmosphere start
    start += dir * max(0.0, begin);

    // step size
    float dt = (end - max(0.0, begin)) / float(STEPS);

    // total amount of scattering
    vec3 scatter_ray = vec3(0.0);
    vec3 scatter_mie = vec3(0.0);

    // total optical depth, assuming no end
    vec3 opt_total = opt(start, dir);

    // optical depth at the previous step
    vec3 opt_prev = opt_total;

    // optical depth to light at the previous step
    vec3 opt_light_prev = opt(start, light);

    // attenuation
    mat3 extinct = mat3(ray_scattering, mie_extinction, absorb_extinction);
    
    // loop
    for(uint i = 1; i <= STEPS; i++) {
        // current position, at the end of the segment
        vec3 pos = start + dir * float(i) * dt;

        // optical depth from here
        vec3 opt_step = opt(pos, dir);

        // optical depth to light from here
        vec3 opt_light = opt(pos, light);

        // segment optical depth
        vec3 opt_segment = opt_prev - opt_step;

        // integrate
        vec3 attn = attenuate(
            // begin is optical depth from before the segment + from the light at the start
            extinct * (opt_total - opt_prev + opt_light_prev),
            // end is the optical depth from before + the segment + from the light at the end
            extinct * (opt_total - opt_step + opt_light)
        );

        // scatter
        scatter_ray += opt_segment.x * attn;
        scatter_mie += opt_segment.y * attn;
        
        // update
        opt_light_prev = opt_light;
        opt_prev = opt_step;
    }

    // return total scattering
    float cos_gamma = dot(dir, light);
    return solar_irradiance * (
        scatter_ray * ray_scattering * phase_ray(cos_gamma)
        + scatter_mie * mie_scattering * phase_mie(cos_gamma)
    );
    
}

void main() {
    // camera offset
    vec3 ray_start = cam_pos + vec3(0.0, bottom_radius, 0.0);

    // ray direction
    vec2 uv = (fragcoord - 0.5) * vec2(float(width) / float(height), 1.0);
    vec3 ray_dir = normalize(vec3(uv, 1.0));

    vec3 radiance = scatter(ray_start, ray_dir, normalize(sun_dir));

    fragcolor = vec4(radiance * exposure, 1.0);
}
// <END>
