
#version 460

layout(set = 0, binding = 15) uniform atmosphere {    
    vec3 solar_irradiance;
    float sun_angular_radius;
    float bottom_radius;
    float top_radius;
    
    // rayleigh density
    float ray_width_a;
    float ray_exp_term_a;
    float ray_exp_scale_a;
    float ray_linear_term_a;
    float ray_constant_term_a;
    
    float ray_width_b;
    float ray_exp_term_b;
    float ray_exp_scale_b;
    float ray_linear_term_b;
    float ray_constant_term_b;
        
    vec3 ray_scattering;
    
     // mie density
    float mie_width_a;
    float mie_exp_term_a;
    float mie_exp_scale_a;
    float mie_linear_term_a;
    float mie_constant_term_a;
        
    float mie_width_b;
    float mie_exp_term_b;
    float mie_exp_scale_b;
    float mie_linear_term_b;
    float mie_constant_term_b;
    
    vec3 mie_scattering;
    vec3 mie_extinction;
    float mie_g;
    
    // absorbsion (ozone) density
    float absorb_width_a;
    float absorb_exp_term_a;
    float absorb_exp_scale_a;
    float absorb_linear_term_a;
    float absorb_constant_term_a; 
    
    float absorb_width_b;
    float absorb_exp_term_b;
    float absorb_exp_scale_b;
    float absorb_linear_term_b;
    float absorb_constant_term_b;
        
    vec3 absorb_extinction;
    vec3 ground_albedo;
    float mu_s_min;
};

layout(set = 0, binding = 12) uniform view {
    float mode; // how to render
    float depth; // number of bounces
    float tilesize; // how big a tile is in scattering mode
};

layout(push_constant) uniform push {
    uint layer;
    uint width;
    uint height;
};

layout(location = 0) in noperspective centroid vec2 fragcoord;
layout(location = 0) out vec4 fragcolor;

const float PI = 3.14159265358979323846;

// maximum number of tries before exiting a loop
const uint MAX_STEPS = 512u;

// Mask for the color channel
uint channel = 0;

// pcg, see https://www.pcg-random.org/
uvec4 pcg4d(inout uvec4 v) {
    v = v * 1664525u + 1013904223u;
    v.x += v.y * v.w;
    v.y += v.z * v.x;
    v.z += v.x * v.y;
    v.w += v.y * v.z;
    v = v ^ (v >> 16u);
    v.x += v.y * v.w;
    v.y += v.z * v.x;
    v.z += v.x * v.y;
    v.w += v.y * v.z;
    return v;
}

// random number state
uvec4 random_state = uvec4(
    // pixel pos
    uvec2(fragcoord.xy * vec2(width, height)) + 1u,
    // what iteration we are on in the pass
    layer + 1u,
    // extra randomness
    uint(float(fragcoord.y + fragcoord.x) * (width + height)) + layer + 1u
);

// generate a random number
float random() {
    // https://experilous.com/1/blog/post/perfect-fast-random-floating-point-numbers
    return uintBitsToFloat((pcg4d(random_state).x >> 9u) | 0x3f800000u) - 1.0;
}

// randomly switch channel
uint random_channel() {
    channel = pcg4d(random_state).x % 3;
    return channel;
}

// uniform random direction in sphere
vec3 sample_uniform_dir() {
    float x = random();
    float y = random();
    float phi = 2.0 * PI * x;
    float cos_theta = 1.0 - 2.0 * y;
    float sin_theta = sqrt(1.0 - cos_theta * cos_theta);
    return vec3(sin_theta * cos(phi), sin_theta * sin(phi), cos_theta);
}

// bruneton uses the cornette-shanks phase function
// so use the same here
// rayleigh assumes g = 0
float phase_ray(float cos_theta) {
    float k = 3.0 / (16.0 * PI);
    return k * (1.0 + cos_theta * cos_theta);
}

// mie assumes g = mie_g
float phase_mie(float cos_theta) {
    float k = 3.0 / (8.0 * PI) * (1.0 - mie_g * mie_g) / (2.0 + mie_g * mie_g);
    return k * (1.0 + cos_theta * cos_theta) / pow(1.0 + mie_g * mie_g - 2.0 * mie_g * cos_theta, 1.5);
}

