import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import odeint

def standard_param_values():
    # STANDARD PARAMETER VALUES
    sigma = 10.
    b = 8./3.
    r = 28.

    return sigma, b, r

def f(w, t, p):
    x, y, z = w
    sigma, b, r = p
    return [ sigma*(y - x), r*x - y - x*z, x*y - b*z ]

def integrate_system(w0, t1, dt, p=None):
    if p is None:
        p = standard_param_values()
    
    # Create time array
    t = np.arange(0, t1, dt)

    # Solve ODE system
    w = odeint(f, w0, t, args=(p,))

    return w, t

def unpack_state(w0):

    return w0[0], w0[1], w0[2]
    
def generate_background_attractor():
    # Initial conditions
    x0 = 5.86
    y0 = 7.11
    z0 = 21.7

    # Time settings
    # Total integration time
    t1 = 250.
    # Output timestep
    dt = 0.01

    w, t = integrate_system([x0, y0, z0], t1, dt)

    return w, t

def generate_ensemble_initial_conditions(w0_centre, N):
    # Covariance for uncorrelated variable 
    cov = [[0.05, 0, 0], [0, 0.05, 0], [0, 0, 0.05]]

    print 'Usando %d condiciones iniciales alrededor del punto ' % (N) + \
        '(x0, y0, z0) = (%5.2f, %5.2f, %5.2f)' % \
        (unpack_state(w0_centre))

    # Generate initial conditions as samples from a 3D normal distribution
    w0 = np.random.multivariate_normal(w0_centre, cov, N)

    return w0

def integrate_ensemble(w0, t1, dt):
    N = w0.shape[0]
    w_pert = np.zeros((int(t1/dt), 3, N))
    for nn in range(N):
        w_pert[:, :, nn], t = integrate_system(w0[nn, :], t1, dt)

    C = np.cov(w_pert[-1, :, :])
    print 'std(x) = %5.2f' % (C[0, 0])
    print 'std(y) = %5.2f' % (C[1, 1])
    print 'std(z) = %5.2f' % (C[2, 2])

    return w_pert, t

def plot_background_attractor(proj='xz', axes=None):
    w, t = generate_background_attractor()

    if proj == 'xz':
        u1 = w[:, 0]
        u2 = w[:, 2]
    elif proj == 'xy':
        u1 = w[:, 0]
        u2 = w[:, 1]
    elif proj == 'yz':
        u1 = w[:, 1]
        u2 = w[:, 2]

    if axes is None:
        plt.plot(u1, u2, color='grey')
    else:
        axes.plot(u1, u2, color='grey')

def plot_orbit(w, colour='black', extremes=False, proj='xz', axes=None):
    if proj == 'xz':
        u1 = w[:, 0]
        u2 = w[:, 2]
    elif proj == 'xy':
        u1 = w[:, 0]
        u2 = w[:, 1]
    elif proj == 'yz':
        u1 = w[:, 1]
        u2 = w[:, 2]

    if axes is None:
        plt.plot(u1, u2, color=colour)
        if extremes:
            plt.plot(u1[0], u2[0], marker='^', linestyle='none', \
                         color=colour)
            plt.plot(u1[-1], u2[-1], marker='o', linestyle='none', \
                         color=colour)
    else:
        axes.plot(u1, u2, color=colour)
        if extremes:
            axes.plot(u1[0], u2[0], marker='^', linestyle='none', \
                          color=colour)
            axes.plot(u1[-1], u2[-1], marker='o', linestyle='none', \
                          color=colour)

def plot_components(w, t, colour='black', axarr=None):
    write_labels = False

    if axarr is None:
        f, axarr = plt.subplots(3, 1, sharex=True, sharey=True)
        write_labels = True

    axarr[0].plot(t, w[:, 0], color=colour)
    if write_labels:
        axarr[0].set_ylabel('x')
    
    axarr[1].plot(t, w[:, 1], color=colour)
    if write_labels:
        axarr[1].set_ylabel('y')
    
    axarr[2].plot(t, w[:, 2], color=colour)
    if write_labels:
        axarr[2].set_xlabel('t')
        axarr[2].set_ylabel('z')

    return axarr

def plot_attractor(w, colour='black', axes=None):
    from mpl_toolkits.mplot3d import Axes3D

    if axes is None:
        axes = plt.gca(projection='3d')

    u1 = w[:, 0]
    u2 = w[:, 1]
    u3 = w[:, 2]

    axes.plot(u1, u2, u3, color=colour)

def find_system_state(w, t, t_0, tol=0.001):
    ix = np.where(np.logical_and(t>t_0-tol, t<t_0+tol))
    return w[ix, :].flatten()

def plot_pdfs(w_fct, w_ana=None):
    f, axarr = plt.subplots(1, 3, sharey=True, figsize=(10, 5))

    bins = np.linspace(-20, 20, 41)
    axarr[0].hist(w_fct[-1, 0, :], bins=bins, normed=True)
    if not (w_ana is None):
        axarr[0].axvline(w_ana[0], color='black', linewidth=2)
    axarr[0].set_xlabel('x')

    bins = np.linspace(-30, 30, 61)
    axarr[1].hist(w_fct[-1, 1, :], bins=bins, normed=True)
    if not (w_ana is None):
        axarr[1].axvline(w_ana[1], color='black', linewidth=2)
    axarr[1].set_xlabel('y')

    bins = np.linspace(0, 50, 51)
    axarr[2].hist(w_fct[-1, 2, :], bins=bins, normed=True)
    if not (w_ana is None):
        axarr[2].axvline(w_ana[2], color='black', linewidth=2)
    axarr[2].set_xlabel('z')
