# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.6
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# # Current visualization
# In this notebook we adress current visualization in a 3D (nanowire) system using Kwant. 
# A minimal example is provided adressing some issues with `kwant.plotter.current`. 

import kwant
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# +
# lets define a simple system, with zero onsite and (next) nearest neighbor hoppings
lat = kwant.lattice.cubic(norbs=1)
syst = kwant.Builder(kwant.TranslationalSymmetry((1,1,0), (1,0,1), (0,1,1)))

syst[lat(0, 0, 0)] = 0
syst[lat(0, 0, 1)] = 0

syst[lat.neighbors(1)] = 1    
syst[lat.neighbors(2)] = 0.5

# +
L = 15  
W = 10    

def shape(site):
    x, y, z = site.pos
    return (-W//2 <= x <= W//2) and (-W//2 <= y <= W//2) and (-L//2 <= z <= L//2)


# +
nanowire = kwant.Builder()
nanowire.fill(syst, shape, (0, 0, 0))

lead_0 = kwant.Builder(kwant.TranslationalSymmetry((0, 0, 2)))
lead_0.fill(syst, shape, start=(0, 0, 0))
nanowire.attach_lead(lead_0)
nanowire.attach_lead(lead_0.reversed())

nanowire = nanowire.finalized()
lead_0 = lead_0.finalized()

sysf = nanowire

# +
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d') 
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_box_aspect(None, zoom=0.85)

kwant.plot(sysf, ax=ax, num_lead_cells=4, show=True)
plt.show()
# -

# ## Longitudinal transport

# Lets compute the wavefunction in the scattering region for states coming from lead 0 at an energy of $E = 0$

wf = kwant.wave_function(sysf, energy=0)
wf_from_lead_0 = wf(0)


# Current visualization for a 3D system is not supported by Kwant. To still get an impression of how the current is distributed through the nanowire, we can try to visualize the current on a single facet (which is 2D).

# +
def longitudinal_intra_layer_cut(site_to, site_from):
    return site_from.pos[1] == -W//2 and site_to.pos[1] == -W//2 

J_operator = kwant.operator.Current(sysf, where=longitudinal_intra_layer_cut)
J = sum(J_operator(psi) for psi in wf_from_lead_0)


# -

# To be able to use `kwant.plotter.current`, Kwant needs a system to know which $J$ values belong to which hoppings. Using the original nanowire system `sysf` as an input does not work, as we are looking at a subset of the system (i.e. the single facet of the nanowire) when using the `where` argument.
# We define a new (auxiliary) system `plotf` for the facet and use this to plot the current. 

# +
def shape_plot_longitudinal_cut(site):
    x, y = site.pos
    return (-W//2 <= x <= W//2) and  (-L//2 <= y <= L//2)

lat = kwant.lattice.square(norbs=1)
plot_syst = kwant.Builder(kwant.TranslationalSymmetry((0, 1), (1, 0)))
plot_syst[lat(0, 0)] = 1
plot_syst[lat.neighbors(1)] = 1    
plot_syst[lat.neighbors(2)] = 0.5

plot = kwant.Builder()
plot.fill(plot_syst, shape_plot_longitudinal_cut, (0, 0))
plotf = plot.finalized()
# -

kwant.plotter.current(plotf, J)

# The result is troubling as the nanowire is homogeneous (leads and scattering regions are the same), so we would expect a homogeneous current flow through the facet of the nanowire. From the plot, there seems to be an increased current density near the leads, as well as an asymmetry between the leads. This cannot be the case in a homogeneous nanowire.