// densities
float scatter_density(vec3 pos) {
    float h = length(pos) - bottom_radius;
    float r = clamp(exp(ray_exp_scale_b * h), 0.0, 1.0);
    float m = clamp(exp(mie_exp_scale_b * h), 0.0, 1.0);

    return r * ray_scattering[channel]
        + m * mie_scattering[channel];
}

float extinct_density(vec3 pos) {
    float h = length(pos) - bottom_radius;
    float r = clamp(exp(ray_exp_scale_b * h), 0.0, 1.0);
    float m = clamp(exp(mie_exp_scale_b * h), 0.0, 1.0);
    float o = clamp(h < absorb_width_a
        ? (h * absorb_linear_term_a + absorb_constant_term_a)
        : (h * absorb_linear_term_b + absorb_constant_term_b)
        , 0.0, 1.0);

    return r * ray_scattering[channel]
        + m * mie_extinction[channel]
        + o * absorb_extinction[channel];
}

// maximum densities
float max_scatter_density() {
    return ray_scattering[channel] + mie_scattering[channel];
}

float max_extinct_density() {
    return ray_scattering[channel] + mie_extinction[channel] + absorb_extinction[channel];
}

// does the ray hit the planet?
// adapted from: https://iquilezles.org/articles/spherefunctions/
bool planet_intersect(vec3 ray, vec3 dir) {    
    float b = dot(ray, dir);
    float c = dot(ray, ray) - bottom_radius * bottom_radius;
    float h = b * b - c;
    if (h < 0.0) return false;
    // hits if it hits the far end of the planet in front of the viewer
    return -b + sqrt(h) >= 0.0;
}

// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 atmo_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - top_radius * top_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

// where does the planet start?
// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 planet_bounds(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - bottom_radius * bottom_radius;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

// ratio tracking, to estimate transmittance
float ratio_tracking(vec3 start, vec3 dir, bool light) {
    // hit the planet? stop, no transmittance
    // ignore if we generate the transmittance table
    if (planet_intersect(start, dir) && light) return 0.0;

    // distance from start
    float t = 0.0;
    vec2 t_bounds = atmo_bounds(start, dir);

    // else, ratio track
    float transmittance = 1.0;
    for (uint i = 0u; i < MAX_STEPS; i++) {
        float xi = random();
        t += -log(1.0 - xi) / max_extinct_density();

        // stop if out of bounds
        if (t > t_bounds.y) break;

        // absorb
        transmittance *= 1.0 - (extinct_density(start + dir * t) / max_extinct_density());
    }

    return transmittance;
}

// scattering types
const uint I_MISS = 0u; // no scattering
const uint I_GROUND = 1u; // hit the ground
const uint I_RAY = 2u; // rayleigh scattering
const uint I_MIE = 3u; // mie scattering
const uint I_ABSORB = 4u; // absorb

// delta tracking, for the next scatter event
uint delta_tracking(vec3 start, vec3 dir, out float t) {
    t = 0.0;

    // maximum distance we can trace
    vec2 t_max = planet_bounds(start, dir);
    vec2 t_bounds = atmo_bounds(start, dir);

    // track
    for (uint i = 0u; i < MAX_STEPS; i++) {
        float xi = random();
        float ki = random();
        float zi = random();
        t += -log(1.0 - xi) / max_extinct_density();
        if (t > t_bounds.y) {
            // outside of volume
            return I_MISS;
        } else if (t_max.y > 0.0 && t > t_max.x) {
            // hit the ground
            return I_GROUND;
        } else if (ki < scatter_density(start + dir * t) / max_extinct_density()) {
            // scatter event
            float h = length(start + dir * t) - bottom_radius;
            float d_ray = exp(ray_exp_scale_b * h) * ray_scattering[channel];
            float d_mie = exp(mie_exp_scale_b * h) * mie_scattering[channel];

            // what type?
            return zi < d_mie / (d_mie + d_ray) ? I_MIE : I_RAY;
        } else if (ki < extinct_density(start + dir * t) / max_extinct_density()) {
            // absorb
            return I_ABSORB;
        }
    }

    return I_MISS;
}

