#version 460

// won't need all of them, but some are used in this model
layout(set = 0, binding = 15) uniform atmosphere {    
    vec3 solar_irradiance;
    float sun_angular_radius;
    float bottom_radius;
    float top_radius;
    
    // rayleigh density
    float ray_width_a;
    float ray_exp_term_a;
    float ray_exp_scale_a;
    float ray_linear_term_a;
    float ray_constant_term_a;
    
    float ray_width_b;
    float ray_exp_term_b;
    float ray_exp_scale_b;
    float ray_linear_term_b;
    float ray_constant_term_b;
        
    vec3 ray_scattering;
    
     // mie density
    float mie_width_a;
    float mie_exp_term_a;
    float mie_exp_scale_a;
    float mie_linear_term_a;
    float mie_constant_term_a;
        
    float mie_width_b;
    float mie_exp_term_b;
    float mie_exp_scale_b;
    float mie_linear_term_b;
    float mie_constant_term_b;
    
    vec3 mie_scattering;
    vec3 mie_extinction;
    float mie_g;
    
    // absorbsion (ozone) density
    float absorb_width_a;
    float absorb_exp_term_a;
    float absorb_exp_scale_a;
    float absorb_linear_term_a;
    float absorb_constant_term_a; 
    
    float absorb_width_b;
    float absorb_exp_term_b;
    float absorb_exp_scale_b;
    float absorb_linear_term_b;
    float absorb_constant_term_b;
        
    vec3 absorb_extinction;
    vec3 ground_albedo;
};

layout(set = 0, binding = 12) uniform view {
    vec3 cam_pos; // camera position
    vec3 sun_dir; // sun direction
    float exposure;
};

layout(push_constant) uniform push {
    uint layer;
    uint width;
    uint height;
};

layout(location = 0) in noperspective centroid vec2 fragcoord;
layout(location = 0) out vec4 fragcolor;

// Adapted from https://github.com/wwwtyro/glsl-atmosphere/blob/master/index.glsl
// Unilicense
// Modified to replace the inner loop with the chapman approximation instead

// <START>
#define PI 3.141592
#define iSteps 24

// chapman function approximation, as described in gpu pro 3
// using exp so not to need rescaling later on
float chapman(float X, float h, float coschi) {
	float c = sqrt(X + h);
	if (coschi >= 0.0) {
		return c / (c * coschi + 1.0) * exp(-h);
	} else {
		float x0 = sqrt(1.0 - coschi * coschi) * (X + h);
		float c0 = sqrt(x0);
		return 2.0 * c0 * exp(X - x0) - c / (1.0 - c * coschi) * exp(-h);
	}
}

float optical_depth(vec3 pos, vec3 dir, float scale_height) {
    return scale_height * chapman(
        bottom_radius / scale_height,
        (length(pos) - bottom_radius) / scale_height,
        dot(normalize(pos), dir)
    );
}

// adapted from: https://iquilezles.org/articles/spherefunctions/
vec2 rsi(vec3 ray, vec3 dir, float sr) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - sr * sr;
    float h = b * b - c;
    if (h < 0.0) return vec2(-1.0, -1.0);
    return vec2(-b - sqrt(h), -b + sqrt(h));
}

