// shadertoy version of the path tracer, published at https://www.shadertoy.com/view/Wfd3D8

const vec3 solar_irradiance = vec3(4.0);
const float sun_angular_radius = 0.004675;
const float bottom_radius = 6360.0;
const float top_radius = 6460.0;

// rayleigh density
const float ray_width_a = 0.0;
const float ray_exp_term_a = 0.0;
const float ray_exp_scale_a = 0.0;
const float ray_linear_term_a = 0.0;
const float ray_constant_term_a = 0.0;

const float ray_width_b = 0.0;
const float ray_exp_term_b = 1.0;
const float ray_exp_scale_b = -0.125;
const float ray_linear_term_b = 0.0;
const float ray_constant_term_b = 0.0;
 
const vec3 ray_scattering = vec3(0.005802, 0.013558, 0.033100);

// mie density
const float mie_width_a = 0.0;
const float mie_exp_term_a = 0.0;
const float mie_exp_scale_a = 0.0;
const float mie_linear_term_a = 0.0;
const float mie_constant_term_a = 0.0;

const float mie_width_b = 0.0;
const float mie_exp_term_b = 1.0;
const float mie_exp_scale_b = -0.833333;
const float mie_linear_term_b = 0.0;
const float mie_constant_term_b = 0.0;

const vec3 mie_scattering = vec3(0.003996);
const vec3 mie_extinction = vec3(0.004440);
const float mie_g = 0.8;

// absorbsion (ozone) density
const float absorb_width_a = 25.0;
const float absorb_exp_term_a = 0.0;
const float absorb_exp_scale_a = 0.0;
const float absorb_linear_term_a = 0.066667;
const float absorb_constant_term_a = -0.666667; 

const float absorb_width_b = 0.0;
const float absorb_exp_term_b = 0.0;
const float absorb_exp_scale_b = 0.0;
const float absorb_linear_term_b = -0.66667;
const float absorb_constant_term_b = 2.666667;
   
const vec3 absorb_extinction = vec3(0.000650, 0.001881, 0.000085);
const vec3 ground_albedo = vec3(0.0);

const vec3 cam_pos = vec3(0.0, 5.0, 0.0); // camera position
const vec3 sun_dir = normalize(vec3(0.0, 0.1, 1.0)); // sun direction
const float depth = 5.0; // number of bounces

const float PI = 3.14159265358979323846;

// maximum number of tries before exiting a loop
const uint MAX_STEPS = 8192u;

// Mask for the color channel
uint channel = 0u;

// current frame
uint layer = 0u; // init later

// pcg, see https://www.pcg-random.org/
uvec4 pcg4d(inout uvec4 v) {
    v = v * 1664525u + 1013904223u;
    v.x += v.y * v.w;
    v.y += v.z * v.x;
    v.z += v.x * v.y;
    v.w += v.y * v.z;
    v = v ^ (v >> 16u);
    v.x += v.y * v.w;
    v.y += v.z * v.x;
    v.z += v.x * v.y;
    v.w += v.y * v.z;
    return v;
}

// random number state
uvec4 random_state = uvec4(0u); // init later

// generate a random number
float random() {
    // https://experilous.com/1/blog/post/perfect-fast-random-floating-point-numbers
    return uintBitsToFloat((pcg4d(random_state).x >> 9u) | 0x3f800000u) - 1.0;
}

// uniform random direction in sphere
vec3 sample_uniform_dir() {
    float x = random();
    float y = random();
    float phi = 2.0 * PI * x;
    float cos_theta = 1.0 - 2.0 * y;
    float sin_theta = sqrt(1.0 - cos_theta * cos_theta);
    return vec3(sin_theta * cos(phi), sin_theta * sin(phi), cos_theta);
}

// bruneton uses the cornette-shanks phase function
// so use the same here
// rayleigh assumes g = 0
float phase_ray(float cos_theta) {
    float k = 3.0 / (16.0 * PI);
    return k * (1.0 + cos_theta * cos_theta);
}

// mie assumes g = mie_g
float phase_mie(float cos_theta) {
    float k = 3.0 / (8.0 * PI) * (1.0 - mie_g * mie_g) / (2.0 + mie_g * mie_g);
    return k * (1.0 + cos_theta * cos_theta) / pow(1.0 + mie_g * mie_g - 2.0 * mie_g * cos_theta, 1.5);
}

// densities
float scatter_density(vec3 pos) {
    float h = length(pos) - bottom_radius;
    float r = clamp(exp(ray_exp_scale_b * h), 0.0, 1.0);
    float m = clamp(exp(mie_exp_scale_b * h), 0.0, 1.0);

    return r * ray_scattering[channel]
        + m * mie_scattering[channel];
}

float extinct_density(vec3 pos) {
    float h = length(pos) - bottom_radius;
    float r = clamp(exp(ray_exp_scale_b * h), 0.0, 1.0);
    float m = clamp(exp(mie_exp_scale_b * h), 0.0, 1.0);
    float o = clamp(h < absorb_width_a
        ? (h * absorb_linear_term_a + absorb_constant_term_a)
        : (h * absorb_linear_term_b + absorb_constant_term_b)
        , 0.0, 1.0);

    return r * ray_scattering[channel]
        + m * mie_extinction[channel]
        + o * absorb_extinction[channel];
}

// maximum densities
float max_scatter_density() {
    return ray_scattering[channel] + mie_scattering[channel];
}

float max_extinct_density() {
    return ray_scattering[channel] + mie_extinction[channel] + absorb_extinction[channel];
}

