import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

fig = plt.figure()
ax = Axes3D(fig)

# create supporting points in polar coordinates
r = np.linspace(0, 1.25, 50)
phi = np.linspace(0, 2*np.pi, 50)
R, PHI = np.meshgrid(r,phi)

# transform them to cartesian system
X,Y = R * np.cos(PHI), R * np.sin(PHI)

Z = ((R**2 - 1)**2)
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet)
ax.set_zlim3d(0, 1)
ax.set_xlabel(r'$\phi_\mathrm{real}$')
ax.set_ylabel(r'$\phi_\mathrm{im}$')
ax.set_zlabel(r'$V(\phi)$')
ax.set_xticks([])
plt.show()
