'''
Define a manifolds distribution
'''
from scipy import integrate
from scipy.stats import norm
import autograd.numpy as np
import random
import traceback
import warnings
import sys

from torch.distributions.multivariate_normal import MultivariateNormal
from autograd import grad
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable

import math
from torch.autograd import grad
np.random.seed(4)
random.seed(4)
torch.manual_seed(4)



def isinf(q):
    x = (q==float('inf'))
    inf = torch.sum(x)
    x = (q==float('-inf'))
    ninf = torch.sum(x)

    if inf:
        #raise Exception("Overflow encounter")
        warnings.warn("Overflow encounter")
        #traceback.print_stack()


    if ninf:
        #raise Exception("Underflow encounter")
        warnings.warn("Underflow encounter")
        #traceback.print_stack()


def funnel_dist(q):
    '''
    Unnormalized Density Function

    Return the value of
    :param: q is the state variable D
    :param: p is the momentum variable
    :param: D is the dimension of q
    :return:
    '''

    v = q[0]
    x = q[1:]

    return torch.exp(-0.5 * v ** 2 / (3) ** 2) + torch.sum(torch.exp(-0.5 * x ** 2 / torch.exp(v)))

def U(q):
    '''
    negative log of unnormalized target distribution
    :return:
    '''

    D = q.size()[0]
    result = 0.5*q[0]**2/9*D + torch.sum(q[1:]**2/(2*torch.exp(q[0])))

    isinf(result)

    return result



def grad_U(q):
    q = Variable(q, requires_grad=True)
    U_q = U(q)

    U_q.backward()

    return q.grad

def K(p,inv_mass):
    return p@inv_mass@p/2

def leapfrog(U, K, grad_U, cur_q, mass,inv_mass, eps=0.1, L=200):
    '''
    Neal's proposed integrator
    '''
    D = cur_q.shape[0]
    q = cur_q.clone()

    p_0 = MultivariateNormal(torch.zeros(D),mass).rsample(torch.Size([1])).view(D,)
    p = p_0.clone()

    # make a small step towords mode at beginning
    p = p - eps*grad_U(q)/2

    # do full step
    for i in range(L):
        q += eps*p
        if i != L-1:
            p -= eps*grad_U(q)
    # Make a half step for momentum at the end.
    p = p-eps*grad_U(q)/2
    p = -p

    cur_U = U(cur_q.view(D,))
    cur_K = K(p_0.view(D,),inv_mass)
    proposed_U = U(q.view(D,))
    proposed_K = K(p.view(D,),inv_mass)

    if torch.log(torch.Tensor(1).uniform_(0,1)) < (cur_U - proposed_U + cur_K - proposed_K):
        return q
    return cur_q

def simple_hmc(U, K, grad_U, mass,inv_mass,iters, q_0, integrator,L=200,eps=0.175):
    D = q_0.shape[0]
    q_hist = []
    q_hist.append(np.asarray(q_0.reshape(D,)))
    accepted_num = 0
    cur_q = q_0.clone()

    for i in range(iters):

        nxt_q = integrator(U, K, grad_U, cur_q, mass,inv_mass, L=L, eps=eps)
        if torch.sum(torch.ne(nxt_q,cur_q)) > 0:
            accepted_num += 1
            q_hist.append(np.asarray(nxt_q.reshape(D,)))
        cur_q = nxt_q

        if i%50 == 0:
            print("progressed {}%".format(i*100/iters))
    print("The acceptance rate is {}".format(accepted_num/iters))
    return q_hist


def globally_adaptive_metric(q_hist):

    q_hist = np.array(q_hist)

    cov = torch.Tensor(np.cov(q_hist.T))
    return cov

def global_adaptive_hmc(U, K, grad_U, mass,inv_mass,iters, q_0, integrator,L=200,eps=0.175):
    D = q_0.shape[0]
    q_hist = []
    q_hist.append(np.asarray(q_0.reshape(D,)))
    accepted_num = 0
    cur_q = q_0.clone()

    for i in range(iters):

        nxt_q = integrator(U, K, grad_U, cur_q, mass,inv_mass, L=L, eps=eps)
        if torch.sum(torch.ne(nxt_q,cur_q)) > 0:
            accepted_num += 1
            q_hist.append(np.asarray(nxt_q.reshape(D,)))
        cur_q = nxt_q

        if i%50 == 0:
            print("progressed {}%".format(i*100/iters))

        if i%1000 and len(q_hist) > 200:
            mass = globally_adaptive_metric(q_hist[-200:])

            inv_mass = torch.inverse(mass + torch.eye(D)*1e-5)
    print("The acceptance rate is {}".format(accepted_num/iters))
    return q_hist

def estimated_mean(q_hist):
    '''
    For each 100; recompute empirical mean
    :param q_hist:
    :return:
    '''
    N = len(q_hist)
    result = []
    for i in range(50,N,50):
        result.append(np.mean(q_hist[:i],axis=0))


    return np.asarray(result)