// does the ray hit the planet?
// adapted from: https://iquilezles.org/articles/spherefunctions/
bool planet_intersect(vec3 ray, vec3 dir) {    
    float b = dot(ray, dir);
    float c = dot(ray, ray) - bottom_radius * bottom_radius;
    float h = b * b - c;
    if (h < 0.0) return false;
    // hits if it hits the far end of the planet in front of the viewer
    return -b + sqrt(h) >= 0.0;
}

// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 atmo_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - top_radius * top_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

// where does the planet start?
// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 planet_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - bottom_radius * bottom_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

// ratio tracking, to estimate transmittance
float ratio_tracking(vec3 start, vec3 dir) {
    // hit the planet? stop, no transmittance
    if (planet_intersect(start, dir)) return 0.0;

    // distance from start
    float t = 0.0;
    vec2 t_bounds = atmo_bounds(start, dir);

    // else, ratio track
    float transmittance = 1.0;
    for (uint i = 0u; i < MAX_STEPS; i++) {
        float xi = random();
        t += -log(1.0 - xi) / max_extinct_density();

        // stop if out of bounds
        if (t > t_bounds.y) break;

        // absorb
        transmittance *= 1.0 - (extinct_density(start + dir * t) / max_extinct_density());
    }

    return transmittance;
}

// scattering types
const uint I_MISS = 0u; // no scattering
const uint I_GROUND = 1u; // hit the ground
const uint I_RAY = 2u; // rayleigh scattering
const uint I_MIE = 3u; // mie scattering
const uint I_ABSORB = 4u; // absorb

// delta tracking, for the next scatter event
uint delta_tracking(vec3 start, vec3 dir, out float t) {
    t = 0.0;

    // maximum distance we can trace
    vec2 t_max = planet_bounds(start, dir);
    vec2 t_bounds = atmo_bounds(start, dir);

    // track
    for (uint i = 0u; i < MAX_STEPS; i++) {
        float xi = random();
        float ki = random();
        float zi = random();
        t += -log(1.0 - xi) / max_extinct_density();
        if (t > t_bounds.y) {
            // outside of volume
            return I_MISS;
        } else if (t_max.y > 0.0 && t > t_max.x) {
            // hit the ground
            return I_GROUND;
        } else if (ki < scatter_density(start + dir * t) / max_extinct_density()) {
            // scatter event
            float h = length(start + dir * t) - bottom_radius;
            float d_ray = exp(ray_exp_scale_b * h) * ray_scattering[channel];
            float d_mie = exp(mie_exp_scale_b * h) * mie_scattering[channel];

            // what type?
            return zi < d_mie / (d_mie + d_ray) ? I_MIE : I_RAY;
        } else if (ki < extinct_density(start + dir * t) / max_extinct_density()) {
            // absorb
            return I_ABSORB;
        }
    }

    return I_MISS;
}

// integrate
float integrate(vec3 start, vec3 dir) {
    // accumulated luminance
    float luminance = 0.0;
    float  throughput = 1.0;

    // loop till end of depth
    for (uint i = 0u; i < uint(depth); i++) {
        // find the next hit
        float t;
        uint hit = delta_tracking(start, dir, t);

        // ground? stop
        if (hit == I_GROUND) break;

        // miss? try again
        else if (hit == I_MISS) break;

        // absorb? stop
        else if (hit == I_ABSORB) break;

        // rayleigh? mie? add luminance
        else if (hit == I_RAY || hit == I_MIE) {
            // pick the phase
            float phase
                = hit == I_RAY
                ? phase_ray(dot(normalize(sun_dir), dir))
                : phase_mie(dot(normalize(sun_dir), dir));

            // add to luminance
            luminance
                += throughput
                * ratio_tracking(start + dir * t, normalize(sun_dir))
                * solar_irradiance[channel]
                * phase;
                
            // change direction
            vec3 new_dir = sample_uniform_dir();

            // new throughput
            throughput
                *= hit == I_RAY
                ? phase_ray(dot(new_dir, dir)) * 4.0 * PI
                : phase_mie(dot(new_dir, dir)) * 4.0 * PI;
            
            start += dir * t;
            dir = new_dir;
        }
    }

    return luminance;
}

void mainImage(out vec4 fragcolor, vec2 fragcoord) {
    // init
    layer = iFrame == 0 || iMouse.w > 0.0 ? 0u : uint(texelFetch(iChannel0, ivec2(fragcoord), 0).a);
    random_state = uvec4(
        // pixel pos
        uvec2(fragcoord.xy) + 1u,
        // what iteration we are on in the pass
        layer + 1u,
        // extra randomness
        channel + 1u
    );
    
    // camera offset
    vec3 ray_start = cam_pos + vec3(0.0, bottom_radius, 0.0);

    // ray direction
    vec2 uv = ((fragcoord / iResolution.xy) - 0.5) * vec2(float(iResolution.x) / float(iResolution.y), 1.0);
    vec3 ray_dir = normalize(vec3(uv, 1.0));

    // move to atmosphere start
    float start = atmo_bounds(ray_start, ray_dir).x;
    if (start >= 0.0) ray_start += ray_dir * start;

    // trace all 3 colors in one go
    channel = 0u; fragcolor.x = integrate(ray_start, ray_dir); // red
    channel = 1u; fragcolor.y = integrate(ray_start, ray_dir); // green
    channel = 2u; fragcolor.z = integrate(ray_start, ray_dir); // blue
    fragcolor.w = float(layer + 1u);

    fragcolor.rgb = mix(
        layer == 0u ? vec3(0.0) : texelFetch(iChannel0, ivec2(fragcoord), 0).rgb,
        fragcolor.rgb,
        1.0 / float(layer + 1u)
    );
}
