import numpy as np
ma = np.ma

def shiftgrid(lon0,datain,lonsin,start=True,cyclic=360.0):
    """
    Shift global lat/lon grid east or west.
    assumes wraparound (or cyclic point) is included.

    .. tabularcolumns:: |l|L|

    ==============   ====================================================
    Arguments        Description
    ==============   ====================================================
    lon0             starting longitude for shifted grid
                     (ending longitude if start=False). lon0 must be on
                     input grid (within the range of lonsin).
    datain           original data.
    lonsin           original longitudes.
    ==============   ====================================================

    .. tabularcolumns:: |l|L|

    ==============   ====================================================
    Keywords         Description
    ==============   ====================================================
    start            if True, lon0 represents the starting longitude
                     of the new grid. if False, lon0 is the ending
                     longitude. Default True.
    ==============   ====================================================

    returns ``dataout,lonsout`` (data and longitudes on shifted grid).
    """
    if np.fabs(lonsin[-1]-lonsin[0]-cyclic) > 1.e-4:
        # Use all data instead of raise ValueError, 'cyclic point not included'
        start_idx = 0
    else:
        # If cyclic, remove the duplicate point
        start_idx = 1
    if lon0 < lonsin[0] or lon0 > lonsin[-1]:
        raise ValueError, 'lon0 outside of range of lonsin'
    i0 = np.argmin(np.fabs(lonsin-lon0))
    i0_shift = len(lonsin)-i0
    if hasattr(datain,'mask'):
        dataout  = ma.zeros(datain.shape,datain.dtype)
    else:
        dataout  = np.zeros(datain.shape,datain.dtype)
    if hasattr(lonsin,'mask'):
        lonsout = ma.zeros(lonsin.shape,lonsin.dtype)
    else:
        lonsout = np.zeros(lonsin.shape,lonsin.dtype)
    if start:
        lonsout[0:i0_shift] = lonsin[i0:]
    else:
        lonsout[0:i0_shift] = lonsin[i0:]-cyclic
    dataout[:,0:i0_shift] = datain[:,i0:]
    if start:
        lonsout[i0_shift:] = lonsin[start_idx:i0+start_idx]+cyclic
    else:
        lonsout[i0_shift:] = lonsin[start_idx:i0+start_idx]
    dataout[:,i0_shift:] = datain[:,start_idx:i0+start_idx]
    return dataout,lonsout
    
def run_test(lon, grid, lonout, gridout):
    lon  = np.asarray(lon, dtype=float)
    grid = np.asarray(grid, dtype=float)    
    lonout = np.asarray(lonout, dtype=float)
    gridout = np.asarray(gridout, dtype=float)
    
    testgrid, testlon = shiftgrid(180, grid, lon, start=False)
    
    assert (testlon==lonout).all()
    assert (testgrid==gridout).all()


def test_cyclic():
    lon  =  [0, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330, 360]
    grid = [[0,  1,  2,  3,   4,   5,   6,   7,   8,   9,  10,  11,   0]]
    
    lonout  =  [-180, -150, -120, -90, -60, -30, 0, 30, 60, 90, 120, 150, 180]
    gridout = [[   6,    7,   8,    9,  10,  11, 0,  1,  2,  3,   4,   5,   6]]
    
    run_test(lon, grid, lonout, gridout)
    
def test_no_cyclic():
    lon  =  [0, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330]
    grid = [[0,  1,  2,  3,   4,   5,   6,   7,   8,   9,  10,  11]]
    
    lonout  =  [-180, -150, -120, -90, -60, -30, 0, 30, 60, 90, 120, 150]
    gridout = [[   6,    7,   8,    9,  10,  11, 0,  1,  2,  3,   4,   5]]
    
    run_test(lon, grid, lonout, gridout)

if __name__ == '__main__':
    test_cyclic()
    test_no_cyclic()
