from numpy import zeros, array
from scipy.linalg import lu_factor, lu_solve
from scipy.sparse.linalg import cg, bicg
from scipy.sparse import spdiags

def pressureMatrix(N, dx, dy, linSol):
    
    pMat = zeros([N*N,5],float)
    for i in range(0,N):
        for j in range(0,N):
            k = i * N + j
            if i==0:
                if j==0:
                    pMat[k][2]   =  2.*dx[i][j]/(dy[i][j]+dy[i+1][j])+2.*dy[i][j]/(dx[i][j]+dx[i][j+1])
                    pMat[k+1][3] = -2.*dx[i][j]/(dy[i][j]+dy[i][j+1])
                    pMat[k+N][4] = -2.*dy[i][j]/(dx[i][j]+dx[i+1][j])
                elif j==N-1:
                    pMat[k-1][1] = -2.*dy[i][j]/(dx[i][j]+dx[i][j-1])
                    pMat[k][2]   =  2.*dy[i][j]/(dx[i][j]+dx[i][j-1])+2.*dx[i][j]/(dy[i][j]+dy[i+1][j])
                    pMat[k+N][4] = -2.*dx[i][j]/(dy[i][j]+dy[i+1][j])
                else:
                    pMat[k-1][1] = -2.*dy[i][j]/(dx[i][j]+dx[i][j-1])
                    pMat[k][2]   =  2.*dy[i][j]/(dx[i][j]+dx[i][j-1])+2.*dx[i][j]/(dy[i][j]+dy[i+1][j])+2.*dy[i][j]/(dx[i][j]+dx[i][j+1])
                    pMat[k+1][3] = -2.*dx[i][j]/(dy[i][j]+dy[i][j+1])
                    pMat[k+N][4] = -2.*dy[i][j]/(dx[i][j]+dx[i+1][j])
            elif i==N-1:
                if j==0:
                    pMat[k-N][0] = -2.*dx[i][j]/(dy[i][j]+dy[i-1][j])
                    pMat[k][2]   =  2.*dx[i][j]/(dy[i][j]+dy[i-1][j])+2.*dy[i][j]/(dx[i][j]+dx[i][j+1])
                    pMat[k+1][3] = -2.*dy[i][j]/(dx[i][j]+dx[i][j+1])
                elif j==N-1:
                    pMat[k-N][0] = -2.*dy[i][j]/(dx[i][j]+dx[i-1][j])
                    pMat[k-1][1] = -2.*dx[i][j]/(dy[i][j]+dy[i][j-1])
                    pMat[k][2]   =  2.*dy[i][j]/(dx[i][j]+dx[i][j-1])+2.*dx[i][j]/(dy[i][j]+dy[i-1][j])
                else:
                    pMat[k-N][0] = -2.*dy[i][j]/(dx[i][j]+dx[i-1][j])
                    pMat[k-1][1] = -2.*dx[i][j]/(dy[i][j]+dy[i][j-1])
                    pMat[k][2]   =  2.*dy[i][j]/(dx[i][j]+dx[i][j-1])+2.*dx[i][j]/(dy[i][j]+dy[i-1][j])+2.*dy[i][j]/(dx[i][j]+dx[i][j+1])
                    pMat[k+1][3] = -2.*dy[i][j]/(dx[i][j]+dx[i][j+1])
            elif j==0:
                if i > 0 & i < N-1:
                    pMat[k-N][0] = -2.*dx[i][j]/(dy[i][j]+dy[i-1][j])
                    pMat[k][2]   =  2.*dx[i][j]/(dy[i][j]+dy[i-1][j])+2.*dx[i][j]/(dy[i][j]+dy[i+1][j])+2.*dy[i][j]/(dx[i][j]+dx[i][j+1])
                    pMat[k+1][3] = -2.*dx[i][j]/(dy[i][j]+dy[i][j+1])
                    pMat[k+N][4] = -2.*dy[i][j]/(dx[i][j]+dx[i+1][j])
            elif j==N-1:
                if i > 0 & i < N-1:
                    pMat[k-N][0] = -2.*dy[i][j]/(dx[i][j]+dx[i-1][j])
                    pMat[k-1][1] = -2.*dx[i][j]/(dy[i][j]+dy[i][j-1])
                    pMat[k][2]   =  2.*dy[i][j]/(dx[i][j]+dx[i][j-1])+2.*dx[i][j]/(dy[i][j]+dy[i-1][j])+2.*dx[i][j]/(dy[i][j]+dy[i+1][j])
                    pMat[k+N][4] = -2.*dx[i][j]/(dy[i][j]+dy[i+1][j])
            else:
                pMat[k-N][0] = -2.*dy[i][j]/(dx[i][j]+dx[i-1][j])
                pMat[k-1][1] = -2.*dx[i][j]/(dy[i][j]+dy[i][j-1])
                pMat[k][2]   =  2.*dy[i][j]/(dx[i][j]+dx[i][j-1])+2.*dx[i][j]/(dy[i][j]+dy[i-1][j])+2.*dx[i][j]/(dy[i][j]+dy[i+1][j])+2.*dy[i][j]/(dx[i][j]+dx[i][j+1])
                pMat[k+1][3] = -2.*dx[i][j]/(dy[i][j]+dy[i][j+1])
                pMat[k+N][4] = -2.*dy[i][j]/(dx[i][j]+dx[i+1][j])
    
    d = array([-N, -1, 0, 1, N])
    
    A = spdiags(-pMat.T,d,N*N,N*N)
    
    if (linSol == 1):
        A = lu_factor(A.todense())
    
    return A

def solvePoisson(A, f, N, linSol, resTol):

    """ Solve the poisson equation for pressure """
    
    if (linSol == 1):
        pressure = lu_solve(A,f)
    elif (linSol == 2):
        pressure = cg(A, f, tol=1e-15)
        pressure = pressure[0]
    elif (linSol) == 3:
        pressure = bicg(A,f,tol=1e-15)
        pressure = pressure[0]

    """ Store pressures in cells """
    
    p = zeros([N,N], float)  # Pressure matrix
    for i in range(0,N):
        for j in range(0,N):
            k = i * N + j
            p[i][j] = pressure[k]
    
    return p