
xplor.requireVersion("2.41")


#
# slow cooling protocol in torsion angle space for protein G. Uses 
# NOE, J-coupling restraints.
#
# this script performs annealing from an extended structure.
# It is faster than the original anneal.py
#
# CDS 2009/07/24
#

# this checks for typos on the command-line. User-customized arguments can
# also be specified.
#
(opts,args) = xplor.parseArguments(["quick"])

quick=False
for opt in opts:
    if opt[0]=="quick":  #specify -quick to just test that the script runs
        quick=True
        pass
    pass


# filename for output structures. This string must contain the STRUCTURE
# literal so that each calculated structure has a unique name. The SCRIPT
# literal is replaced by this filename (or stdin if redirected using <),
# but it is optional.
#
pdbTemplate = "SCRIPT_STRUCTURE.sa"
numberOfStructures=100   #usually you want to create at least 20 
numDock=5      # number os docking attempts before refinement

if quick:
    numberOfStructures=3
    pass

# protocol module has many high-level helper functions.
#
import protocol

protocol.initRandomSeed()   #set random seed - by time

command = xplor.command


einSel="resid 170:424"
nprSel="resid 1:85"

protocol.initStruct('param/new_tsap.psf')
protocol.initCoords('nprTagged.pdb')
protocol.loadPDB('refineEIN_58.sa.best',deleteUnknownAtoms=True)

protocol.initParams("./param/tsap_gen.par")


#
# a PotList contains a list of potential terms. This is used to specify which
# terms are active during refinement.
#
from potList import PotList
potList = PotList()

# parameters to ramp up during the simulated annealing protocol
#
from simulationTools import MultRamp, StaticRamp, InitialParams, FinalParams

rampedParams=[]
highTempParams=[]

# compare atomic Cartesian rmsd with a reference structure
#  backbone and heavy atom RMSDs will be printed in the output
#  structure files
#
#from posDiffPotTools import create_PosDiffPot
#refRMSD = create_PosDiffPot("refRMSD",
#                            "(name CA or name C or name N)",
#                            pdbFile='start_ein.pdb')
#refRMSD.setUpperBound(2) # difference in angstrom
#potList.append(refRMSD)


pcs = PotList('pcs')
potList.append( pcs )
from rdcPotTools import create_RDCPot
from varTensorTools import create_VarTensor
oTensor_tsap = create_VarTensor("oTensor")
oTensor_tsap.setDaMax(1e5)
oTensor_tsap.setFreedom('fixRh, fixDa')

for (name, oTensor, file) in [
#    ('tsap-ein-n' , oTensor_tsap, "data/PCS_N_e45c_Yb_major_ein_only.tbl"),
#    ('tsap-npr-n' , oTensor_tsap, "data/PCS_N_e45c_Yb_major_npr_only.tbl"),
    ('tsap-ein-hn', oTensor_tsap, "data/PCS_HN_e45c_Yb_major_e1n_only.tbl"),
    ('tsap-npr-hn', oTensor_tsap, "data/PCS_HN_e45c_Yb_major_npr_only.tbl"),
                               ]:
    pot = create_RDCPot('pcs-'+name,file,oTensor=oTensor)
    pot.setUseDistance(True)
    pot.setAveType('sum')
    pot.setScale(100)
    pot.setShowAllRestraints(True)
    pcs.append(pot)
    pass

from varTensorTools import calcTensor
calcTensor(oTensor_tsap)

rampedParams.append( MultRamp(1,150,"pcs.setScale(VALUE)") )
rampedParams.append(
    StaticRamp("calcTensor(oTensor_tsap,expts=[ pcs['pcs-tsap-npr-hn'] ])") )


# orientation Tensor - used with the dipolar coupling term
#  one for each medium
#   For each medium, specify a name, and initial values of Da, Rh.
#
from varTensorTools import create_VarTensor
media={}
#                        medium  Da   rhombicity
for (medium,Da,Rh) in [ ('m1',  10.0, 0.57),
                        ('m2',   9.1, 0.57),
                        ]:
    oTensor = create_VarTensor(medium)
    oTensor.setDa(Da)
    oTensor.setRh(Rh)
    media[medium] = oTensor
    pass

media['m2'].setFreedom('fixAxisTo m1, fixRhTo m1, fixDa')
media['m1'].setFreedom('fixRh, fixDa')



