import queue
import numpy as np
from numpy.random import default_rng
import time
from copy import deepcopy
from queue import LifoQueue
import matplotlib.pyplot as plt
import multiprocessing as mp

class MCMC:
    def __init__(self) -> None:
        self.rng = default_rng()
        self.L = None
        self.J = None
        #self.coordinate_list = [(i, j) for i in range(self.L) for j in range(self.L)]
        self.lattice = None

    def init_lattice(self, L, init_type):
        if init_type == 'random':
            return self.rng.choice([-1, 1], size=(self.L, self.L))

    def get_energy(self):
        energy = 0
        spins_right = np.roll(self.lattice, 1, axis=0)
        spins_down = np.roll(self.lattice, 1, axis=1)
        energy -= self.J * (self.lattice * spins_right).sum()
        energy -= self.J * (self.lattice * spins_down).sum()
        return energy

    def get_random_pos(self):
        return np.random.randint(0, self.L), np.random.randint(0, self.L)

    def get_magnetization(self):
        return np.abs(self.lattice.mean())

    def step_wolff(self, beta, J):
        acceptance_prob = 1-np.exp(-2*beta*J)
        i, j = self.get_random_pos()
        ij_spin = self.lattice[i,j]
        self.lattice[i,j] = -self.lattice[i,j]
        stack = LifoQueue() # LifoQueue(maxsize=self.L*self.L)
        stack.put((i,j))

        while not stack.empty():
            i1,j1 = stack.get()

            if ij_spin == self.lattice[(i1+1)%self.L,j1] and self.rng.uniform() < acceptance_prob: 
                self.lattice[(i1+1)%self.L,j1] = -self.lattice[(i1+1)%self.L,j1]
                stack.put(((i1+1)%self.L,j1))
            
            if ij_spin == self.lattice[(i1-1)%self.L,j1] and self.rng.uniform() < acceptance_prob: 
                self.lattice[(i1-1)%self.L,j1] = -self.lattice[(i1-1)%self.L,j1]
                stack.put(((i1-1)%self.L,j1))
            
            if ij_spin == self.lattice[i1,(j1+1)%self.L] and self.rng.uniform() < acceptance_prob: 
                self.lattice[i1,(j1+1)%self.L] = -self.lattice[i1,(j1+1)%self.L]
                stack.put((i1,(j1+1)%self.L))

            if ij_spin == self.lattice[i1,(j1-1)%self.L] and self.rng.uniform() < acceptance_prob: 
                self.lattice[i1,(j1-1)%self.L] = -self.lattice[i1,(j1-1)%self.L]
                stack.put((i1,(j1-1)%self.L))
            
    def get_autocorrelation(self, sample_result):

        def autocorr_manual(x):
            mean = np.mean(x)
            var = np.var(x)
            xp = x - mean
            # We need the / x.size, because the variance has it in the denominator
            if var != 0:
                corr = [1. if l == 0 else np.sum(xp[l:] * xp[:-l]) / x.size / var for l in range(x.size)]
            else:
                corr = [0] * x.size
            return np.array(corr)

        def autocorr_fft(x):
            '''fft, pad 0s, non partial'''
            n = x.size
            # pad 0s to 2n-1
            ext_size = 2 * n - 1
            # nearest power of 2
            fsize = 2 ** np.ceil(np.log2(ext_size)).astype('int')
            xp = x - np.mean(x)
            var = np.var(x)
            if var != 0:
                # do fft and ifft
                cf = np.fft.fft(xp, fsize)
                sf = cf.conjugate() * cf
                corr = np.fft.ifft(sf).real
                corr = corr / var / n
                corr = corr[:x.size]
            else:
                corr = [0] * x.size
            return np.array(corr)

        sample_result = np.array([np.array(a) for a in sample_result])
        #print(sample_result)

        if sample_result.ndim == 1:
            acf = autocorr_fft(sample_result)

            if acf[0] != 0:
                acf = acf / acf[0]
        elif sample_result.ndim == 2:
            raise NotImplementedError
        elif sample_result.ndim == 3:
            raise NotImplementedError
            #acf = np.empty((sample_result.shape[1], sample_result.shape[2]))
            #for i in range(sample_result.shape[1]):
            #    for j in range(sample_result.shape[2]):
            #        print(sample_result[:, i, j])
            #        temp = autocorr_fft(sample_result[:, i, j])
            #        acf[i, j] = temp
            #        if acf[i, j][0] == 0:
            #            acf[i, j] = acf[i, j]/acf[i, j][0]
        else:
            raise NotImplementedError

        return acf

    def get_autocorrelation_time(self, sample_result, cutoff=None):
        acf = self.get_autocorrelation(sample_result)
        tau = 0.5
        #leng = acf.size
        if acf[0] == 0:
            return tau
        else:
            #acf[acf <= 0] = 0
            try:
                cutoff = np.where(acf <= 0)[0][0]
            except Exception:
                if acf[acf.size - 1] > 0:
                    cutoff = acf.size
                else:
                    cutoff = 1
            #tau += sum(acf[t] * (1-t/leng) for t in range(1, leng, 1))
            acf = acf[0:cutoff]
            acf = np.delete(acf, 0)
            tau += np.sum(acf*(np.ones(acf.size)-np.arange(1, acf.size+1, 1)/acf.size))
            return tau
    
    def MCMC(self, L, J, beta, N_steps, N_therm, N_samples, init_type = 'random'):
        self.L = L
        self.J = J
        self.lattice = self.init_lattice(self.L, init_type)
        output_sample_array = np.empty((N_samples, L, L))

        # Thermalization ("Burn-in")
        for i in range(N_therm):
            self.step_wolff(beta, J)

        magnetization_list_scratch = []
        for i in range(N_samples):
            for j in range(N_steps):
                self.step_wolff(beta, J)
            magnetization_list_scratch.append(self.get_magnetization())
            output_sample_array[i,:,:] = deepcopy(self.lattice)
        
        np.save(f'lattice_{np.around(beta, 3)}', output_sample_array)
        return magnetization_list_scratch