// integrate
vec2 integrate(vec3 start, vec3 dir, vec3 light, bool separate_mie) {
    // accumulated luminance
    float luminance = 0.0;
    float mie_luminance = 0.0;
    float throughput = 1.0;

    // loop till end of depth
    for (uint i = 0u; i < uint(depth); i++) {
        // find the next hit
        float t;
        uint hit = delta_tracking(start, dir, t);

        // ground? stop
        if (hit == I_GROUND) break;

        // miss? try again
        else if (hit == I_MISS) break;

        // absorb? stop
        else if (hit == I_ABSORB) break;

        // mie and we separate mie scattering?
        else if (hit == I_MIE && i == 0 && separate_mie) {
            // pick the phase
            // because this is mie, just don't care about it
            float phase = 0.25 / PI;

            // add to luminance
            mie_luminance
                += throughput
                * ratio_tracking(start + dir * t, normalize(light), true)
                * solar_irradiance[channel]
                * phase;
                
            // change direction
            vec3 new_dir = sample_uniform_dir();
            start += dir * t;
            dir = new_dir;
        }

        // disable rayleigh phase
        else if (hit == I_RAY && i == 0 && separate_mie) {
            // pick the phase
            // because this is mie, just don't care about it
            float phase = 0.25 / PI;

            // add to luminance
            luminance
                += throughput
                * ratio_tracking(start + dir * t, normalize(light), true)
                * solar_irradiance[channel]
                * phase;
                
            // change direction
            vec3 new_dir = sample_uniform_dir();
            start += dir * t;
            dir = new_dir;
        }

        // rayleigh? mie? add luminance
        else if (hit == I_RAY || hit == I_MIE) {
            // pick the phase
            float phase
                = hit == I_RAY
                ? phase_ray(dot(normalize(light), dir))
                : phase_mie(dot(normalize(light), dir));

            // add to luminance
            luminance
                += throughput
                * ratio_tracking(start + dir * t, normalize(light), true)
                * solar_irradiance[channel]
                * phase;
                
            // change direction
            vec3 new_dir = sample_uniform_dir();

            // new throughput
            throughput
                *= hit == I_RAY
                ? phase_ray(dot(new_dir, dir)) * 4.0 * PI
                : phase_mie(dot(new_dir, dir)) * 4.0 * PI;
            
            start += dir * t;
            dir = new_dir;
        }
    }

    return vec2(luminance, mie_luminance);
}