# dipolar coupling restraints for protein amide NH.  
#
# collect all RDCs in the rdcs PotList
#
# RDC scaling. Three possible contributions.
#   1) gamma_A * gamma_B / r_AB^3 prefactor. So that the same Da can be used
#      for different expts. in the same medium. Sometimes the data is
#      prescaled so that this is not needed. scale_toNH() is used for this.
#      Note that if the expt. data has been prescaled, the values for rdc rmsd
#      reported in the output will relative to the scaled values- not the expt.
#      values.
#   2) expt. error scaling. Used here. A scale factor equal to 1/err^2
#      (relative to that for NH) is used.
#   3) sometimes the reciprocal of the Da^2 is used if there is a large
#      spread in Da values. Not used here.
#
from rdcPotTools import create_RDCPot, scale_toNH
rdcs = PotList('rdc') 
for (medium,expt,file,                 scale) in \
    [('m1','NH_e1n' , 'data/e1n_rdc_no_overlap_051415.tbl'            ,1),
#     ('m2','NH_npr' , 'data/npr_rdc_062915_newshifts.tbl'            ,1),
     ]:
    rdc = create_RDCPot("%s_%s"%(medium,expt),file,media[medium])

    #1) scale prefactor relative to NH
    #   see python/rdcPotTools.py for exact calculation
    # scale_toNH(rdc) - not needed for these datasets -
    #                        but non-NH reported rmsd values will be wrong.

    #3) Da rescaling factor (separate multiplicative factor)
    # scale *= ( 1. / rdc.oTensor.Da(0) )**2
    rdc.setScale(scale)
    rdc.setShowAllRestraints(1) #all restraints are printed during analysis
    rdc.setThreshold(1.5)       # in Hz
    rdcs.append(rdc)
    pass
potList.append(rdcs)
rampedParams.append( MultRamp(0.05,5.0, "rdcs.setScale( VALUE )") )

# calc. initial tensor orientation
# and setup tensor calculation during simulated annealing
#
from varTensorTools import calcTensorOrientation, calcTensor
for medium in media.keys():
    calcTensorOrientation(media[medium])
#    rampedParams.append( StaticRamp("calcTensor(media['%s'])" % medium) )
    pass

# set up NOE potential
noe=PotList('noe')
potList.append(noe)
from noePotTools import create_NOEPot
for (name,scale,file) in [
    ('unambig',1,"data/15N_NOESY_EIN_unambig_061815_20160517.tbl"),
                          #add entries for additional tables
                          ]:
    pot = create_NOEPot(name,file)
    # pot.setPotType("soft") # if you think there may be bad NOEs
    pot.setScale(scale)
    noe.append(pot)
rampedParams.append( MultRamp(2,30, "noe.setScale( VALUE )") )

csMap = create_NOEPot('csMap','data/shifts_noe_CLORE_20160517.tbl')
potList.append(csMap)
# pot.setPotType("soft") # if you think there may be bad NOEs
rampedParams.append( MultRamp(0.01,30, "csMap.setScale( VALUE )") )



## set up J coupling - with Karplus coefficients
#from jCoupPotTools import create_JCoupPot
#jCoup = create_JCoupPot("jcoup","jna_coup.tbl",
#                        A=6.98,B=-1.38,C=1.72,phase=-60.0)
#potList.append(jCoup)

# Set up dihedral angles - EIN only.
from xplorPot import XplorPot
protocol.initDihedrals("data/TALOS_COMPLEX_051815_20160517.tbl",
                       #useDefaults=False  # by default, symmetric sidechain
                                           # restraints are included
                       )
potList.append( XplorPot('CDIH') )
highTempParams.append( StaticRamp("potList['CDIH'].setScale(10)") )
rampedParams.append( StaticRamp("potList['CDIH'].setScale(200)") )
# set custom values of threshold values for violation calculation
#
potList['CDIH'].setThreshold( 5 ) #5 degrees is the default value, though



# gyration volume term 
#
from gyrPotTools import create_GyrPot
gyr = create_GyrPot("Vgyr",
#                    "resid 1:56" # selection should exclude disordered tails
                    )
potList.append(gyr)
rampedParams.append( MultRamp(.002,1,"gyr.setScale(VALUE)") )


#contact term
from residueAffPotTools import create_ResidueAffPot
contact = create_ResidueAffPot("contact")
potList.append(contact)

# hbdb - knowledge-based backbone hydrogen bond term
#
protocol.initHBDB()
potList.append( XplorPot('HBDB') )

#New torsion angle database potential
#
from torsionDBPotTools import create_TorsionDBPot
torsionDB = create_TorsionDBPot('torsionDB')
potList.append( torsionDB )
rampedParams.append( MultRamp(.002,2,"torsionDB.setScale(VALUE)") )

#
# setup parameters for atom-atom repulsive term. (van der Waals-like term)
#
vdw = XplorPot('VDW')
potList.append( vdw )
rampedParams.append( StaticRamp("protocol.initNBond(nbxmod=4)") )
rampedParams.append( MultRamp(0.9,0.8,
                              "command('param nbonds repel VALUE end end')") )
