import numpy
import scipy.interpolate
import matplotlib.pyplot as plt


def polynomial(x):
    """
    HIgh level dummy polynomial which represents very closely
    the data I am working with.
    """

    coeffs = [-1.29804361e-22, 1.57325809e-17, -7.55023475e-13,
              1.82307416e-08, -2.31711156e-04, 1.45808228e+00,
              -1.50272586e+02]

    y = numpy.polyval(coeffs, x)
    return y


def get_normal_points(cx, cy, cos_t, sin_t, length, scaling):
    """
    For a line passing through (*cx*, *cy*) and having a angle *t*,
    return locations of the two points located along its perpendicular
    line at the distance of *length*.
    """

    if length == 0.:
        return cx, cy, cx, cy

    cos_t1, sin_t1 = sin_t, -cos_t
    cos_t2, sin_t2 = -sin_t, cos_t

    x1, y1 = length*cos_t1 + cx, length*sin_t1*scaling + cy
    x2, y2 = length*cos_t2 + cx, length*sin_t2*scaling + cy

    return x1, y1, x2, y2


def main():

    # Create the X, Y pairs of data for the original curve
    x = numpy.linspace(0, 36000, 1000)
    y = polynomial(x)

    # Interpolate with splines
    tck = scipy.interpolate.splrep(x, y)
    # Get splines derivatives
    y_deriv = scipy.interpolate.splev(x, tck, der=1)
    # Calculate the angle for every interpolated segment
    t = numpy.arctan(y_deriv)

    cos_t, sin_t = numpy.cos(t), numpy.sin(t)

    # These 5 numbers are fixed... I can't change them
    distance = 800
    xlims = (0., 40000.)
    ylims = (1500., 4000.)

    # Get the normals at every point
    scaling = (ylims[1] - ylims[0])/xlims[1]
    xhigh, yhigh, xlow, ylow = get_normal_points(x, y, cos_t, sin_t, distance, scaling)

    # Plotting stuff
    fig = plt.figure()
    ax = plt.subplot(111)

    ax.plot(x, y, color='orange', ls='-', lw=3, label='data', zorder=30)

    ax.plot(xlow, ylow, 'b-', label='low', zorder=30)
    ax.plot(xhigh, yhigh, 'g-', label='high', zorder=30)

    ax.set_xlim(xlims)
    ax.set_ylim(ylims)
    ax.grid()

    ax.invert_yaxis()
    ax.legend()

    ax.annotate('LARGER',
                xy=(1860, 1660), xycoords='data',
                xytext=(5000, 1660), textcoords='data',
                ha='center', va='center',
                arrowprops=dict(arrowstyle='->',
                                connectionstyle='arc3'),
                )

    ax.annotate('SMALLER',
                xy=(5000, 3160), xycoords='data',
                xytext=(7000, 2900), textcoords='data',
                ha='center', va='center',
                arrowprops=dict(arrowstyle='->',
                                connectionstyle='arc3'),
                )

    plt.show()


if __name__ == '__main__':
    main()

    