void main() {
    if (mode == 0.0) {
        // transmittance, direct
        // x = height, y = cos-theta
        float height = bottom_radius + fragcoord.x * (top_radius - bottom_radius);
        float cos_theta = fragcoord.y * 2.0 - 1.0;
        float sin_theta = sqrt(1.0 - cos_theta * cos_theta);

        vec3 start = vec3(0.0, height, 0.0);
        vec3 dir = vec3(0.0, cos_theta, sin_theta);

        // transmittance
        fragcolor = vec4(0.0, 0.0, 0.0, 1.0);
        random_channel(); fragcolor[channel] = ratio_tracking(start, dir, false) * 3.0; // random channel
    } else if (mode == 1.0) {
        // scatter, direct
        // x = cos-theta, y = cos-gamma, z = cos-light, w = height
        float in_x = fract(fragcoord.x * tilesize);
        float in_y = fract(fragcoord.y * tilesize);

        float out_x = floor(fragcoord.x * tilesize) / tilesize;
        float out_y = floor(fragcoord.y * tilesize) / tilesize;

        // view params
        // squared so there is more coverage on lower altitudes
        float height = bottom_radius + out_y * out_y * (top_radius - bottom_radius);

        // angles        
        float cos_zenith = in_x * 2.0 - 1.0;
        float sin_zenith = sqrt(1.0 - cos_zenith * cos_zenith);

        // don't include much from the shaded side of the planet, as this is just black
        float cos_light = out_x * 1.1 - 0.1;
        float sin_light = sqrt(1.0 - cos_light * cos_light);

        float cos_gamma = in_y * 2.0 - 1.0;
        float sin_gamma = sqrt(1.0 - cos_gamma * cos_gamma);

        // directions
        vec3 light = vec3(0.0, cos_light, sin_light);
        vec3 dir = vec3(sin_zenith * sin_gamma, cos_zenith, sin_zenith * cos_gamma);
        vec3 start = vec3(0.0, height, 0.0);
        
        // scattering
        fragcolor = vec4(0.0, 0.0, 0.0, 1.0);
        random_channel(); fragcolor[channel] = integrate(start, dir, light, false).x * 3.0; // random channel
    } else if (mode == 2.0) {
        // transmittance, parametrization
        // from skyatmo
        // height at this pixel
        float h = sqrt(top_radius * top_radius - bottom_radius * bottom_radius);
        float rho = h * (1.0 - fragcoord.y);
        float view_height = sqrt(rho * rho + bottom_radius * bottom_radius);
    
        // view angle at this pixel
        float d_min = top_radius - view_height;
        float d_max = rho + h;
        float d = d_min + fragcoord.x * (d_max - d_min);
        float cos_zenith
            = d == 0.0
            ? 1.0
            : clamp((h * h - rho * rho - d * d) / (2.0 * view_height * d), -1.0, 1.0);
    
        // start position and direction
        vec3 start = vec3(0.0, 0.0, view_height);
        vec3 dir = vec3(0.0, sqrt(1.0 - cos_zenith * cos_zenith), cos_zenith);       
            
        // transmittance
        fragcolor = vec4(0.0, 0.0, 0.0, 1.0);
        random_channel(); fragcolor[channel] = ratio_tracking(start, dir, false) * 3.0; // random channel
    } else if (mode == 3.0) {
        // scatter, parametrization
        // GetRMuMuSNuFromScatteringTextureUvwz from bruneton
        float in_x = fract(fragcoord.x * tilesize);
        float in_y = fract(fragcoord.y * tilesize);

        float out_x = floor(fragcoord.x * tilesize) / tilesize;
        float out_y = floor(fragcoord.y * tilesize) / tilesize;

        // convenience
        vec4 uvwz = vec4(in_x, in_y, out_x, out_y);

        // distance to top
        float h = sqrt(top_radius * top_radius - bottom_radius * bottom_radius);

        // horizon distance
        float rho = h * uvwz.w;
        float r = sqrt(rho * rho + bottom_radius * bottom_radius);

        float mu;
        // NOTE: flipped compared to the original, this avoids a hard cutoff
        if (uvwz.z >= 0.5) {
            uvwz.z -= 0.5;
            // distance to the ground
            float d_min = r - bottom_radius;
            float d_max = rho;
            float d = d_min + (d_max - d_min) * (1.0 - 2.0 * uvwz.z);
            mu = d == 0.0 ? -1.0 : clamp(-(rho * rho + d * d) / (2.0 * r * d), -1.0, 1.0);
        } else {
            uvwz.z += 0.5;
            // distance to the top atmosphere
            float d_min = top_radius - r;
            float d_max = rho + h;
            float d = d_min + (d_max - d_min) * (2.0 * uvwz.z - 1.0);
            mu = d == 0.0 ? 1.0 : clamp((h * h - rho * rho - d * d) / (2.0 * r * d), -1.0, 1.0);
        }

        float x_mu_s = uvwz.y;
        float d_min = top_radius - bottom_radius;
        float d_max = h;

        // distance to top atmosphere boundary
        float d_cap = -bottom_radius * mu_s_min
            + sqrt(bottom_radius * bottom_radius * (mu_s_min * mu_s_min - 1.0) + top_radius * top_radius);
        float a_cap = (d_cap - d_min) / (d_max - d_min);
        float a = (a_cap - x_mu_s * a_cap) / (1.0 + x_mu_s * a_cap);
        float d = d_min + min(a, a_cap) * (d_max - d_min);
        float mu_s = d == 0.0 ? 1.0 : clamp((h * h - d * d) / (2.0 * bottom_radius * d), -1.0, 1.0);

        float nu = uvwz.x * 2.0 - 1.0;

        // convert coordinates to directions
        float height = r;

        float cos_zenith = mu;
        float sin_zenith = sqrt(1.0 - cos_zenith * cos_zenith);

        float cos_light = mu_s;
        float sin_light = sqrt(1.0 - cos_light * cos_light);

        float cos_gamma = nu;
        float sin_gamma = sqrt(1.0 - cos_gamma * cos_gamma);

        // directions
        vec3 light = vec3(0.0, cos_light, sin_light);
        vec3 dir = vec3(sin_zenith * sin_gamma, cos_zenith, sin_zenith * cos_gamma);
        vec3 start = vec3(0.0, height, 0.0);        
        
        // scattering
        fragcolor = vec4(0.0, 0.0, 0.0, 1.0);
        random_channel(); fragcolor[channel] = integrate(start, dir, light, false).x * 3.0; // random channel
    } else if (mode == 4.0) {
        // scatter, direct, no mie scattering
        // x = cos-theta, y = cos-gamma, z = cos-light, w = height
        float in_x = fract(fragcoord.x * tilesize);
        float in_y = fragcoord.y;

        float out_x = floor(fragcoord.x * tilesize) / tilesize;

        // view params
        // squared so there is more coverage on lower altitudes
        float height = bottom_radius + in_x * in_x * (top_radius - bottom_radius);

        // angles        
        float cos_zenith = in_y * 2.0 - 1.0;
        float sin_zenith = sqrt(1.0 - cos_zenith * cos_zenith);

        float cos_light = out_x * 1.2 - 0.2;
        float sin_light = sqrt(1.0 - cos_light * cos_light);

        // don't care about this one, as we well not do mie scattering
        float cos_gamma = 1.0;
        float sin_gamma = 0.0;

        // directions
        vec3 light = vec3(0.0, cos_light, sin_light);
        vec3 dir = vec3(sin_zenith * sin_gamma, cos_zenith, sin_zenith * cos_gamma);
        vec3 start = vec3(0.0, height, 0.0);
        
        // scattering
        fragcolor = vec4(0.0, 0.0, 0.0, 1.0);

        random_channel();
        vec2 int_result = integrate(start, dir, light, true);

        fragcolor[channel] = int_result.x * 3.0; // random channel
        fragcolor[3] = int_result.y; // separate mie 
    } else  if (mode == 5.0) {
        // bruneton's paremetrization, but not skewed
        float in_x = fract(fragcoord.x * tilesize);
        float in_y = fract(fragcoord.y * tilesize);

        float out_x = floor(fragcoord.x * tilesize) / tilesize;
        float out_y = floor(fragcoord.y * tilesize) / tilesize;

        // view params
        // squared so there is more coverage on lower altitudes
        float height = bottom_radius + in_x * in_x * (top_radius - bottom_radius);

        // all params
        float mu = in_y * 2.0 - 1.0;
        float mu_s = out_x * 1.2 - 0.2;
        float nu = out_y * 2.0 - 1.0;

        // clamp to keep in valid range
        // see GetRMuMuSNuFromScatteringTextureFragCoord
        nu = clamp(nu,
            mu * mu_s - sqrt((1.0 - mu * mu) * (1.0 - mu_s * mu_s)),
            mu * mu_s + sqrt((1.0 - mu * mu) * (1.0 - mu_s * mu_s))
        );
    
        // to cosine angles
        float cos_zenith = mu;
        float sin_zenith = sqrt(1.0 - cos_zenith * cos_zenith);

        float cos_light = mu_s;
        float sin_light = sqrt(1.0 - cos_light * cos_light);

        float cos_gamma = nu;
        float sin_gamma = sqrt(1.0 - cos_gamma * cos_gamma);

        // directions
        vec3 light = vec3(0.0, cos_light, sin_light);
        vec3 dir = vec3(sin_zenith * sin_gamma, cos_zenith, sin_zenith * cos_gamma);
        vec3 start = vec3(0.0, height, 0.0);

        // scattering
        fragcolor = vec4(0.0, 0.0, 0.0, 1.0);
        random_channel(); fragcolor[channel] = integrate(start, dir, light, false).x * 3.0; // random channel
    } else {
        // wrong mode!
        fragcolor = vec4(1.0, 0.0, 1.0, 1.0);
    }
}
