#version 460
#extension GL_ARB_shading_language_include : require
#extension GL_EXT_samplerless_texture_functions : require

// Adapted from atmosphere/model.cc in bruneton's atmosphere
#define IN(x) const in x
#define OUT(x) out x
#define TEMPLATE(x)
#define TEMPLATE_ARGUMENT(x)
#define assert(x)

// from atmosphere/constants.h
const int TRANSMITTANCE_TEXTURE_WIDTH = 256;
const int TRANSMITTANCE_TEXTURE_HEIGHT = 64;

const int SCATTERING_TEXTURE_R_SIZE = 32;
const int SCATTERING_TEXTURE_MU_SIZE = 128;
const int SCATTERING_TEXTURE_MU_S_SIZE = 32;
const int SCATTERING_TEXTURE_NU_SIZE = 8;

const int IRRADIANCE_TEXTURE_WIDTH = 64;
const int IRRADIANCE_TEXTURE_HEIGHT = 16;

// sampler
layout(set = 0, binding = 13) uniform sampler smooth_sampler;

// bruneton assumes a combined sampler, we don't have that.
// hacky way with defines to fix that, same thing skyatmosphere does
vec4 sample_tex(texture3D tex, vec3 uv) { return texture(sampler3D(tex, smooth_sampler), uv); }
vec4 sample_tex(texture2D tex, vec2 uv) { return texture(sampler2D(tex, smooth_sampler), uv); }
#define texture sample_tex

#define sampler2D texture2D
#define sampler3D texture3D

#include "definitions.glsl"
#include "functions.glsl"

// texture bindings
layout(set = 0, binding = 0) uniform texture2D transmittance_tex;
layout(set = 0, binding = 1) uniform texture2D irradiance_tex;

// volume texture bindings
layout(set = 0, binding = 8)  uniform texture3D single_ray_scatter_tex;
layout(set = 0, binding = 9)  uniform texture3D single_mie_scatter_tex;
layout(set = 0, binding = 10) uniform texture3D scatter_density_tex;
layout(set = 0, binding = 11) uniform texture3D multi_scatter_tex;

// per draw params
layout(set = 0, binding = 12) uniform view {
    vec3 cam_pos; // camera position, offset from the top of the sphere
    vec3 sun_dir; // sun direction
    float exposure; // exposure
};

// scattering order uniforms
layout(set = 0, binding = 12) uniform multiscatter {
    float order; // scattering order
};

// global uniforms are used for the atmosphere parameters
// needed in all passes as it's needed for the lut's too
layout(set = 0, binding = 15) uniform globals {    
    vec3 solar_irradiance;
    float sun_angular_radius;
    float bottom_radius;
    float top_radius;
    
    // rayleigh density
    float ray_width_a;
    float ray_exp_term_a;
    float ray_exp_scale_a;
    float ray_linear_term_a;
    float ray_constant_term_a;
    
    float ray_width_b;
    float ray_exp_term_b;
    float ray_exp_scale_b;
    float ray_linear_term_b;
    float ray_constant_term_b;
        
    vec3 ray_scattering;
    
     // mie density
    float mie_width_a;
    float mie_exp_term_a;
    float mie_exp_scale_a;
    float mie_linear_term_a;
    float mie_constant_term_a;
        
    float mie_width_b;
    float mie_exp_term_b;
    float mie_exp_scale_b;
    float mie_linear_term_b;
    float mie_constant_term_b;
    
    vec3 mie_scattering;
    vec3 mie_extinction;
    float mie_g;
    
    // absorbsion (ozone) density
    float absorb_width_a;
    float absorb_exp_term_a;
    float absorb_exp_scale_a;
    float absorb_linear_term_a;
    float absorb_constant_term_a; 
    
    float absorb_width_b;
    float absorb_exp_term_b;
    float absorb_exp_scale_b;
    float absorb_linear_term_b;
    float absorb_constant_term_b;
        
    vec3 absorb_extinction;
    vec3 ground_albedo;
};

// in/out
layout(location = 0) in noperspective centroid vec2 fragcoord;
layout(location = 0) out vec4 fragcolor;

// layer
layout(push_constant) uniform push {
    uint layer; // layer of the 3D texture
    uint width; // width of the current texture
    uint height; // height of the current texture
};

