import os
import subprocess
import matplotlib.pyplot as plt

class Animation(object):
    '''
    This class wraps the creation of an animation using matplotlib.

    frame_info is a list, with each list item being informatino for a single
    frame.
    '''
    def __init__(self, frame_info):
        self._frame_info = frame_info
        self._init_anim()
        self._set_event_handler()

    def _init_anim(self):
        pass

    def _set_event_handler(self):
        '''
        Sets up the proper idle handler to use to update frames. Right now
        hardcoded to GTK, but should be able to use others.
        '''
        import gobject
        self._idle_add = gobject.timeout_add

    def run(self, interval=0.2, repeat=True, show=True):
        '''
        Starts interactive animation. Adds the draw frame command to the GUI
        handler, calls show to start the event loop.
        '''
        self._repeat = repeat
        self._cur_frame = 0
        self._idle_add(int(interval * 1000), self._iter_loop)
        if show:
            plt.show()

    def save(self, filename, fps=5, codec='mpeg4', clear_temp=True,
        frame_prefix='_tmp'):
        '''
        Saves a movie file by drawing every frame.
        '''

        fnames = []
        for idx in range(len(self._frame_info)):
            self._draw_frame(idx)
            fname = '%s%04d.png' % (frame_prefix, idx)
            fnames.append(fname)
            plt.savefig(fname)

        self._make_movie(filename, fps, codec, frame_prefix)

        #Delete temporary files
        if clear_temp:
            for fname in fnames:
                os.remove(fname)

    def _make_movie(self, fname, fps, codec, frame_prefix):
#        subprocess.call(['mencoder', 'mf://%s*.png' % frame_prefix, '-mf',
#            'type=png:fps=%d' % fps, '-ovc', 'lavc', '-lavcopts',
#             'vcodec=%s' % codec, '-oac', 'copy', '-o', filename])
        subprocess.call(['ffmpeg', '-r', str(fps), '-b', '1800k', '-i',
            '%s%%04d.png' % frame_prefix, fname])

    def _iter_loop(self, *args):
        '''
        Animation helper that iterates the animation loop. This tracks the
        current frame so that it can play it again in the GUI.
        '''
        self._draw_frame(self._cur_frame)

        #Repeat until we have exhausted all of the data in the sequence
        self._cur_frame += 1
        if self._cur_frame < len(self._frame_info):
            return True
        elif self._repeat:
            self._cur_frame = 0
            return True
        else:
            return False

    def _draw_frame(self, num):
        raise NotImplementedError('Needs to be implemented by subclasses to'
            ' actually make an animation.')

class ArtistAnimation(Animation):
    '''
    Before calling this function, all plotting should have taken place
    and the relevant artists saved.
    
    frame_info is a list, with each list entry a collection of artists that
    represent what needs to be enabled on each frame. These will be disabled
    for other frames.
    '''
    def _init_anim(self):
        # Make all the artists involved in *any* frame invisible
        for f in self._frame_info:
            for artist in f:
                artist.set_visible(False)

    def _draw_frame(self, num):
        '''
        Performs actual drawing up the frame 'num'. Clears artists from
        the last frame.
        '''
        # Make all the artists from the previous frame invisible
        prev_num = (num - 1) % len(self._frame_info)
        for artist in self._frame_info[prev_num]:
            artist.set_visible(False)

        # Make all the artists from the current frame visible
        for artist in self._frame_info[num]:
            artist.set_visible(True)

        # Call for redraw
        plt.draw()

class FuncAnimation(Animation):
    '''
    Makes an animation by repeatedly calling a function with args.
    
    frame_info is used to control number of frames and can contain
    per frame data.  The whole of the data, the frame number, and 
    any extra args are passed to func.
    '''
    def __init__(self, frame_info, func, *args):
        self._args = args
        self._func = func
        Animation.__init__(self, frame_info)

    def _draw_frame(self, num):
        self._func(num, self._frame_info, *self._args)
        plt.draw()

if __name__ == '__main__':
    import numpy as np

    def update_line(num, data, line):
        line.set_data(data[:num].T)

    data = np.random.rand(2, 25)
    l, = plt.plot(data[0], data[1], 'r-')
    plt.xlim(0,1)
    plt.ylim(0,1)
    plt.xlabel('x')
    plt.title('test')
    line_ani = FuncAnimation(data.T, update_line, l)

    fig2 = plt.figure()
    x = np.arange(-9, 10)
    y = np.arange(-9, 10).reshape(-1, 1)
    base = np.hypot(x, y)
    ims = []
    for add in np.arange(15):
        ims.append((plt.pcolor(x, y, base+add, norm=plt.Normalize(0,30)),))
    im_ani = ArtistAnimation(ims)

    im_ani.run(show=False)
    line_ani.run()
