import argparse
import tomllib
import statistics as stats

parser = argparse.ArgumentParser()
parser.add_argument('mode')
parser.add_argument('file')
parser.add_argument('-c', '--csv', action='store_true')
args = parser.parse_args()

if args.mode not in ['all', 'out', 'in']:
    print('the first argument must be either `all`, `out` or `in`')
    exit()

def skip(n):
    if args.mode == 'all':
        return False
    if args.mode == 'out':
        return not 'orbit' in n and not 'space' in n
    if args.mode == 'in':
        return 'orbit' in n or 'space' in n


timings = tomllib.load(open(args.file, 'rb'))

shaders = {}

for name, times in timings['outputs'].items():
    if skip(name): continue

    shader = name.split('-')[0] if not name.startswith(
        'transmittance') else 'transmittance-' + name.split('-')[1]
    total = times['average']

    # skyatmo is special, include the pass timing
    if shader == 'skyatmo':
        total += timings['passes'][name]['average']

    if shader in shaders:
        shaders[shader].append(total)
    else:
        shaders[shader] = [total]

print(timings['gpu']['name'], timings['gpu']['backend'],
      timings['gpu']['driver'])

print('times in miliseconds')

pad = max(map(len, shaders.keys())) + 1

print(' ' * pad + '  ---')

# bruneton has a precomputation step
bruneton_precomp = sum([
    v['total'] for k, v in timings['passes'].items()
    if k.startswith('bruneton')
])

# skyatmo too
skyatmo_precomp = sum([
    v['total'] for k, v in timings['passes'].items()
    if k in ['skyatmo-transmittance', 'skyatmo-multi-scatter']
])

print('bruneton precompute'.rjust(pad) + f": {bruneton_precomp * 1e-6:0.6}")
print('skyatmo precompute'.rjust(pad) + f": {skyatmo_precomp * 1e-6:0.6}")

print(' ' * pad + '  ---')

# sort by average fastesd
shaders = sorted(shaders.items(), key=lambda x: stats.mean(x[1]))

if not args.csv:
    # print results
    for shader, times in shaders:
        # to miliseconds
        print(
            shader.rjust(pad) + f": {stats.mean(times) * 1e-6:0.6}" +
            f" (best: {min(times) * 1e-6:0.6}," +
            f" worst: {max(times) * 1e-6:0.6}," +
            f" stdev: {stats.stdev(map(lambda x: x * 1e-6, times)):0.6})")
else:
    print('model,average,best,worst,stdev')
    for shader, times in shaders:
        print(f"{shader},{stats.mean(times)*1e-6},{min(times)*1e-6},{max(times)*1e-6},{stats.stdev(map(lambda x: x * 1e-6, times)):0.6}")
