from multiprocessing import Pool
import tomllib
import numpy as np
import os
import flip_evaluator as flip
from pathlib import Path
import csv
import matplotlib.pyplot as plt


def eval(x, dyn_range='HDR'):
    # flip
    image, error, _ = flip.evaluate(x[0],
                                    x[1],
                                    dyn_range,
                                    computeMeanError=True,
                                    inputsRGB=False)

    # mean squared error
    img2 = (x[0] - x[1])**2
    mse = float(np.mean(img2))

    # apply gamma
    img2 = np.pow(img2, 1.0 / 2.2)

    return image, np.clip(img2, 0, 1), error, mse
    #return np.clip(img2, 0, 1), np.clip(img2, 0, 1), 0, mse


def eval_ldr(x):
    return eval(x, dyn_range='LDR')


def save_imgs(x):
    view, t, img, img2 = x
    plt.imsave('../images/compare/diff/flip-' + t + '-' + view + '.png', img)
    plt.imsave('../images/compare/diff/mse-' + t + '-' + view + '.png', img2)


# Pool
pool = Pool(1)
pool2 = Pool(30)

# all views to compare
views = tomllib.load(open("views.toml", "rb"))

# list of all shaders besides the path tracer
# also ignore transmittance, as we don't want to compare that one to the path tracer output
types = [
    x.removesuffix('.toml') for x in os.listdir("render")
    if x != "pathtrace.toml" and not x.startswith("transmittance-")
]

# list of transmittance types
trans_types = [
    x.removesuffix('.toml') for x in os.listdir('render')
    if x.startswith('transmittance-')
    and x != 'transmittance-naive.toml'  # don't include the actual one
]

print(types, trans_types)

# timings
timings = [
    tomllib.load(open('../images/timings/' + x, "rb"))
    for x in os.listdir("../images/timings")
]

# csv file to output to
Path('../images/compare/diff').mkdir(parents=True, exist_ok=True)
csv_flip = csv.writer(open('../images/compare/diff/flip.csv', 'w+'),
                      delimiter=',')
csv_mse = csv.writer(open('../images/compare/diff/mse.csv', 'w+'),
                     delimiter=',')

csv_flip.writerow(['view'] + types)
csv_mse.writerow(['view'] + types)

csv_flip_tr = csv.writer(open('../images/compare/diff/flip-transmittance.csv',
                              'w+'),
                         delimiter=',')
csv_mse_tr = csv.writer(open('../images/compare/diff/mse-transmittance.csv',
                             'w+'),
                        delimiter=',')

csv_flip_tr.writerow(['view'] + trans_types)
csv_mse_tr.writerow(['view'] + trans_types)

# for all views
for view in views:
    # load all images
    imgs = [
        # add 1e-9 else flip will just not work
        np.load('../images/compare/' + t + '-' + view + '.npy')[:, :, :3] +
        1e-9 for t in types
    ]

    # path tracer
    pt = np.load('../images/compare/pathtrace-' + view +
                 '.npy')[:, :, :3] + 1e-9

    # compare, in parallel
    results = pool.map(eval, [(pt, y) for y in imgs])

    # write to csv
    csv_flip.writerow(
        [view] +
        ["{:.12f}".format(error) for img, img2, error, mse in results])
    csv_mse.writerow(
        [view] + ["{:.12f}".format(mse) for img, img2, error, mse in results])

    # and terminal
    print(view, 'done')

    # save images
    pool2.map(save_imgs, [(view, t, img, img2)
                          for t, (img, img2, _, _) in zip(types, results)])

print('doing tansmittance')

# compare transmittance
for view in views:
    # load all transmittance images
    imgs = [
        np.load('../images/compare/' + t + '-' + view + '.npy')[:, :, :3] +
        1e-9 for t in trans_types
    ]

    # reference
    # use the transmittance result instead of path tracer
    tr = np.load('../images/compare/transmittance-naive-' + view +
                 '.npy')[:, :, :3] + 1e-9

    # compare
    results = pool.map(eval_ldr, [(tr, y) for y in imgs])

    # write to csv
    csv_flip_tr.writerow(
        [view] +
        ["{:.12f}".format(error) for img, img2, error, mse in results])
    csv_mse_tr.writerow(
        [view] + ["{:.12f}".format(mse) for img, img2, error, mse in results])

    # terminal
    print(view, 'done')

    # save images
    pool2.map(save_imgs,
              [(view, t, img, img2)
               for t, (img, img2, _, _) in zip(trans_types, results)])
