import matplotlib.pyplot as plt

import numpy as np


class BaseFilter(object):
    def prepare_image(self, src_image, dpi, pad):
        ny, nx, depth = src_image.shape
        padded_src = np.zeros([pad*2+ny, pad*2+nx, depth], dtype="d")
        padded_src[pad:-pad, pad:-pad,:] = src_image[:,:,:]

        return padded_src#, tgt_image

    def get_pad(self, dpi):
        return 0

    def __call__(self, im, dpi):
        pad = self.get_pad(dpi)
        padded_src = self.prepare_image(im, dpi, pad)
        tgt_image = self.process_image(padded_src, dpi)
        return tgt_image, -pad, -pad



class GradientFilter(BaseFilter):
    def __init__(self, cmap=None, pad=3):
        if cmap is None:
            self._cmap=plt.get_cmap()
        self._pad = pad

    def get_pad(self, dpi):
        return int(self._pad*3)


    def process_image(self, padded_src, dpi):

        rgb = padded_src[:,:,:3]

        tgt_image = np.empty_like(padded_src)
        y = np.linspace(0, 1, tgt_image.shape[0])
        tgt_image[:,:,:] = self._cmap(y)[:,np.newaxis,:]
        tgt_image[:,:,-1] *= padded_src[:,:,3]

        return tgt_image



def gradient_line(ax):
    # draw lines
    l1, = ax.plot([0.1, 0.5, 0.9], [0.1, 0.9, 0.5], "bo-",
                  mec="b", mfc="w", lw=5, mew=3, ms=10, label="Line 1")

    grad = GradientFilter()

    l1.set_agg_filter(grad)
    l1.set_rasterized(True) # to support mixed-mode renderers

    ax.set_xlim(0., 1.)
    ax.set_ylim(0., 1.)

    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)

if 1:

    plt.figure(1, figsize=(6, 6))
    plt.subplots_adjust(left=0.05, right=0.95)

    ax = plt.subplot(111)
    gradient_line(ax)

    plt.show()


