# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.6
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# # Current visualization
# In this notebook we adress current visualization in a 3D (nanowire) system using Kwant. 
# If the ordering of the hoppings is different in the 2D system compared to the original cut through the 3D system, `kwant.plotter.current` gives incorrect results. Here we give methods to get the correct ordering and check if the ordering is right.

import kwant
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# +
L = 15
W = 10   

def shape(site):
    x, y, z = site.pos
    return (-W//2 <= x <= W//2) and (-W//2 <= y <= W//2) and (-L//2 <= z <= L//2)

def sysf_FCC_structure():

    lat = kwant.lattice.cubic(norbs=1)
    syst = kwant.Builder(kwant.TranslationalSymmetry((2,0,0), (0,2,0), (0,0,2)))
    
    syst[lat(0, 0, 0)] = 0
    syst[lat(0, 1, 0)] = 0
    syst[lat(0, 0, 1)] = 0
    syst[lat(1, 0, 0)] = 0
    syst[lat(0, 1, 1)] = 0
    syst[lat(1, 0, 1)] = 0
    syst[lat(1, 1, 0)] = 0
    syst[lat(1, 1, 1)] = 0

    syst[lat.neighbors(1)] = 1    
    syst[lat.neighbors(2)] = 0.5

    nanowire = kwant.Builder()
    nanowire.fill(syst, shape, (0, 0, 0))
    
    lead_0 = kwant.Builder(kwant.TranslationalSymmetry((0, 0, 2)))
    lead_0.fill(syst, shape, start=(0, 0, 0))
    nanowire.attach_lead(lead_0)
    nanowire.attach_lead(lead_0.reversed())
    
    nanowire = nanowire.finalized()
    lead_0 = lead_0.finalized()

    return nanowire

def sysf_cubic_structure():

    lat = kwant.lattice.cubic(norbs=1)
    syst = kwant.Builder(kwant.TranslationalSymmetry((1,0,0), (0,1,0), (0,0,1)))

    syst[lat(0, 0, 0)] = 0
    syst[lat.neighbors(1)] = 1    
    syst[lat.neighbors(2)] = 0.5

    nanowire = kwant.Builder()
    nanowire.fill(syst, shape, (0, 0, 0))
    
    lead_0 = kwant.Builder(kwant.TranslationalSymmetry((0, 0, 2)))
    lead_0.fill(syst, shape, start=(0, 0, 0))
    nanowire.attach_lead(lead_0)
    nanowire.attach_lead(lead_0.reversed())
    
    nanowire = nanowire.finalized()
    lead_0 = lead_0.finalized()

    return nanowire


# +
sysf_FCC_build = sysf_FCC_structure()
sysf_cubic_build = sysf_cubic_structure()

print(sysf_FCC_build)
print(sysf_cubic_build)

print(list(sysf_FCC_build.sites) == list(sysf_cubic_build.sites))
print(list(sysf_FCC_build.graph) == list(sysf_cubic_build.graph))
# -

# These systems look exactly the same. Only the hoppings are ordered differently giving rise to a wrong current visualization plot, as the `J` values are mapped to the wrong hoppings within the system. 
# Ordering the hoppings manually shows that indeed the hoppings are the same, but ordered differently.

# +
fcc_hoppings_sorted = sorted(list(sysf_FCC_build.graph))
cubic_hoppings_sorted = sorted(list(sysf_cubic_build.graph))

print(fcc_hoppings_sorted == cubic_hoppings_sorted) 
# -


# ## Longitudinal transport

# ### Cubic system

# Lets build a nanowire, first using the simple cubic structure and compute its scattering wavefunction in the scattering region for states coming from lead 0 at an energy of $E = 0$

# +
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d') 
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_box_aspect(None, zoom=0.85)

kwant.plot(sysf_cubic_build, ax=ax, num_lead_cells=4, show=True)
plt.show()
# -

wf = kwant.wave_function(sysf_cubic_build, energy=0)
wf_from_lead_0 = wf(0)


# Current visualization for a 3D system is not supported by Kwant. To still get an impression of how the current is distributed through the nanowire, we can try to visualize the current on a single facet (which is 2D).

# +
def longitudinal_intra_layer_cut(site_to, site_from):
    return site_from.pos[1] == W//2 and site_to.pos[1] == W//2 

J_operator = kwant.operator.Current(sysf_cubic_build, where=longitudinal_intra_layer_cut)
J = sum(J_operator(psi) for psi in wf_from_lead_0)


# -

# To be able to use `kwant.plotter.current`, Kwant needs a system to know which $J$ values belong to which hoppings. Using the original nanowire system `sysf` as an input does not work, as we are looking at a subset of the system (i.e. the single facet of the nanowire) when using the `where` argument.
# We define a new (auxiliary) system `plotf` for the facet and use this to plot the current. 

# +
def shape_plot_longitudinal_cut(site):
    x, y = site.pos
    return (-W//2 <= x <= W//2) and  (-L//2 <= y <= L//2)

def sysf_cubic_structure_2D():

    lat = kwant.lattice.square(norbs=1)
    plot_syst = kwant.Builder(kwant.TranslationalSymmetry((1, 0), (0, 1)))
    plot_syst[lat(0, 0)] = 1
    plot_syst[lat.neighbors(1)] = 1    
    plot_syst[lat.neighbors(2)] = 0.5
    
    plot = kwant.Builder()
    plot.fill(plot_syst, shape_plot_longitudinal_cut, (0, 0))
    plotf = plot.finalized()

    return plotf


# -

plotf = sysf_cubic_structure_2D()
kwant.plot(plotf);

kwant.plotter.current(plotf, J, relwidth = 0.1)

# ### FCC system

# +
wf = kwant.wave_function(sysf_FCC_build, energy=0)
wf_from_lead_0 = wf(0)

J_operator = kwant.operator.Current(sysf_FCC_build, where=longitudinal_intra_layer_cut)
J = sum(J_operator(psi) for psi in wf_from_lead_0)


# -

def sysf_FCC_structure_2D():

    lat = kwant.lattice.square(norbs=1)
    plot_syst = kwant.Builder(kwant.TranslationalSymmetry((2, 0), (0, 2)))
    plot_syst[lat(0, 0)] = 1
    plot_syst[lat(0, 1)] = 1
    plot_syst[lat(1, 0)] = 1
    plot_syst[lat(1, 1)] = 1
    plot_syst[lat.neighbors(1)] = 1    
    plot_syst[lat.neighbors(2)] = 0.5
    
    plot = kwant.Builder()
    plot.fill(plot_syst, shape_plot_longitudinal_cut, (0, 0))
    plotf = plot.finalized()

    return plotf


plotf = sysf_FCC_structure_2D()
kwant.plot(plotf);

kwant.plotter.current(plotf, J, relwidth = 0.1)


# +
def extract_hopping_positions(sysf, hopping_indices, normal_vector = None):
    
    pos_from_array, pos_to_array = [], []
    
    drop_axis = None
    if normal_vector is not None:
        n = np.array(normal_vector, dtype=float)
        drop_axis = int(np.argmax(np.abs(n)))

    for hopping_idx in hopping_indices:
        site_from_idx, site_to_idx = list(sysf.graph)[hopping_idx]
        site_from = sysf.sites[site_from_idx]
        site_to = sysf.sites[site_to_idx]

        pos_from = np.array(site_from.pos)
        pos_to = np.array(site_to.pos)

        if drop_axis is not None:
            pos_from = np.delete(pos_from, drop_axis)
            pos_to = np.delete(pos_to, drop_axis)

        pos_from_array.append(pos_from)
        pos_to_array.append(pos_to)

    return np.array(pos_from_array), np.array(pos_to_array)


# This function checks if the ordering of the hoppings for the 2D system (sysf_2D_proj) is the same as for the 3D system (sysf)
def check_ordering(sysf, sysf_2D_proj, layer, normal_vector):

    #### for the 2D projection system

    subset_hoppings_2D_proj = [i for i, _ in enumerate(sysf_2D_proj.graph)]
    print("Number of hoppings in the 2D projection subset:", len(subset_hoppings_2D_proj))

    pos_from_2D_proj, pos_to_2D_proj = extract_hopping_positions(sysf_2D_proj, subset_hoppings_2D_proj)

    #### for the 3D system

    def intra_longitudinal_layer(site_to_idx, site_from_idx, normal_vector, layer):
        drop_axis = int(np.argmax(np.abs(normal_vector)))
        
        site_from = sysf.sites[site_from_idx]
        site_to = sysf.sites[site_to_idx]
        
        return (np.isclose(site_from.pos[drop_axis], layer) and np.isclose(site_to.pos[drop_axis], layer))

    subset_hoppings = [i for i, (site_from_idx, site_to_idx) in enumerate(sysf.graph) if intra_longitudinal_layer(site_to_idx, site_from_idx, normal_vector, layer)]
    print("Number of hoppings in the 3D system subset:", len(subset_hoppings))

    pos_from_3D, pos_to_3D = extract_hopping_positions(sysf, subset_hoppings, normal_vector)

    if (np.allclose(pos_from_3D, pos_from_2D_proj) and np.allclose(pos_to_3D, pos_to_2D_proj)):
        print("The 3D system and 2D projection have the same ordering!")
        return True
    else:
        raise ValueError("The ordering between the 3D and 2D systems are not the same. Check if the translational symmetries of the systems match.")


# -

# The following code checks if a 2D cut through the 3D FCC system has the same ordering as the 2D cubic system.

check_ordering(sysf_FCC_structure(), sysf_FCC_structure_2D(), layer = 0, normal_vector = [0, 1, 0])

# The following code raises an error as a 2D cut through the 3D FCC system has different ordering than the 2D cubic system.

check_ordering(sysf_FCC_structure(), sysf_cubic_structure_2D(), layer = 0, normal_vector = [0, 1, 0])


