#!/usr/local/bin/pyXplor
# pylint: disable=import-error

import sys
import re

from simulationTools import StructureLoop, AnnealIVM
import xplorPotTools
from posDiffPotTools import create_PosDiffPot
from simulationTools import MultRamp, StaticRamp, InitialParams
from avePot import AvePot
from xplorPot import XplorPot
from complexationPot import OctoPot
import rdcPotTools
from potList import PotList
from atomSel import AtomSel
from pcsTools import calcXTensor
from varTensorTools import create_VarTensor, calcTensor, calcTensorOrientation
from ivm import IVM
import protocol

init_temp = 1000
stop_temp = 100

# before high temp dynamics, exp. constraints need to be ramped
expramp_duration = 1
expramp_pcs_target = 40

# before annealing, the high temp dynamics will run for a duration of
# `hightempDuration`
hightemp_duration = 9

# annealing is performed in `annealingTempstep `temp steps, and each temp step
# is simulated for duration `annealingTimestep`
annealing_timestep = 1
annealing_tempstep = 5

# number of structures that should be calculated
n_structures = 2

# hightemp_timestep = hightemp_duration / hightemp_output_steps


def get_X_components(ln):
    return {
        'ce': {
            'dXax': 2.1e-32,
            'dXrh': 0.7e-32,
            'X': 5.6e-32
        },
        'pr': {
            'dXax': 3.4e-32,
            'dXrh': 2.1e-32,
            'X': 11.2e-32
        },
        'nd': {
            'dXax': 1.7e-32,
            'dXrh': 0.5e-32,
            'X': 11.4e-32
        },
        'sm': {
            'dXax': 0.2e-32,
            'dXrh': -0.1e-32,
            'X': 0.6e-32
        },
        'eu': {
            'dXax': -2.3e-32,
            'dXrh': -1.6e-32,
            'X': 6.0e-32
        },
        'gd': {
            'dXax': 0e-32,
            'dXrh': 0e-32,
            'X': 55.1e-32
        },
        'tb': {
            'dXax': 42.1e-32,
            'dXrh': 11.2e-32,
            'X': 82.7e-32
        },
        'dy': {
            'dXax': 34.7e-32,
            'dXrh': 20.3e-32,
            'X': 99.2e-32
        },
        'ho': {
            'dXax': 18.5e-32,
            'dXrh': 5.8e-32,
            'X': 98.5e-32
        },
        'er': {
            'dXax': -11.6e-32,
            'dXrh': -8.6e-32,
            'X': 80.3e-32
        },
        'tm': {
            'dXax': -21.9e-32,
            'dXrh': -20.1e-32,
            'X': 50.0e-32
        },
        'yb': {
            'dXax': -8.3e-32,
            'dXrh': -5.8e-32,
            'X': 18.0e-32
        }
    }[ln.lower()]


def print_header(s):
    n = 78 - len(s)
    nl = n // 2
    nr = n - nl
    print("=" * nl + f" {s} " + "=" * nr)


# --------------------------------- load PDB --------------------------------- #

# load custom parameters for lanthanides
protocol.initParams(files=["ion.par"])
protocol.loadPDB('lbt-cbm.pdb')

# ------------------------------ load PCS data ------------------------------- #
lnAtom = AtomSel('resid 200 and name TB+3')[0]
pcsPots = PotList('pcs')


def init_tensor(name, pos, ln):
    t = create_VarTensor(name)
    t.setDa(abs(get_X_components(ln)['dXax']) * 1e30)
    t.setRh(abs(get_X_components(ln)['dXrh']) * 1e30)
    return t


def fmt_tensor(t):
    dXax = "dXax: %7.4e" % t.Da()
    dXrh = "dXrh: %7.4e" % t.Rh()
    xyz = "(x,y,z): %s" % lnAtom.pos()
    return ', '.join([dXax, dXrh, xyz])


def print_tensors():
    print('\n'.join([fmt_tensor(pot) for pot in pcsPots]))


