import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab


fig = plt.figure()
ax = fig.add_subplot(111)

linedata = [
    (1.58, -2.57),
    (0.35, -1.1),
    (-1.75, 2.0),
    (0.375, 2.0),
    (0.85, 1.15),
    (2.2, 3.2),
    (3, 0.05),
    (2.0, -0.5),
    (1.58, -2.57),
    ]

class LineInteractor:
    """
    An line editor.

    Key-bindings

      't' toggle vertex markers on and off.  When vertex markers are on,
          you can move them, delete them


    """

    showverts = True
    epsilon = 5  # max pixel distance to count as a vertex hit

    def __init__(self):

        self.ax = ax
        canvas = self.ax.figure.canvas

        x, y = zip(*linedata)

        self.line, = ax.plot(x,y,marker='o', markerfacecolor='r', animated=True)

        self._ind = None # the active vert

        canvas.mpl_connect('draw_event', self.draw_callback)
        canvas.mpl_connect('button_press_event', self.button_press_callback)
        canvas.mpl_connect('button_release_event', self.button_release_callback)
        canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)
        canvas.mpl_connect('key_press_event', self.key_press_callback)
        self.canvas = canvas


    def draw_callback(self, event):
        self.background = self.canvas.copy_from_bbox(self.ax.bbox)
##         self.ax.draw_artist(self.pathpatch)
        self.ax.draw_artist(self.line)
        self.canvas.blit(self.ax.bbox)
        
    def key_press_callback(self, event):
        'whenever a key is pressed'
        if not event.inaxes: return
        if event.key=='t':
            self.showverts = not self.showverts
            self.line.set_visible(self.showverts)
            if not self.showverts: self._ind = None

        self.canvas.draw()
        
    def get_ind_under_point(self, event):
        'get the index of the vertex under point if within epsilon tolerance'

        # display coords
        xy = np.asarray(zip(*self.line.get_data()) )
        xyt = self.line.get_transform().transform(xy)
        xt, yt = xyt[:, 0], xyt[:, 1]
        d = np.sqrt((xt-event.x)**2 + (yt-event.y)**2)
        ind = d.argmin()

        if d[ind]>=self.epsilon:
            ind = None

        return ind

    def button_press_callback(self, event):
        'whenever a mouse button is pressed'
        if not self.showverts: return
        if event.inaxes==None: return
        if event.button != 1: return
        self._ind = self.get_ind_under_point(event)

    def button_release_callback(self, event):
        'whenever a mouse button is released'
        if not self.showverts: return
        if event.button != 1: return
        self._ind = None

    def key_press_callback(self, event):
        'whenever a key is pressed'
        if not event.inaxes: return
        if event.key=='t':
            self.showverts = not self.showverts
            self.line.set_visible(self.showverts)
            if not self.showverts: self._ind = None

        self.canvas.draw()

    def motion_notify_callback(self, event):
        'on mouse movement'
        if not self.showverts: return
        if self._ind is None: return
        if event.inaxes is None: return
        if event.button != 1: return
        x,y = self.line.get_data()
        print x[self._ind], y[self._ind], '!=',
        x[self._ind] = event.xdata
        y[self._ind] = event.ydata
        self.line.set_data(x, y)
        x,y = self.line.get_data()
        print x[self._ind], y[self._ind], '==',
        self.canvas.restore_region(self.background)
##         self.ax.draw_artist(self.pathpatch)
        self.ax.draw_artist(self.line)
        self.canvas.blit(self.ax.bbox)
        x,y = self.line.get_data()
        print x[self._ind], y[self._ind]


interactor = LineInteractor()
ax.set_title('drag vertices to update line')
ax.set_xlim(-3,4)
ax.set_ylim(-3,4)

plt.show()


