// shadertoy version of the path tracer, published at https://www.shadertoy.com/view/wcdGRX

const vec3 solar_irradiance = vec3(4.0);
const float sun_angular_radius = 0.004675;
const float bottom_radius = 6360.0;
const float top_radius = 6460.0;

// rayleigh density
const float ray_width_a = 0.0;
const float ray_exp_term_a = 0.0;
const float ray_exp_scale_a = 0.0;
const float ray_linear_term_a = 0.0;
const float ray_constant_term_a = 0.0;

const float ray_width_b = 0.0;
const float ray_exp_term_b = 1.0;
const float ray_exp_scale_b = -0.125;
const float ray_linear_term_b = 0.0;
const float ray_constant_term_b = 0.0;
 
const vec3 ray_scattering = vec3(0.005802, 0.013558, 0.033100);

// mie density
const float mie_width_a = 0.0;
const float mie_exp_term_a = 0.0;
const float mie_exp_scale_a = 0.0;
const float mie_linear_term_a = 0.0;
const float mie_constant_term_a = 0.0;

const float mie_width_b = 0.0;
const float mie_exp_term_b = 1.0;
const float mie_exp_scale_b = -0.833333;
const float mie_linear_term_b = 0.0;
const float mie_constant_term_b = 0.0;

const vec3 mie_scattering = vec3(0.003996);
const vec3 mie_extinction = vec3(0.004440);
const float mie_g = 0.8;

// absorbsion (ozone) density
const float absorb_width_a = 25.0;
const float absorb_exp_term_a = 0.0;
const float absorb_exp_scale_a = 0.0;
const float absorb_linear_term_a = 0.066667;
const float absorb_constant_term_a = -0.666667; 

const float absorb_width_b = 0.0;
const float absorb_exp_term_b = 0.0;
const float absorb_exp_scale_b = 0.0;
const float absorb_linear_term_b = -0.66667;
const float absorb_constant_term_b = 2.666667;
   
const vec3 absorb_extinction = vec3(0.000650, 0.001881, 0.000085);
const vec3 ground_albedo = vec3(0.0);

const vec3 cam_pos = vec3(0.0, 5.0, 0.0); // camera position
const vec3 sun_dir = normalize(vec3(0.0, 0.1, 1.0)); // sun direction

const float PI = 3.14159265358979323846;
const uint STEPS = 40u; // 40 used in the original implementation

// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 atmo_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - top_radius * top_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

// where does the planet start?
// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 planet_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - bottom_radius * bottom_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

// bruneton uses the cornette-shanks phase function
// so use the same here
// rayleigh assumes g = 0
float phase_ray(float cos_theta) {
    float k = 3.0 / (16.0 * PI);
    return k * (1.0 + cos_theta * cos_theta);
}

// mie assumes g = mie_g
float phase_mie(float cos_theta) {
    float k = 3.0 / (8.0 * PI) * (1.0 - mie_g * mie_g) / (2.0 + mie_g * mie_g);
    return k * (1.0 + cos_theta * cos_theta) / pow(1.0 + mie_g * mie_g - 2.0 * mie_g * cos_theta, 1.5);
}

// densities
vec3 scatter_density(vec3 pos) {
    float h = length(pos) - bottom_radius;
    float r = clamp(exp(ray_exp_scale_b * h), 0.0, 1.0);
    float m = clamp(exp(mie_exp_scale_b * h), 0.0, 1.0);

    return r * ray_scattering
        + m * mie_scattering;
}

vec3 absorb_density(vec3 pos) {
    float h = length(pos) - bottom_radius;
    float r = clamp(exp(ray_exp_scale_b * h), 0.0, 1.0);
    float m = clamp(exp(mie_exp_scale_b * h), 0.0, 1.0);
    float o = clamp(h < absorb_width_a
        ? (h * absorb_linear_term_a + absorb_constant_term_a)
        : (h * absorb_linear_term_b + absorb_constant_term_b)
        , 0.0, 1.0);

    return r * ray_scattering
        + m * mie_extinction
        + o * absorb_extinction;
}