LANTHANIDES = ['Tb', 'Ho', 'Yb', 'Tm']  # -> bad
# LANTHANIDES = ['Tb', 'Ho', 'Yb']  # -> bad(few steps before fail)
# LANTHANIDES = ['Tb', 'Ho', 'Tm']  # -> bad (fails after 1 step)
# LANTHANIDES = ['Tb', 'Yb', 'Tm']  # -> bad (few steps before fail)
# LANTHANIDES = ['Ho', 'Yb', 'Tm']  # -> bad (fails instantly)
# LANTHANIDES = ['Tb', 'Ho']  # -> bad (few steps before fail)
# LANTHANIDES = ['Tb', 'Yb']  # -> bad (few steps before fail)
# LANTHANIDES = ['Tb', 'Tm']  # -> bad (few steps before fail)
# LANTHANIDES = ['Ho', 'Yb']  # -> bad (fails instantly)
# LANTHANIDES = ['Ho', 'Tm']  # -> bad (fails instantly)
# LANTHANIDES = ['Yb', 'Tm']  # -> bad (few steps before fail)
# LANTHANIDES = ['Tb']  # -> good
# LANTHANIDES = ['Ho']  # -> good
# LANTHANIDES = ['Yb']  # -> good
# LANTHANIDES = ['Tm']  # -> good
pcsSpecs = [(f'{ln}PCS', f'{ln}.tbl', init_tensor(f'{ln}PCS', lnAtom, ln))
            for ln in LANTHANIDES]

for (name, pcs_list, tensor) in pcsSpecs:
    pcsPot = rdcPotTools.create_RDCPot(name, pcs_list, oTensor=tensor)
    pcsPot.setUseDistance(True)
    pcsPot.setScale(1)
    pcsPot.setShowAllRestraints(True)
    pcsPot.setVerbose(True)
    pcsPots.append(pcsPot)

for pot in pcsPots:
    print_header(f"{pot.instanceName()} tensor calculation")
    calcXTensor(pot.oTensor, maxDisplacement=35)
    print(fmt_tensor(pot.oTensor))

tbCoord = OctoPot('TbCoord',
                  'resid 200 and name TB+3',
                  [['(resid 14 and name OE2)', '(resid  7 and name OD2)'],
                   ['(resid  3 and name OD2)', '(resid 11 and name OE2)'],
                   ['(resid  9 and name O)  ', '(resid  5 and name OD1)']],
                  optDist=2.48,
                  distTol=0.26,
                  distScale=2,
                  angleScale=1.5)

# ---------------------------- create potentials ----------------------------- #

# vdw-like repulsion
protocol.initNBond(
    nbxmod=3,  # only use potential when 3 or more bonds between atoms
    repel=0.9,  # initial effective atom radius
)

# initialize RMSD reference to XRD-CBM backbone heavy atoms
xrdRMSD = create_PosDiffPot(
    "xrdRMSD",
    # selection for LBT-CBM
    selection="(name CA or name C or name N) and (residue 20:157)",
    # comparison for additional, non-energy RMSD
    cmpSel="(not name H) and (resid 20:162)",
    pdbFile="lbt-cbm.pdb")

# dihedrals
protocol.initDihedrals('all-dihedrals.tbl', scale=1, useDefaults=False)

radRamp = MultRamp(0.4, 0.8,
                   "xplor.command('param nbonds repel VALUE end end')")
angRamp = MultRamp(0.4, 1.0, "potList['ANGL'].setScale(VALUE)")
impRamp = MultRamp(0.1, 1.0, "potList['IMPR'].setScale(VALUE)")
vdwRamp = MultRamp(0.004, 4,
                   "xplor.command('param nbonds rcon VALUE end end')")


def scalepcs_for_time(t):
    s = min(expramp_pcs_target * t / expramp_duration, expramp_pcs_target)
    for p in pcsPots:
        p.setScale(s)
    return s