rampedParams.append( MultRamp(.004,4,
                              "command('param nbonds rcon VALUE end end')") )
# nonbonded interaction only between CA atoms
highTempParams.append( StaticRamp("""protocol.initNBond(cutnb=100,
                                                        rcon=0.004,
                                                        tolerance=45,
                                                        repel=1.2,
                                                        onlyCA=1)""") )


potList.append( XplorPot("BOND") )
potList.append( XplorPot("ANGL") )
potList['ANGL'].setThreshold( 5 )
rampedParams.append( MultRamp(0.4,1,"potList['ANGL'].setScale(VALUE)") )
potList.append( XplorPot("IMPR") )
potList['IMPR'].setThreshold( 5 )
rampedParams.append( MultRamp(0.1,1,"potList['IMPR'].setScale(VALUE)") )
      


# Give atoms uniform weights, except for the anisotropy axis
#
protocol.massSetup()

from ivm import IVM
dynRigid = IVM() # used for initial rigid-body docking

dynRigid.fix(einSel)
dynRigid.group(nprSel)

protocol.torsionTopology(dynRigid)


dyn = IVM()   # this used for refinement

dyn.fix("(%s) and (name C or name CA or name N)" % einSel)
dyn.group("(%s) and (name C or name CA or name N)" % nprSel)
dyn.group("resid 45") # tag sidechain should be rigid
protocol.torsionTopology(dyn)

from simulationTools import AnnealIVM
init_t  = 3000
cool = AnnealIVM(initTemp =init_t,
                 finalTemp=25,
                 tempStep =25,
                 ivm=dyn,
                 rampedParams = rampedParams)





def calcOneStructure(loopInfo):
    """ this function calculates a single structure, performs analysis on the
    structure, and then writes out a pdb file, with remarks.
    """

    minEnergy = 1e30  # big number
    startPos  = xplor.simulation.atomPosArr()
    InitialParams(rampedParams) 
    InitialParams( highTempParams )

    k=0
    while k < numDock:
        xplor.simulation.setAtomPosArr(startPos)

        from atomAction import randomizeDomainPos
    	randomizeDomainPos( nprSel, deltaPos=45 )

        #dock using gyr, csMap vdw, 
        protocol.initMinimize(dyn,
                              potList=[csMap,vdw],
                              numSteps=1000)

#        xray.setVerbose(2)
#        dyn.setVerbose( dyn.printNodeDef )
        dyn.run() ; 
        dyn.setVerbose( 0 )

        InitialParams(rampedParams) 
        protocol.initMinimize(dyn,
                              potList=[pcs,csMap,vdw,contact],
                              numSteps=1000)
        dyn.run()


        
        scorePot = PotList()
        for p in (csMap,vdw,contact):
            scorePot.append(p)
            pass
        scoreEnergy = scorePot.calcEnergy()
        print 'dock iteration: %d   score energy: %.2f' %(k,scoreEnergy)

	if scoreEnergy < minEnergy :
            minPos    = xplor.simulation.atomPosArr()
            minEnergy = scoreEnergy
            pass

        k += 1
        pass
    
    xplor.simulation.setAtomPosArr(minPos)

    InitialParams( rampedParams )
    InitialParams( highTempParams )
    
    #high-temp dynamics
    protocol.initDynamics(dyn,
                          potList=potList,
                          bathTemp=init_t,
                          initVelocities=1,
                          finalTime=800,   # stops at 800ps or 8000 steps
                          numSteps=8000,   # whichever comes first
                          printInterval=100)

    dyn.setETolerance( init_t/100 )  #used to det. stepsize. default: t/1000 
    dyn.run()
    
    # initialize parameters for cooling loop
    InitialParams( rampedParams )


    # initialize integrator for simulated annealing
    #
    protocol.initDynamics(dyn,
                          potList=potList,
                          numSteps=100,       #at each temp: 100 steps or
                          finalTime=.2 ,       # .2ps, whichever is less
                          printInterval=100)

    # perform simulated annealing
    #
    cool.run()

    # refit SAXS for final minimization
    FinalParams( rampedParams )

    # final torsion angle minimization
    #
    protocol.initMinimize(dyn)
    dyn.run()

    #do analysis and write structure after return
    pass



from simulationTools import StructureLoop
StructureLoop(numStructures=numberOfStructures,
              pdbTemplate=pdbTemplate,
              structLoopAction=calcOneStructure,
              doWriteStructures=True,
              calcMissingStructs=True,
              genViolationStats=True,
              averageRegularize=False,
              averagePotList=potList,
#              averageCrossTerms=refRMSD,
              averageFitSel="name CA and (%s)" % nprSel,
              averageTopFraction=0.1, #report top 10% of structs
              averageContext=FinalParams(rampedParams)).run()



