import time
import numpy as np
from scipy.stats.mstats import gmean
import torch
from torchdiffeq import odeint
import detest

torch.set_default_tensor_type(torch.DoubleTensor)


class NFEDiffEq:

    def __init__(self, diffeq):
        self.diffeq = diffeq
        self.nfe = 0

    def __call__(self, t, y):
        self.nfe += 1
        return self.diffeq(t, y)


def main():

    sol = dict()
    for method in ['dopri5', 'adams']:
        for tol in [1e-3, 1e-6, 1e-9]:
            print('======= {} | tol={:e} ======='.format(method, tol))
            nfes = []
            times = []
            errs = []
            for c in ['A', 'B', 'C', 'D', 'E']:
                for i in ['1', '2', '3', '4', '5']:
                    diffeq, init, _ = getattr(detest, c + i)()
                    t0, y0 = init()
                    diffeq = NFEDiffEq(diffeq)

                    if not c + i in sol:
                        sol[c + i] = odeint(
                            diffeq, y0, torch.stack([t0, torch.tensor(20.)]), atol=1e-12, rtol=1e-12, method='dopri5'
                        )[1]
                        diffeq.nfe = 0

                    start_time = time.time()
                    est = odeint(diffeq, y0, torch.stack([t0, torch.tensor(20.)]), atol=tol, rtol=tol, method=method)
                    time_spent = time.time() - start_time

                    error = torch.sqrt(torch.mean((sol[c + i] - est[1])**2))

                    errs.append(error.item())
                    nfes.append(diffeq.nfe)
                    times.append(time_spent)

                    print('{}: NFE {} | Time {} | Err {:e}'.format(c + i, diffeq.nfe, time_spent, error.item()))

            print('Total NFE {} | Total Time {} | GeomAvg Error {:e}'.format(np.sum(nfes), np.sum(times), gmean(errs)))


if __name__ == '__main__':
    main()