def plot_analysis(ETS):
    N = ETS.shape[0]
    plt.figure(figsize=(20, 10))
    for i in range(D):
        plt.subplot(3,5,i+1)
        plt.axhline(y=0, linewidth=4, color='r')
        plt.ylim(-20,20)
        plt.plot(ETS[:,i])
        #plt.scatter(torch.arange(N),ETS[:, i],s=80)
        plt.xlabel('iterations')
        if i == 0:
            plt.ylabel("mean of q_0")
        else:
            plt.ylabel("mean of q{}".format(i))

    plt.tight_layout()
    plt.show()


def trajectory_trace(Q):
    contour(Q)


def generate_grid(h, w):

    x = torch.linspace(-10, 10,h)
    y = torch.linspace(-20, 20,w)
    grid = torch.stack([x.repeat(w), y.repeat(h,1).t().contiguous().view(-1)],1)
    r = [grid[:, 0].view(w, h), grid[:, 1].view(w, h)]
    return r


def contour(q_hist):
    '''
    Visualize One Direction of Funnel
    :return:
    '''
    # p = torch.Tensor(1).normal_(0, 3)
    # q = torch.Tensor(D).normal_(0, torch.exp(p).item())

    q,v = generate_grid(300, 300)
    z = torch.exp(
        -torch.exp(-0.5 * v ** 2 / (2) ** 2) - (torch.exp(-0.5 * q**2 / torch.exp(v))))
    plt.contour(q, v, z)
    plt.plot(q_hist[:,1],q_hist[:,0])
    plt.xlabel("q_1")
    plt.ylabel("q_0")
    plt.show()



def corrplot(trace,  maxlags=100):

    trace = trace[:,1]

    plt.acorr(trace-np.mean(trace),  normed=True, maxlags=maxlags)
    plt.xlim([0, maxlags])
    plt.xlabel("q_0")
    plt.show()



########################################################################
########################################################################
# manifolds


def rHMC(U, G_metric,iters, q_0,num_samples,eps=0.175):

    accepted_num = 0
    q_hist = []
    D = q_0.size()[0]
    reject_flag = 0

    for j in range(num_samples):
        mass = G_metric(U,q_0)

        p_0 = MultivariateNormal(torch.zeros(D), mass).rsample(torch.Size([1])).view(D, )

        currentH = Hamiltonian(U,G_metric,p_0,q_0)
        isinf(currentH)
        L = random.randint(5,20)

        p = p_0.clone()
        q = q_0.clone()

        trajectory = []
        for i in range(L):

            # update the momentum with fixed point iteration method
            p = FixPointIteration(iters, p, q, grad_H, eps, 'p', U, G_metric)

            # update the state with fixed point iteration method
            q = FixPointIteration(iters, p, q, grad_H, eps, 'q', U, G_metric)

            # save the state into trajectory
            trajectory.append(np.asarray(q))

            # make the final half step
            p = p - eps/2 * grad_H(U, G_metric, p, q)[0]

        trajectory = np.asarray(trajectory)
        #trajectory_trace(trajectory)
        #p = -p # ???

        # compute the acceptance criteria
        proposedH = Hamiltonian(U,G_metric,p,q)
        isinf(proposedH)
        Ratio = -torch.log(proposedH) + torch.log(currentH)
        isinf(Ratio)

        u = torch.log(torch.Tensor(1).uniform_(0,1))

        if Ratio > u and reject_flag == 0:
            accepted_num += 1
            q_hist.append(np.asarray(q))
            q_0 = q
        if j%10 == 0:
            print("Progress {}%".format(j/num_samples * 100))

    print("Acceptance Rate is {}".format(accepted_num/num_samples))
    return q_hist

def Hamiltonian(U,G_metric,p,q):
    '''

    :param U: -log(unormalized distribution of interest)
    :param G_metric: Riemannian Metric Tensor
    :param p: momentum variable
    :param q: state variable
    :return: Hamiltonian equation H(q,p)
    '''

    D = p.size()[0] # get the dimension of p,q
    G = G_metric(U,q)


    detG = torch.det(G)
    invG = torch.inverse(G)



    return U(q) + 0.5*torch.log((2*math.pi)**D*detG) + 0.5*p@invG@p

def grad_H(U, G_metric, p, q):
    '''

    :return: the graident of Hamiltonian equation w.r.t. q and p vector
    '''
    q = Variable(q,requires_grad = True)
    p = Variable(p, requires_grad=True)
    isinf(q)
    isinf(p)

    Hamiltonian(U, G_metric, p, q).backward()

    grad_q = q.grad
    grad_p = p.grad

    isinf(grad_q)
    isinf(grad_p)

    return grad_p,grad_q

def FixPointIteration(iters,p0,q0,grad_H,eps,param,U,G_metric):
    '''
    http://home.iitk.ac.in/~psraj/mth101/lecture_notes/lecture8.pdf
    For more detailed explaination on why fix point iteration method can be a good approximator

    :param: param represents which parameter to be updated
    :param: grad_H is the gradient of Hamiltonian equation w.r.t. x
    :return:
    '''
    p = p0
    q = q0
    if param == 'p':
        for i in range(iters):
            p = p0 - eps/2 * grad_H(U, G_metric, p, q)[1]
            isinf(p)
        return p
    else:
        for i in range(iters):
            q = q0 + eps/2 * grad_H(U, G_metric, p, q0)[0] + eps/2 * grad_H(U, G_metric, p, q)[0]
            isinf(q)

        return q


