import matplotlib as mpl
mpl.use('tkagg')  # Use an interactive backend for this example.
import matplotlib.pyplot as plt
import numpy as np

def multi_scatter(x, y, z, color_dict, ax=None):
    if ax is None:
        ax = plt.gca()
    xx = {}
    yy = {}
    for color_key in z:
        xx[color_key] = x[z == color_key]
        yy[color_key] = y[z == color_key]
        ax.scatter(xx[color_key], yy[color_key], c=color_dict[color_key])

def test():
    x = np.array([1, 2, 3, 4])
    y = np.array([2, 3, 4, 5])
    z = np.array([0, 1, 0, 1])
    color_dict = {0: 'b', 1:'r'}
    multi_scatter(x, y, z, color_dict)
    plt.show()

if __name__ == '__main__':
    test()
