import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Some landcover classifications
d = np.array([[21, 21, 41, 41, 21, 41, 41, 41, 41, 41],
              [41, 41, 21, 41, 41, 41, 11, 11, 11, 11],
              [41, 11, 11, 11, 11, 11, 11, 11, 11, 11],
              [11, 11, 11, 11, 11, 11, 11, 11, 11, 11],
              [11, 11, 11, 11, 11, 11, 11, 11, 11, 11],
              [11, 11, 11, 11, 11, 11, 11, 11, 11, 11],
              [11, 11, 11, 11, 11, 11, 11, 11, 11, 11],
              [11, 11, 11, 11, 11, 11, 41, 41, 41, 41],
              [11, 41, 41, 41, 41, 41, 41, 41, 41, 41],
              [41, 41, 41, 41, 41, 41, 41, 41, 41, 41]])

# Color table from national land cover dataset
ctable = dict(((0, (0.0, 0.0, 0.0, 1.0)),
    (11, (0.27843137254901962, 0.41960784313725491, 0.62745098039215685, 1.0)),
    (12, (0.81960784313725488, 0.8666666666666667, 0.97647058823529409, 1.0)),
    (21, (0.8666666666666667, 0.78823529411764703, 0.78823529411764703, 1.0)),
    (22, (0.84705882352941175, 0.57647058823529407, 0.50980392156862742, 1.0)),
    (23, (0.92941176470588238, 0.0, 0.0, 1.0)),
    (24, (0.66666666666666663, 0.0, 0.0, 1.0)),
    (31, (0.69803921568627447, 0.67843137254901964, 0.63921568627450975, 1.0)),
    (41, (0.40784313725490196, 0.66666666666666663, 0.38823529411764707, 1.0)),
    (42, (0.10980392156862745, 0.38823529411764707, 0.18823529411764706, 1.0)),
    (43, (0.70980392156862748, 0.78823529411764703, 0.55686274509803924, 1.0)),
    (51, (0.6470588235294118, 0.5490196078431373, 0.18823529411764706, 1.0)),
    (52, (0.80000000000000004, 0.72941176470588232, 0.48627450980392156, 1.0)),
    (71, (0.88627450980392153, 0.88627450980392153, 0.75686274509803919, 1.0)),
    (72, (0.78823529411764703, 0.78823529411764703, 0.46666666666666667, 1.0)),
    (73, (0.59999999999999998, 0.75686274509803919, 0.27843137254901962, 1.0)),
    (74, (0.46666666666666667, 0.67843137254901964, 0.57647058823529407, 1.0)),
    (81, (0.85882352941176465, 0.84705882352941175, 0.23921568627450981, 1.0)),
    (82, (0.66666666666666663, 0.4392156862745098, 0.15686274509803921, 1.0)),
    (90, (0.72941176470588232, 0.84705882352941175, 0.91764705882352937, 1.0)),
    (95, (0.4392156862745098, 0.63921568627450975, 0.72941176470588232, 1.0))))

# create a lookup table where any index missing from *ctable* defaults
# to transparent white.
index_colors = [ctable[key] if key in ctable else
                (1.0, 1.0, 1.0, 0.0) for key in xrange(96)]
cm = ListedColormap(index_colors, 'nlcd', 96)

# Passing the colormap object in.... bad
fig = plt.figure()
ax = fig.add_subplot(1, 2, 1)
ax.imshow(d, cmap=cm, interpolation='none')
ax.set_title('Bad')

# Passing in a MxNx4 array... good!
ax = fig.add_subplot(1, 2, 2)
ax.imshow(cm(d), interpolation='none')
ax.set_title('Good')

plt.show()
