#version 460

layout(set = 0, binding = 15) uniform atmosphere {    
    vec3 solar_irradiance;
    float sun_angular_radius;
    float bottom_radius;
    float top_radius;
    
    // rayleigh density
    float ray_width_a;
    float ray_exp_term_a;
    float ray_exp_scale_a;
    float ray_linear_term_a;
    float ray_constant_term_a;
    
    float ray_width_b;
    float ray_exp_term_b;
    float ray_exp_scale_b;
    float ray_linear_term_b;
    float ray_constant_term_b;
        
    vec3 ray_scattering;
    
     // mie density
    float mie_width_a;
    float mie_exp_term_a;
    float mie_exp_scale_a;
    float mie_linear_term_a;
    float mie_constant_term_a;
        
    float mie_width_b;
    float mie_exp_term_b;
    float mie_exp_scale_b;
    float mie_linear_term_b;
    float mie_constant_term_b;
    
    vec3 mie_scattering;
    vec3 mie_extinction;
    float mie_g;
    
    // absorbsion (ozone) density
    float absorb_width_a;
    float absorb_exp_term_a;
    float absorb_exp_scale_a;
    float absorb_linear_term_a;
    float absorb_constant_term_a; 
    
    float absorb_width_b;
    float absorb_exp_term_b;
    float absorb_exp_scale_b;
    float absorb_linear_term_b;
    float absorb_constant_term_b;
        
    vec3 absorb_extinction;
    vec3 ground_albedo;
};

layout(set = 0, binding = 12) uniform view {
    vec3 cam_pos; // camera position
};

layout(push_constant) uniform push {
    uint layer;
    uint width;
    uint height;
};

layout(location = 0) in noperspective centroid vec2 fragcoord;
layout(location = 0) out vec4 fragcolor;

const float PI = 3.14159265358979323846;
const uint STEPS = 128u; // more for better accuracy

// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 atmo_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - top_radius * top_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

// where does the planet start?
// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 planet_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - bottom_radius * bottom_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

vec3 absorb_density(vec3 pos) {
    float h = length(pos) - bottom_radius;
    float r = clamp(exp(ray_exp_scale_b * h), 0.0, 1.0);
    float m = clamp(exp(mie_exp_scale_b * h), 0.0, 1.0);
    float o = clamp(h < absorb_width_a
        ? (h * absorb_linear_term_a + absorb_constant_term_a)
        : (h * absorb_linear_term_b + absorb_constant_term_b)
        , 0.0, 1.0);

    return r * ray_scattering
        + m * mie_extinction
        + o * absorb_extinction;
}

// transmittance only!
vec3 transmittance(
    vec3 start,
    vec3 dir
) {
    // how far to travel in the atmosphere
    float end = planet_bounds(start, dir).y > 0.0
        ? planet_bounds(start, dir).x
        : atmo_bounds(start, dir).y;

    // stop if we hit anything
    if (end < 0.0) return vec3(1.0);
    
    // we won't use a variable sample count
    float dt = end / float(STEPS);

    // raymarch
    vec3 transmittance = vec3(1.0);

    float t = 0.0;
    float sample_t = 0.3;
    for (uint i = 0u; i < STEPS; i++) {
        // move
        float new_t = end * (float(i) + sample_t) / float(STEPS);
        dt = new_t - t;
        t = new_t;

        // position
        vec3 pos = start + dir * t;

        // sample medium
        vec3 sample_optical_depth = absorb_density(pos) * dt;
        vec3 sample_transmittance = exp(-sample_optical_depth);

        // add to transmittance
        transmittance *= sample_transmittance;
    }

    return transmittance;
}

// render the final image
void main() {
    // camera offset
    vec3 ray_start = cam_pos + vec3(0.0, bottom_radius, 0.0);

    // ray direction
    vec2 uv = (fragcoord - 0.5) * vec2(float(width) / float(height), 1.0);
    vec3 ray_dir = normalize(vec3(uv, 1.0));
    
    // raymarch
    // move to start
    float start = atmo_bounds(ray_start, ray_dir).x;
    if (start >= 0.0) ray_start += ray_dir * start;

    // result is the transmittance, directly
    fragcolor.w = 1.0;
    fragcolor.xyz = transmittance(ray_start, ray_dir);
}