// needed as well, as we need to multiply by phases
vec3 scatter_density_phases(vec3 pos, float ray, float mie) {
    float h = length(pos) - bottom_radius;
    float r = clamp(exp(ray_exp_scale_b * h), 0.0, 1.0);
    float m = clamp(exp(mie_exp_scale_b * h), 0.0, 1.0);

    return r * ray_scattering * ray
        + m * mie_scattering * mie;
}

// trace a single scattering ray
// adapted from the original skyatmosphere implementation, in RenderSkyRayMarching.hlsl
void scatter(
    vec3 start,
    vec3 dir,
    vec3 light_dir,
    bool isotropic,
    sampler2D iChannel0, sampler2D iChannel1, sampler2D iChannel2,
    out vec3 transmittance,
    out vec3 scattering,
    out vec3 multiscattering
) {
    // how far to travel in the atmosphere
    float end = planet_bounds(start, dir).y > 0.0
        ? planet_bounds(start, dir).x
        : atmo_bounds(start, dir).y;

    // stop if we hit anything
    if (end < 0.0) {
        scattering = vec3(0.0);
        transmittance = vec3(1.0);
        multiscattering = vec3(0.0);
        return;
    }

    // we won't use a variable sample count
    float dt = end / float(STEPS);

    // phase
    float cos_theta = dot(dir, light_dir);
    float uniform_phase = 1.0 / (4.0 * PI);
    float ray_phase = phase_ray(cos_theta);
    float mie_phase = phase_mie(cos_theta);
    
    // raymarch
    scattering = vec3(0.0);
    transmittance = vec3(1.0);
    multiscattering = vec3(0.0);

    float t = 0.0;
    float sample_t = 0.3;
    for (uint i = 0u; i < STEPS; i++) {
        // move
        float new_t = end * (float(i) + sample_t) / float(STEPS);
        dt = new_t - t;
        t = new_t;

        // position
        vec3 pos = start + dir * t;

        // sample medium
        vec3 sample_optical_depth = absorb_density(pos) * dt;
        vec3 sample_transmittance = exp(-sample_optical_depth);

        // uv for the lookup texture
        float cos_theta = dot(normalize(pos), light_dir);
            
        // read from the lut
        float h = length(pos);
        float h_max = sqrt(max(0.0, top_radius * top_radius - bottom_radius * bottom_radius));
        float rho = sqrt(max(0.0, h * h - bottom_radius * bottom_radius));

        float disc = h * h * (cos_theta * cos_theta - 1.0) + top_radius * top_radius;
        float d = max(0.0, (-h * cos_theta + sqrt(disc)));

        float d_min = top_radius - h;
        float d_max = rho + h_max;
        float x_mu = (d - d_min) / (d_max - d_min);
        float x_r = rho / h_max;

        // read position
        vec2 uv = vec2(x_mu, x_r);

        // read from the transmittance texture
        vec3 transmittance_to_light = texture(iChannel0, uv).xyz;

        // scattering, with phase function
        vec3 phase_scattering
            = isotropic
            ? scatter_density_phases(pos, uniform_phase, uniform_phase)
            : scatter_density_phases(pos, ray_phase, mie_phase);

        // planet shadow
        // 1 = not shadowed
        float shadow = planet_bounds(pos, light_dir).y > 0.0 ? 0.0 : 1.0;

        // read from the multiple scattering table
        uv = vec2(cos_theta * 0.5 + 0.5, (length(pos) - bottom_radius) / (top_radius - bottom_radius));

        // multiscattering contribution
        vec3 multiscattered_luminance = texture(iChannel1, uv).xyz;

        // scattering contribution
        vec3 s = shadow * transmittance_to_light * phase_scattering + multiscattered_luminance * scatter_density(pos);

        // integrate multiple scattering
        vec3 ms = scatter_density(pos);
        vec3 sample_multiscatter = (ms - ms * sample_transmittance) / absorb_density(pos);
        multiscattering += transmittance * sample_multiscatter;
        
        // integrate
        vec3 sample_scattering = (s - s * sample_transmittance) / absorb_density(pos);

        scattering += transmittance * sample_scattering;
        transmittance *= sample_transmittance;
    }
}

