#! /usr/bin/env python
# Time-stamp: <2015-02-17 17:01:24 pbrowne>
# MATLAB original by Javier Amezcua
# Python conversion by David Livings
# Python modifications by Philip Browne

###############################################################################
#    lorenz96_empire.py Implements Lorenz 1996 model with EMPIRE coupling
#
#The MIT License (MIT)
#
#Copyright (c) 2014 Philip A. Browne
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE.
#
#Email: p.browne@reading.ac.uk
###############################################################################

from mpi4py import MPI #import MPI stuff
import sys
import numpy as np

def initialise_mpi():
    #call mpi_init ( mpi_err )
    #not needed as called automatically when importing mpi4py
    global cpl_root,cpl_mpi_comm
    mdl_num_proc = 1
    da = np.array(0, dtype='i')
    mpi_comm_world = MPI.COMM_WORLD
    world_id = mpi_comm_world.Get_rank()
    world_size = mpi_comm_world.Get_size()

    models=mpi_comm_world.Split(da,world_id)
    models_size= models.Get_size()
    models_id  = models.Get_rank()
    mdlcolour = models_id / mdl_num_proc
    mdl_mpi_comm = models.Split(mdlcolour,models_id)
    mdl_id = mdl_mpi_comm.Get_rank()
    if ( mdl_id == 0):
        couple_colour = 9999
    else:
        couple_colour = MPI.UNDEFINED
    cpl_mpi_comm =  mpi_comm_world.Split(couple_colour,mdlcolour)
    if ( mdl_id == 0):
        nens = cpl_mpi_comm.Get_size()
        ptcl_id = cpl_mpi_comm.Get_rank()
        nda = world_size-models_size
        nens = nens-nda
        for da in range(1,nda+1):
            if ( ptcl_id < np.float64((nens*(da)))/np.float64(nda)):
                cpl_root = da-1 + nens
                break
            else:
                cpl_root = -1


def lorenz96(x_0):
    """This function computes the time evolution of the Lorenz 96 model.

    It is the general case for N variables; often N=40 is used.

    The Lorenz 1996 model is cyclical: dx[j]/dt=(x[j+1]-x[j-2])*x[j-1]-x[j]+F

    Inputs:   - x_0, original position.
    Outputs:  - x,   the nature run.   """

    # Initialize values for integration
    t = np.arange(0,tmax+tstep/2,tstep)

    x = x_0

    # The integration
    for i in range(len(t)-1): # for each time
        x[:] = rk4(x[:]) # solved via RK4
        cpl_mpi_comm.Send([x,len(x),MPI.DOUBLE],dest=cpl_root,tag=1)
        cpl_mpi_comm.Recv([x,len(x),MPI.DOUBLE],source=cpl_root,tag=MPI.ANY_TAG)
        print ' '.join(map(str,x[:]))

##____________________________________
## Functions for the integration
def rk4(Varsold):
    "This function contains the RK4 routine."
    k1 = f(Varsold                   )
    k2 = f(Varsold+(1.0/2.0)*tstep*k1)
    k3 = f(Varsold+(1.0/2.0)*tstep*k2)
    k4 = f(Varsold+          tstep*k3)
    Varsnew = Varsold + tstep*(k1+2*k2+2*k3+k4)/6.0
    return Varsnew

def f(x):
    "The actual Lorenz 1996 model."
    #global N, F
    k=np.empty_like(x)
    k.fill(np.nan)
    # Remember it is a cyclical model, hence we need modular algebra
    for j in range(N):
        k[j]=(x[(j+1)%N]-x[(j-2)%N])*x[(j-1)%N]-x[j]+F
    return k


initialise_mpi()


global N,F,tstep,tmax

# Load the initial conditions, timestep, maximum time, and parameters
file = open('lorenz96.dat', 'r')
for line in file:
    if(line.split()[1] == 'tstep'):
        tstep  = eval(line.split()[0])
    elif(line.split()[1] == 'tmax'):
        tmax   = eval(line.split()[0])
    elif(line.split()[1] == 'F'):
        F      = eval(line.split()[0])
    elif(line.split()[1] == 'N'):
        N      = eval(line.split()[0])
    else:
        print 'Error in data file lorenz96.dat'
        sys.exit(1)
file.close()


x_0 = 0.0*np.array(range(0,N))
x_0[:] = F
pert = 0.05
pospert = np.ceil(N/2.0)-1
x_0[pospert] = F+pert

print ' '.join(map(str,x_0[:]))

cpl_mpi_comm.Send([x_0,len(x_0),MPI.DOUBLE_PRECISION],dest=cpl_root,  tag=1)
cpl_mpi_comm.Recv([x_0,len(x_0),MPI.DOUBLE_PRECISION],source=cpl_root,tag=MPI.ANY_TAG,status=None)

lorenz96(x_0)