# common scaling
xrdRMSD.setScale(1.5)
tbCoord.setScale(1.5)

# ignore vdw repulsion at high temperatures
hightempPotList = PotList()
hightempPotList.append(XplorPot('BOND'))
hightempPotList.append(XplorPot('ANGL'))
hightempPotList.append(XplorPot('IMPR'))
hightempPotList.append(XplorPot('CDIH'))  # taken from minimized XRD
hightempPotList.append(xrdRMSD)  # taken from minimized XRD
hightempPotList.append(tbCoord)  # lanthanide coordination potential
hightempPotList.append(pcsPots)  # experimental restraints

# low temperature: all potentials, ramped coefficients
potList, ramps = PotList(), []
potList.append(XplorPot('BOND'))
potList.append(XplorPot('ANGL'))
potList.append(XplorPot('IMPR'))
potList.append(XplorPot('VDW'))
potList.append(XplorPot('CDIH'))  # taken from minimized XRD
potList.append(xrdRMSD)  # taken from minimized XRD
potList.append(tbCoord)  # lanthanide coordination potential
potList.append(pcsPots)  # experimental restraints

ramps = [radRamp, angRamp, impRamp, vdwRamp]

# ------------------------ simulated annealing setup ------------------------- #

# set up refinement loop in cartesian space
cartesianDynamics = IVM()
protocol.cartesianTopology(cartesianDynamics)
torsionDynamics = IVM()
protocol.torsionTopology(torsionDynamics)


def anneal(loopInfo):
    global ramps

    InitialParams(ramps)
    # ---------------- ramp up experimental restraints ---------------- #
    print_header("Ramping up PCS restraints")
    protocol.initDynamics(torsionDynamics,
                          potList=hightempPotList,
                          bathTemp=init_temp,
                          initVelocities=1,
                          finalTime=expramp_duration,
                          printInterval=100)
    torsionDynamics.setETolerance(init_temp / 100)
    torsionDynamics.run()

    # ---------------- high temp dynamics ---------------- #
    print_header("Starting high temp dynamics")
    xplor.command(  # pylint: disable=undefined-variable
        "parameters nbonds repel %f end end" % radRamp.value())
    xplor.command(  # pylint: disable=undefined-variable
        "parameters nbonds rcon  %f end end" % 0.004)
    InitialParams(ramps)
    protocol.initDynamics(torsionDynamics,
                          potList=hightempPotList,
                          bathTemp=init_temp,
                          initVelocities=1,
                          finalTime=hightemp_duration,
                          printInterval=100)
    torsionDynamics.setETolerance(init_temp / 100)
    torsionDynamics.run()

    # ---------------- annealing ---------------- #
    print_header("Starting annealing")
    protocol.initDynamics(
        torsionDynamics,
        potList=potList,
        bathTemp=init_temp,
        initVelocities=1,
        finalTime=annealing_timestep,
        printInterval=500,
        stepsize=1e-3  # time that each temp step takes in ps
    )

    AnnealIVM(initTemp=init_temp,
              finalTemp=stop_temp,
              tempStep=annealing_tempstep,
              ivm=torsionDynamics,
              rampedParams=ramps).run()

    # ---------------- final polish ---------------- #
    print_header("Final torsion minization")
    protocol.initMinimize(torsionDynamics, printInterval=100)
    torsionDynamics.run()

    print_header("Final cartesic minization")
    protocol.initMinimize(cartesianDynamics,
                          potList=potList,
                          printInterval=100)
    cartesianDynamics.run()

    # ---------------- generate output ---------------- #
    print_header("Writing final structure")
    loopInfo.writeStructure(potList)


# -------------------------------- now run it -------------------------------- #

StructureLoop(numStructures=n_structures,
              pdbTemplate='STRUCTURE.pdb',
              structLoopAction=anneal,
              genViolationStats=1,
              averagePotList=potList).run()

protocol.writePDB(f'final.pdb')