def f(L, J, beta, N, init_type = 'random'):
    #print('started')
    mcmc = MCMC()
    if beta > 0.4:
        N_steps = int(L)
    else:
        N_steps=int(L*L/2)

    magnetization_list_scratch = mcmc.MCMC(L=L, J=J, beta=beta, N_steps=N_steps, N_therm=10*L, N_samples=N)
    result_array = np.array(magnetization_list_scratch)
    #print('finished')
    return result_array

if __name__ == '__main__':

    L = 40
    N = 500

    start = time.perf_counter()

    betas = np.linspace(0.2, 1.0, 33)
    #temperatures = np.linspace(1.5, 2.5, 10)
    #betas = 1 / temperatures 
        
    with mp.Pool(processes=6) as p:
        args = [(L, 1, beta, N) for beta in betas]
        result_array = p.starmap(f, args, chunksize=1)

    mcmc = MCMC()

    result_array = np.array(result_array)
    #magnetization = sum(magnetization_list_scratch)/len(magnetization_list_scratch)
    op_array = np.mean(result_array, axis=1)
    op_array_error_auto = np.apply_along_axis(mcmc.get_autocorrelation_time, 1, result_array)
    op_array_error_var = np.apply_along_axis(np.var, 1, result_array)
    op_array_error = np.sqrt(2 * op_array_error_auto * op_array_error_var / result_array.shape[1])

    fig, ax = plt.subplots()
    ax.errorbar(betas, op_array, yerr=op_array_error)
    ax.set(xlabel=r'Inverse temperature $\beta$')
    ax.set(ylabel=r'Magnetization $|M|$')
    ax.set_ylim([0, 1])
    fig.savefig('magnetization.pdf')

    op_4 = np.mean(result_array**4, axis=1)
    op_2 = np.mean(result_array**2, axis=1)
    op_2[op_2 == 0] = 1e-5
    op_4[op_2 == 0] = 0
    op_binder = 3/2 * ( 1 - op_4/(3*op_2**2) )

    fig, ax = plt.subplots()
    ax.plot(betas, op_binder)
    ax.set(xlabel=r'Inverse temperature $\beta$')
    ax.set(ylabel=r'Binder cumulant $C_d$')
    #ax.set_ylim([0, 1])
    fig.savefig('binder.pdf')

    #for beta in betas:
    #    magnetization_list_scratch = []
    #    for i in range(N):
    #        mcmc.MCMC(L=40, J=1, beta=beta, N=L)
    #        magnetization_list_scratch.append(mcmc.get_magnetization())
    #    magnetization = sum(magnetization_list_scratch)/len(magnetization_list_scratch)
    #    magnetization_list.append(magnetization)


    end = time.perf_counter()
    run_time = end-start

    with open("parameters.txt", "w") as text_file:
            print(
                f"The simulation parameters can be found here."
                f"\n\nSimulation Runtime: {run_time}"
                f"\n\nLattice Parameters:"
                f"\nL: {L}"
                f"\nN_samples_per_beta: {N}"
                f"\nbetas: {betas}",
                file=text_file)

    print("Finished in " + str(run_time) + " seconds.")