// compute transmittance
vec4 transmittance(vec2 fragcoord, sampler2D iChannel0, sampler2D iChannel1, sampler2D iChannel2) {
    // height at this pixel
    float h = sqrt(top_radius * top_radius - bottom_radius * bottom_radius);
    float rho = h * (fragcoord.y);
    float view_height = sqrt(rho * rho + bottom_radius * bottom_radius);

    // view angle at this pixel
    float d_min = top_radius - view_height;
    float d_max = rho + h;
    float d = d_min + fragcoord.x * (d_max - d_min);
    float cos_zenith
        = d == 0.0
        ? 1.0
        : clamp((h * h - rho * rho - d * d) / (2.0 * view_height * d), -1.0, 1.0);

    // start position and direction
    vec3 start = vec3(0.0, 0.0, view_height);
    vec3 dir = vec3(0.0, sqrt(1.0 - cos_zenith * cos_zenith), cos_zenith);

    vec3 transmittance, scattering, multiscattering;
    scatter(start, dir, vec3(1.0), false, iChannel0, iChannel1, iChannel2, transmittance, scattering, multiscattering);

    // transmittance out
    return vec4(transmittance, 1.0);
}

// render the single scattering table
vec4 single_scatter(vec2 fragcoord, sampler2D iChannel0, sampler2D iChannel1, sampler2D iChannel2) {
    // using non-linear lut
    float view_height = length(cam_pos + vec3(0.0, bottom_radius, 0.0));
    float v_horizon = sqrt(view_height * view_height - bottom_radius * bottom_radius);
    float cos_beta = v_horizon / view_height;
    float beta = acos(cos_beta);
    float zenith_angle = PI - beta;
    float cos_zenith = 0.0;

    if (fragcoord.y < 0.5) {
        float coord = 1.0 - 2.0 * (fragcoord.y);
        coord = 1.0 - (coord * coord);
        cos_zenith = cos(coord * zenith_angle);
    } else {
        float coord = (fragcoord.y) * 2.0 - 1.0;
        coord *= coord;
        cos_zenith = cos(zenith_angle + beta * coord);
    }

    float coord = fragcoord.x * fragcoord.x;
    float cos_gamma = -(coord * 2.0 - 1.0);

    // set up positions and directions
    vec3 start = vec3(0.0, view_height, 0.0);

    vec3 up_dir = normalize(cam_pos + vec3(0.0, bottom_radius, 0.0));
    float light_theta = dot(up_dir, normalize(sun_dir));
    vec3 light_dir = normalize(vec3(0.0, light_theta, sqrt(1.0 - light_theta * light_theta)));

    float sin_zenith = sqrt(1.0 - cos_zenith * cos_zenith);
    vec3 dir = vec3(
        sin_zenith * sqrt(1.0 - cos_gamma * cos_gamma),
        cos_zenith,
        sin_zenith * cos_gamma
    );


    vec3 transmittance, scattering, multiscattering;
    scatter(start, dir, light_dir, false, iChannel0, iChannel1, iChannel2, transmittance, scattering, multiscattering);

    // scattering out
    return vec4(scattering, 1.0);
}

// render the multiple scattering table
vec4 multi_scatter(vec2 fragcoord, sampler2D iChannel0, sampler2D iChannel1, sampler2D iChannel2) {
    // this is originally a compute shader, that runs each sample direction in a different thread
    // then combines the result afterward
    // Here we just do it with a loop in the pixel shader
    float cos_gamma = fragcoord.x * 2.0 - 1.0;
    vec3 light_dir = vec3(0.0, sqrt(1.0 - cos_gamma * cos_gamma), cos_gamma);
    float h = bottom_radius + (fragcoord.y) * (top_radius - bottom_radius);

    vec3 start = vec3(0.0, 0.0, h);

    // results to accumulate
    vec3 total_scattering = vec3(0.0);
    vec3 total_multiscattering = vec3(0.0);

    // integrate over the sphere
    for (uint i = 0u; i < 8u; i++)
    for (uint j = 0u; j < 8u; j++) {
        // direction
        float a = (float(i) + 0.5) / 8.0;
        float b = (float(j) + 0.5) / 8.0;
        float theta = 2.0 * PI * a;
        float phi = acos(1.0 - 2.0 * b);

        float cos_phi = cos(phi);
        float sin_phi = sin(phi);
        float cos_theta = cos(theta);
        float sin_theta = sin(theta);

        vec3 dir = vec3(cos_theta * sin_phi, sin_theta * sin_phi, cos_phi);

        // sample scattering
        vec3 transmittance, scattering, multiscattering;
        scatter(start, dir, light_dir, true, iChannel0, iChannel1, iChannel2, transmittance, scattering, multiscattering);

        // contributes 8*8 of the sphere's area
        total_scattering += scattering / 64.0;
        total_multiscattering += multiscattering / 64.0;
    }

    // power series
    vec3 r = total_multiscattering;
    vec3 sum = 1.0 / (1.0 - r);
    vec3 l = total_scattering * sum;

    return vec4(l, 1.0);
}

