from mpi4py import MPI
import sys
import time
import numpy as np

#import our library
from lib_SI_CI_V6 import *

#Create a dictionary of functions
func_dict={'1':'M_spx','2':'M_spy','3':'M_spz','4':'M_ax','5':'M_ay','6':'M_az','7':'Q'}


switch=int(sys.argv[1]) #Choosing function to calculate: see below
key=int(sys.argv[2]) #Choosing a set of parameters: see below
z=float(sys.argv[3]) #choosing a distance from the interface between a Ferromagnet and a Superconductor
Ez=float(sys.argv[4]) #Zeeman splitting energy

T_up=0.06
l_sf=100.
kT=0.01
alpha=0.3*np.pi
zeta=0.2*np.pi
reg_c=1e-7

theta=0.0
T_down=T_up
if (key==1):
	theta=0.15
elif (key==2):
	theta=0.32
elif (key==3):
	theta=0.49
elif (key==4):
	theta=0.66
elif (key==5):
	theta=0.83
elif (key==6):
	T_down=0.02
elif (key==7):
	pass
else:
	sys.exit('ERROR: key=1-7')

eV=0.0
#Create our system with eV=0, we will update it on fly for the calculations below
Sys=SI_CI(T_up=T_up,T_down=T_down,theta=theta,Ez=Ez,eV=eV,l_sf=l_sf,kT=kT,reg_c=reg_c,alpha=alpha,zeta=zeta)

#Nprocessors=int(commands.getoutput('nproc'))
#print "Number of processors is, ",Nprocessors 

comm = MPI.COMM_WORLD
size = comm.Get_size() # total number of processes - usually not more than the number of processors. mpi4py uses one process per core.
rank = comm.Get_rank() # id of the process (or processor if size=number of cores)

u_cutoff=0.1
x_cutoff=10.0

values=[]

eVmin=float(sys.argv[5])
eVmax=float(sys.argv[6])

step_eV=(eVmax-eVmin)/size
eV_start=eVmin+step_eV*rank
eV_end=eV_start+step_eV
Npoints=int((eV_end-eV_start)*350)
eVs=np.linspace(eV_start,eV_end,Npoints,endpoint=False)
counter=1
for eV in eVs:
	print "I am processor number ",rank," and I am calculating point number ",counter," of "+str(Npoints)
	#We first update the value of eV
	Sys.eV=eV
	time_start=time.time()
	data=[]
	if (switch==1):
		data=M_spx(Sys,x_cutoff,z)
	elif (switch==2):
		data=M_spy(Sys,x_cutoff,z)
	elif (switch==3):
		data=M_spz(Sys,x_cutoff,z)
	elif (switch==4):
		data=M_ax(u_cutoff,x_cutoff,Sys,z)
	elif (switch==5):
		data=M_ay(u_cutoff,x_cutoff,Sys,z)
	elif (switch==6):
		data=M_az(u_cutoff,x_cutoff,Sys,z)
	elif (switch==7):
		data=Q(u_cutoff,x_cutoff,Sys,z)
	time_end=time.time()
	dt=(time_end-time_start)/60. #time elapsed for evaluation
	result=[eV,data[0],data[1],int(data[2]),int(data[3]),dt]
	values.append(result)
	counter+=1

##As soon as the values are obtained, we gather them in one file
if (rank==0):
	NAME=func_dict[str(switch)]+'_eVmin='+str(eVmin)+'_eVmax='+str(eVmax)+'_theta='+str(theta)+'_z='+str(z)+'_Ez='+str(Ez)
	f=open(NAME+'.dat','w')
	print >>f, "#Here goes the calculation of "+func_dict[str(switch)]+" as a function of eV"
	print >>f, "#In colums: eV, res, error, Nevals, Nwarn, dt(mins)"
	for res in values:
		print >>f, res[0], res[1], res[2], res[3], res[4], res[5]
	f.close()
	for i in xrange(1,size):
		received_values=comm.recv(source=i,tag=i)
		f=open(NAME+'.dat','a')
		for res in received_values:
			print >>f, res[0], res[1], res[2], res[3], res[4], res[5]
		f.close()
else:
	comm.send(values,dest=0,tag=rank)
	#request=comm.isend(value,dest=0,tag=rank) #Checks if the process #rank can send his message
	#request.Wait() # Wait until the sending process is finished