// get the atmosphere
AtmosphereParameters atmosphere() {
    AtmosphereParameters a;
    a.solar_irradiance = solar_irradiance;
    a.sun_angular_radius = sun_angular_radius;
    a.bottom_radius = bottom_radius;
    a.top_radius = top_radius;

    a.rayleigh_density.layers[0].width = ray_width_a;
    a.rayleigh_density.layers[0].exp_term = ray_exp_term_a;
    a.rayleigh_density.layers[0].exp_scale = ray_exp_scale_a;
    a.rayleigh_density.layers[0].linear_term = ray_linear_term_a;
    a.rayleigh_density.layers[0].constant_term = ray_constant_term_a;

    a.rayleigh_density.layers[1].width = ray_width_b;
    a.rayleigh_density.layers[1].exp_term = ray_exp_term_b;
    a.rayleigh_density.layers[1].exp_scale = ray_exp_scale_b;
    a.rayleigh_density.layers[1].linear_term = ray_linear_term_b;
    a.rayleigh_density.layers[1].constant_term = ray_constant_term_b;

    a.rayleigh_scattering = ray_scattering;

    a.mie_density.layers[0].width = mie_width_a;
    a.mie_density.layers[0].exp_term = mie_exp_term_a;
    a.mie_density.layers[0].exp_scale = mie_exp_scale_a;
    a.mie_density.layers[0].linear_term = mie_linear_term_a;
    a.mie_density.layers[0].constant_term = mie_constant_term_a;

    a.mie_density.layers[1].width = mie_width_b;
    a.mie_density.layers[1].exp_term = mie_exp_term_b;
    a.mie_density.layers[1].exp_scale = mie_exp_scale_b;
    a.mie_density.layers[1].linear_term = mie_linear_term_b;
    a.mie_density.layers[1].constant_term = mie_constant_term_b;

    a.mie_scattering = mie_scattering;
    a.mie_extinction = mie_extinction;
    a.mie_phase_function_g = mie_g;

    a.absorption_density.layers[0].width = absorb_width_a;
    a.absorption_density.layers[0].exp_term = absorb_exp_term_a;
    a.absorption_density.layers[0].exp_scale = absorb_exp_scale_a;
    a.absorption_density.layers[0].linear_term = absorb_linear_term_a;
    a.absorption_density.layers[0].constant_term = absorb_constant_term_a;

    a.absorption_density.layers[1].width = absorb_width_b;
    a.absorption_density.layers[1].exp_term = absorb_exp_term_b;
    a.absorption_density.layers[1].exp_scale = absorb_exp_scale_b;
    a.absorption_density.layers[1].linear_term = absorb_linear_term_b;
    a.absorption_density.layers[1].constant_term = absorb_constant_term_b;

    a.absorption_extinction = absorb_extinction;
    a.ground_albedo = ground_albedo;

    // reccomended
    a.mu_s_min = -0.2;
    
    return a;
}

// <START>
// precompute functions, taken from model.cc
// transmittance, first
void transmittance() {
    fragcolor = vec4(
        ComputeTransmittanceToTopAtmosphereBoundaryTexture(atmosphere(), gl_FragCoord.xy),
        0.0
    );
}

// delta_irradiance, second
void direct_irradiance() {
    fragcolor = vec4(
        ComputeDirectIrradianceTexture(atmosphere(), transmittance_tex, gl_FragCoord.xy),
        0.0
    );
}

// delta_ray and mie scattering, third
void single_scatter_ray() {
    vec3 delta_ray, delta_mie;
    vec4 scattering;
    ComputeSingleScatteringTexture(
        atmosphere(),
        transmittance_tex,
        vec3(gl_FragCoord.xy, float(layer) + 0.5),
        delta_ray, 
        delta_mie
    );
    fragcolor = vec4(delta_ray, 0.0);
}

void single_scatter_mie() {
    vec3 delta_ray, delta_mie;
    vec4 scattering;
    ComputeSingleScatteringTexture(
        atmosphere(),
        transmittance_tex,
        vec3(gl_FragCoord.xy, float(layer) + 0.5),
        delta_ray, 
        delta_mie
    );
    fragcolor = vec4(delta_mie, 0.0);
}