// render the final image
vec4 render(vec2 fragcoord, float width, float height, sampler2D iChannel0, sampler2D iChannel1, sampler2D iChannel2) {
    // camera offset
    vec3 ray_start = cam_pos + vec3(0.0, bottom_radius, 0.0);

    // ray direction
    vec2 uv = (fragcoord - 0.5) * vec2(float(width) / float(height), 1.0);
    vec3 ray_dir = normalize(vec3(uv, 1.0));

    // outside?
    // raymarch
    if (length(ray_start) > top_radius) {
        // move to start
        float start = atmo_bounds(ray_start, ray_dir).x;
        if (start >= 0.0) ray_start += ray_dir * start;

        // raymarch
        vec3 transmittance, scattering, multiscattering;
        scatter(ray_start, ray_dir, normalize(sun_dir), false, iChannel0, iChannel1, iChannel2, transmittance, scattering, multiscattering);

        // result is the scattering
        return vec4(scattering * solar_irradiance, 1.0);
    } else {
        // otherwise, look up from the texture
        bool hits_ground = planet_bounds(ray_start, ray_dir).y > 0.0;

        // get texture coords
        vec3 up = normalize(ray_start);
        float cos_zenith = dot(ray_dir, up);

        // light direction
        vec3 side = normalize(cross(up, ray_dir));
        vec3 forward = normalize(cross(side, up));
        vec2 light_on_plane = normalize(vec2(dot(normalize(sun_dir), forward), dot(normalize(sun_dir), side)));
        float cos_light = light_on_plane.x;
        
        // camera direction
        float h = length(ray_start);
        float v_horizon = sqrt(h * h - bottom_radius * bottom_radius);
        float cos_beta = v_horizon / h;
        float beta = acos(cos_beta);
        float zenith_angle = PI - beta;

        if (!hits_ground) {
            float coord = acos(cos_zenith) / zenith_angle;
            coord = sqrt(1.0 - coord);
            uv.y = 0.5 * (1.0 - coord);
        } else {
            float coord = (acos(cos_zenith) - zenith_angle) / beta;
            uv.y = sqrt(coord) * 0.5 + 0.5;
        }

        uv.x = sqrt(-cos_light * 0.5 + 0.5);

        // sample scattering texture
        return vec4(texture(iChannel2, uv).xyz * solar_irradiance, 1.0);
    }
}

#define MAIN_TRANS void mainImage(out vec4 col, vec2 coord) { col = iFrame == 0 || iMouse.w > 0.0 ? transmittance(coord / iResolution.xy, iChannel0, iChannel3, iChannel3) : texelFetch(iChannel0,   ivec2(coord), 0); }
#define MAIN_MULTI void mainImage(out vec4 col, vec2 coord) { col = iFrame == 0 || iMouse.w > 0.0 ? multi_scatter(coord / iResolution.xy, iChannel0, iChannel3, iChannel3) : texelFetch(iChannel1,   ivec2(coord), 0); }
#define MAIN_SINGLE void mainImage(out vec4 col, vec2 coord) { col = iFrame == 0 || iMouse.w > 0.0 ? single_scatter(coord / iResolution.xy, iChannel0, iChannel1, iChannel2) : texelFetch(iChannel2, ivec2(coord), 0); }
#define MAIN_FRAG void mainImage(out vec4 col, vec2 coord) { col = pow(render(coord / iResolution.xy, iResolution.x, iResolution.y, iChannel0, iChannel1, iChannel2), vec4(1.0 / 2.2)); }