vec3 sky_luminance(
    vec3 r,
    vec3 r0,
    vec3 pSun,
    vec3 iSun,
    float rPlanet,
    float rAtmos,
    vec3 kRlh,
    vec3 kMie,
    vec3 aMie,
    vec3 aOzo,
    float shRlh,
    float shMie,
    float whOzo,
    vec2 ltOzo,
    vec2 ctOzo,
    float g
) {
    // Normalize the sun and view directions.
    pSun = normalize(pSun);
    r = normalize(r);

    // Find the ray start and end location
    vec2 iPlanet = rsi(r0, r, rPlanet);
    vec2 iAtmos = rsi(r0, r, rAtmos);

    // miss the atmosphere? return 0
    if (iAtmos.y <= 0.0) return vec3(0.0);

    // x is start, y is end
    vec2 p = max(vec2(
        iAtmos.x, // start on the atmosphere start, or camera if we are inside
        iPlanet.y <= 0.0 // if we hit the planet?
            ? iAtmos.y // miss the planet, end at atmosphere boundary
            : iPlanet.x // hit the planet, end at planet surface
    ), 0.0);

    // Calculate the step size of the primary ray.
    float iStepSize = (p.y - p.x) / float(iSteps);

    // Initialize the primary ray time.
    float iTime = p.x;

    // Initialize accumulators for Rayleigh and Mie scattering.
    vec3 totalRlh = vec3(0.0);
    vec3 totalMie = vec3(0.0);

    // Initialize optical depth accumulators for the primary ray.
    float iOdRlh = 0.0;
    float iOdMie = 0.0;
    
    // Calculate the Rayleigh and Mie phases.
    float mu = dot(r, pSun);
    float mumu = mu * mu;
    float gg = g * g;
    float pRlh = (3.0 / (16.0 * PI)) * (1.0 + mumu);
    float pMie = (3.0 / (8.0 * PI)) * ((1.0 - gg) * (mumu + 1.0)) / (pow(1.0 + gg - 2.0 * mu * g, 1.5) * (2.0 + gg));

    // Sample the primary ray.
    for (int i = 0; i < iSteps; i++) {

        // Calculate the primary ray sample position.
        vec3 iPos = r0 + r * (iTime + iStepSize * 0.5);

        // Calculate the height of the sample.
        float iHeight = length(iPos) - rPlanet;

        // Calculate the optical depth of the Rayleigh and Mie scattering for this step.
        float odStepRlh = exp(iHeight * shRlh) * iStepSize;
        float odStepMie = exp(iHeight * shMie) * iStepSize;
        
        // Accumulate optical depth.
        iOdRlh += odStepRlh;
        iOdMie += odStepMie;

        // Optical depth for the light ray
        float jOdRlh = optical_depth(iPos, pSun, -1.0 / shRlh);
        float jOdMie = optical_depth(iPos, pSun, -1.0 / shMie);
       
        // Calculate attenuation.
        vec3 attn = exp(
            - aMie * (iOdMie + jOdMie)
            - kRlh * (iOdRlh + jOdRlh)
        );

        // Accumulate scattering, if not shadowed
        if (rsi(iPos, pSun, bottom_radius).y < 0.0) {
            totalRlh += odStepRlh * attn;
            totalMie += odStepMie * attn;
        }

        // Increment the primary ray time.
        iTime += iStepSize;

    }

    // Calculate and return the final color.
    return iSun * (pRlh * kRlh * totalRlh + pMie * kMie * totalMie);
}

// distance to where the planet starts, -1 otherwise
// adapted from: https://iquilezles.org/articles/spherefunctions/
float planet_start(vec3 ray, vec3 dir) {
    float b = dot(ray, dir);
    float c = dot(ray, ray) - bottom_radius * bottom_radius;
    float h = b * b - c;
    if (h < 0.0) return -1.0;
    return -b - sqrt(h);
}

void main() {
    // camera offset
    vec3 ray_start = cam_pos + vec3(0.0, bottom_radius, 0.0);

    // ray direction
    vec2 uv = (fragcoord - 0.5) * vec2(float(width) / float(height), 1.0);
    vec3 ray_dir = normalize(vec3(uv, 1.0));

    vec3 luminance = sky_luminance(
        ray_dir,
        ray_start,
        sun_dir,
        solar_irradiance,
        bottom_radius,
        top_radius,
        ray_scattering,
        mie_scattering,
        mie_extinction,
        absorb_extinction,
        ray_exp_scale_b,
        mie_exp_scale_b,
        absorb_width_a,
        vec2(absorb_linear_term_a, absorb_linear_term_b),
        vec2(absorb_constant_term_a, absorb_constant_term_b),
        mie_g
    );

    fragcolor = vec4(luminance * exposure, 1.0);
}
// <END>