void scatter() {
    vec3 delta_ray, delta_mie;
    vec4 scattering;
    ComputeSingleScatteringTexture(
        atmosphere(),
        transmittance_tex,
        vec3(gl_FragCoord.xy, float(layer) + 0.5),
        delta_ray, 
        delta_mie
    );
    fragcolor = vec4(delta_ray.rgb, delta_mie.r);
}

// repeat for all multiple scattering orders (2, 3 and 4)
// delta_scattering_density, fourth, each order
void scatter_density() {
    fragcolor = vec4(
        ComputeScatteringDensityTexture(
            atmosphere(), 
            transmittance_tex,
            single_ray_scatter_tex,
            single_mie_scatter_tex,
            multi_scatter_tex,
            irradiance_tex,
            vec3(gl_FragCoord.xy, float(layer) + 0.5), 
            int(order)
        ),
        0.0
    );
}

// delta_indirect_irradiance, fifth, each order
void indirect_irradiance() {
    fragcolor = vec4(
        ComputeIndirectIrradianceTexture(
            atmosphere(),
            single_ray_scatter_tex,
            single_mie_scatter_tex,
            multi_scatter_tex,
            gl_FragCoord.xy, 
            int(order)
        ), 
        0.0
    );
}

// delta_multiple_scattering, sixth, each order
void multiple_scattering() {
    float nu;
    fragcolor = vec4(
        ComputeMultipleScatteringTexture(
            atmosphere(), 
            transmittance_tex,
            scatter_density_tex,
            vec3(gl_FragCoord.xy, float(layer) + 0.5),
            nu
        ),
        0.0
    );
}

void scatter_step() {
    float nu;
    vec3 delta_multiple_scattering = ComputeMultipleScatteringTexture(
        atmosphere(), 
        transmittance_tex,
        scatter_density_tex,
        vec3(gl_FragCoord.xy, float(layer) + 0.5),
        nu    
    );
    fragcolor = vec4(
        delta_multiple_scattering.rgb / RayleighPhaseFunction(nu), 
        0.0
    ) + texelFetch(
        single_ray_scatter_tex, 
        ivec3(gl_FragCoord.xy, layer), 
        0
    );
}
// end repeat

// needed for the output
// https://ebruneton.github.io/precomputed_atmospheric_scattering/atmosphere/model.cc.html
vec3 GetSkyRadiance(vec3 camera, vec3 view_ray, vec3 sun_direction, out vec3 transmittance) {
    return GetSkyRadiance(
        atmosphere(),
        transmittance_tex,
        single_ray_scatter_tex,
        single_mie_scatter_tex,
        camera, view_ray, 0.0,
        sun_direction, transmittance
    );
}

// render the sky
void render() {
    // camera offset
    vec3 ray_start = cam_pos + vec3(0.0, bottom_radius, 0.0);
    
    // get the ray direction
    vec2 uv = (fragcoord - 0.5) * vec2(float(width) / float(height), 1.0);
    vec3 ray_dir = normalize(vec3(uv, 1.0));

    // skip reflected radiance, don't need it here
    // sky radiance
    vec3 transmittance;
    vec3 radiance = GetSkyRadiance(ray_start, ray_dir, normalize(sun_dir), transmittance);

    // skip sun disk
    // out color
    fragcolor = vec4(radiance * exposure, 1.0);
}
// <END>

// main, decides the entrypoint
void main() {
    // purple to know when nothing is written
    fragcolor = vec4(1.0, 1.0, 0.0, 1.0);
#ifdef TRANSMITTANCE 
    transmittance();
#endif
#ifdef DIRECT_IRRADIANCE 
    direct_irradiance(); 
#endif
#ifdef SINGLE_SCATTER_RAY 
    single_scatter_ray(); 
#endif
#ifdef SINGLE_SCATTER_MIE 
    single_scatter_mie(); 
#endif
#ifdef SCATTER
    scatter();
#endif
#ifdef SCATTER_DENSITY 
    scatter_density(); 
#endif
#ifdef INDIRECT_IRRADIANCE 
    indirect_irradiance(); 
#endif
#ifdef MULTIPLE_SCATTERING 
    multiple_scattering();
#endif
#ifdef SCATTER_STEP
    scatter_step();
#endif
#ifdef RENDER 
    render();
#endif
}