def Hessian(f, q):
    '''
    Compute the full hessian matrix for function f w.r.t. vector q
    :param f:
    :param q:
    :return:
    '''
    D = q.size()[0]
    H = torch.zeros(D, D)

    q = Variable(q,requires_grad=True)

    df = grad(f(q), q, create_graph=True)[0]

    for i in range(D):
        H[i] = grad(df[i], q, create_graph=True)[0]
    return H


def Expected_Fish_Info(U,q):
    '''
    If the target distribution is believed to conditioned on some random variable
    Example: Posterior Distribution depends on the random variable of a prior distribution

    :param U: -log(unormalized dist of interests)
    :param q: The collection of Empirical Smapled q
    :return: the expected fisher information matrix
    '''
    # sample some prior random variable
    pass

D=2
def EignDecomp(H):
    '''
    Write H = QWQ; where W is diagonal eigenvalues, and Q is corresponding eigenvectors
    :return:
    '''

    #try:
    H = torch.symeig(H,eigenvectors=True)
    # except:
    #     H = torch.eye(D) # reduced to Euclidean
    #     H = torch.symeig(H, eigenvectors=True)
    Q = H[1]
    V = H[0]

    return V, Q

alpha = 10


def SoftAbs_metric(U,q):

    D = q.size()[0]


    H = Hessian(U, q)

    #H = torch.eye(D)
    try:
        V,Q = EignDecomp(H)
    except:
        return torch.eye(D)*1e-4

    coth = torch.zeros(D)

    for i in range(D):
        if V[i] > 0:
            coth[i] = V[i] * (1+torch.exp(-2*alpha*V[i]))/(1-torch.exp(-2*alpha*V[i]))
        else:
            coth[i] = V[i] * (1+torch.exp(2*alpha*V[i]))/(torch.exp(2*alpha*V[i])-1)
    V = torch.diag(coth)
    H = Q@V@Q.t()
    return H

c = 0.1
def Smooth_Abs_metric(U,q):

    D = q.size()[0]


    H = Hessian(U, q)

    #H = torch.eye(D)
    try:
        V,Q = EignDecomp(H)
    except:
        return torch.eye(D)*1e-4

    V = (V**2+c**2)**(1/2)

    V = torch.diag(V)
    H = Q@V@Q.t()
    return H

def Hessian_metric(U,q):

    D = q.size()[0]
    H = Hessian(U,q)
    sigma = torch.abs(torch.Tensor(1).normal_(0,1e-5))
    H = H@H + torch.eye(D)*sigma
    return H

def Euclidean_metric(U, q):
    '''
    Define a Riemannian Metric: a nonsingula positive definite matrix
    Use Expected Fisher Information
    :return:
    '''
    D = q.size()[0]
    H = torch.eye(D)  # reduced to Euclidean Metric
    return H


########### Experiments on Simple HMC
######################################################################
D=10
x0 = torch.ones(D)
x0[0] = torch.Tensor(1).normal_(0,1)


mass=torch.eye(D)
inv_mass = mass

q_hist = simple_hmc(U, K, grad_U, mass, inv_mass, 1500, x0, leapfrog,L=20,eps=0.005)
q_hist = np.asarray(q_hist)
plot_analysis(q_hist)
contour(q_hist)
corrplot(q_hist)

########## Experiments on Adaptive Global HMC
#####################################################################

D=10
x0 = torch.ones(D)
x0[0] = torch.Tensor(1).normal_(0,1)


mass=torch.eye(D)
inv_mass = mass

q_hist = global_adaptive_hmc(U, K, grad_U, mass, inv_mass, 1500, x0, leapfrog,L=20,eps=0.3)
q_hist = np.asarray(q_hist)
plot_analysis(q_hist)
contour(q_hist)
corrplot(q_hist)



########## Riennmanian Metric ##########################################
#####################################################################

D=10
x0 = torch.ones(D)
x0[0] = torch.Tensor(1).normal_(0,1)


mass=torch.eye(D)
inv_mass = mass

q_hist = rHMC(U, Hessian_metric,2, x0,1500,eps=0.3)
q_hist = np.asarray(q_hist)
plot_analysis(q_hist)
contour(q_hist)
corrplot(q_hist)


########## Experiments on Riemannian HMC with SoftAb Metric ############################
#####################################################################
#
D=10
x0 = torch.ones(D)
x0[0] = torch.Tensor(1).normal_(0,1)


mass=torch.eye(D)
inv_mass = mass

q_hist = rHMC(U, Hessian_metric,3, x0,500,eps=0.2)
q_hist = np.asarray(q_hist)
plot_analysis(q_hist)
contour(q_hist)
corrplot(q_hist)