"""
A graph with an axis break in it. Some questions I have:
- how can I hide the y-ticks in the middle of the graph?
- Can I use a transAxes transform so the axis break symbol
  (now implemented as a LineCollection)
  stays in the same position when panning? Is a LineCollection the best choice?
  The problem with a Patch class is that the verts of this symbol are not 
  connected.
- Is there a better way to make an axis break symbol?

I have to write some functions to make creation of graphs for other
datasets easier. 
"""
from scipy import *
from pylab import *
from matplotlib.patches import Rectangle
from matplotlib.collections import LineCollection

phaseshift_a = array([-62, 49.5, 44.7, 41.1, 40.9, 37.2, 37.5, 37.9, 33.4, 32.3, 30.3, 26.6, 31.5])
m_a = array([1, 3, 4, 6, 11, 21, 26, 51, 101, 201, 1001, 2001, 10001])
n_a = array([1, 2, 3, 5, 10, 20, 25, 50, 100, 200, 1000, 2000, 10000])
plotaxis_a = log(m_a*n_a)

def yaxisbreak(ymin1, ymax1, ymin2, ymax2, breakpos=0.3, fignum=1):
    figure(fignum)
    clf()
    bottom_size = 0.75*breakpos
    top_size = 0.75*(1-breakpos)
    ax1 = subplot(211, position = [0.1, 0.12501+bottom_size, 0.8, top_size])
    ax2 = subplot(212, sharex = ax1, position = [0.1, 0.125, 0.8, bottom_size])
    ax1.xaxis.set_visible(False)
    return (ax1, ax2)

ax1, ax2 = yaxisbreak(-70, -40, 25, 50)
ax1.plot(plotaxis_a[1:], phaseshift_a[1:], 'ko')
ax2.plot(plotaxis_a[0:1], phaseshift_a[0:1], 'ko')
    
x = arange(6)*5-5
x[0]=-1
y1 = array([20, 30, 40, 50])
y2 = array([-70,-50, -40])
def rads(x, pos):
    """The two args are the value and tick position"""
    return r'$%1.1f' % x

formatter = FuncFormatter(rads)
ax1.xaxis.set_major_formatter(formatter)
ax1.yaxis.set_major_formatter(formatter)
ax2.xaxis.set_major_formatter(formatter)
ax2.yaxis.set_major_formatter(formatter)
ax1.set_xticks(x)
ax2.set_xticks(x)
ax1.set_xticklabels([r'', r'$0$', r'', r'$10$', r'', r'$$'])
ax2.set_xticklabels([r'', r'$0$', r'', r'$10$', r'', r'$20$'])
ax1.set_yticks(y1)
ax2.set_yticks(y2)
ax1.set_yticklabels([r'', r'30', r'40', r'50'])
ax2.set_yticklabels([r'-70', r'-40', r''])

rect = Rectangle((-0.002, 0.98), 1.0, 0.04, transform=ax2.transAxes, fc='w', ec='w')
ax2.add_patch(rect)
rect.set_clip_on(False)
rect.set_zorder(4)
xmin, xmax = ax2.get_xlim()
ymin, ymax = ax2.get_ylim()
xspan = (xmax - xmin)*0.025
yspan = (ymax - ymin)*0.05
break_symbol = LineCollection([[(xmin-xspan, ymax-yspan), (xmin+xspan, ymax)]\
    , [(xmin-xspan, ymax), (xmin+xspan, ymax+yspan)]])
break_symbol.set_color('k')
ax2.add_collection(break_symbol)
break_symbol.set_clip_on(False)
break_symbol.set_zorder(5)

#savefig('axisbreak.png')
show()