mdiggory 2003/06/15 10:01:39 Modified: math/src/java/org/apache/commons/math RealMatrix.java RealMatrixImpl.java math/src/test/org/apache/commons/math RealMatrixImplTest.java Log: PR: http://nagoya.apache.org/bugzilla/show_bug.cgi?id=20783 Submitted by: [EMAIL PROTECTED] Revision Changes Path 1.4 +105 -58 jakarta-commons-sandbox/math/src/java/org/apache/commons/math/RealMatrix.java Index: RealMatrix.java =================================================================== RCS file: /home/cvs/jakarta-commons-sandbox/math/src/java/org/apache/commons/math/RealMatrix.java,v retrieving revision 1.3 retrieving revision 1.4 diff -u -r1.3 -r1.4 --- RealMatrix.java 11 Jun 2003 14:50:29 -0000 1.3 +++ RealMatrix.java 15 Jun 2003 17:01:39 -0000 1.4 @@ -61,176 +61,223 @@ */ public interface RealMatrix { - public RealMatrix copy(); + /** + * Returns a (deep) copy of this. + * + * @return matrix copy + */ + RealMatrix copy(); /** - * Compute the sum of *this and m + * Compute the sum of this and m. + * * @param m matrix to be added * @return this + m - * @exception IllegalArgumentException if m is not the same size as *this + * @exception IllegalArgumentException if m is not the same size as this */ - public RealMatrix add(RealMatrix m); + RealMatrix add(RealMatrix m) throws IllegalArgumentException; /** - * Compute *this minus m + * Compute this minus m. + * * @param m matrix to be subtracted * @return this + m - * @exception IllegalArgumentException if m is not the same size as *this + * @exception IllegalArgumentException if m is not the same size as this */ - public RealMatrix subtract(RealMatrix m); + RealMatrix subtract(RealMatrix m) throws IllegalArgumentException; /** - * Returns the rank of the matrix - * @return the rank of this matrix + * Returns the rank of the matrix. + * + * @return the rank of this matrix */ - public int getRank(); + int getRank(); /** - * Returns the result of adding d to each entry of *this + * Returns the result of adding d to each entry of this. + * * @param d value to be added to each entry * @return d + this */ - public RealMatrix scalarAdd(double d); + RealMatrix scalarAdd(double d); /** - * Returns the result multiplying each entry of *this by d + * Returns the result multiplying each entry of this by d. + * * @param d value to multiply all entries by - * @return d*this + * @return d * this */ - public RealMatrix scalarMultiply(double d); + RealMatrix scalarMultiply(double d); /** - * Returns the result postmultiplyin *this by m + * Returns the result postmultiplying this by m. + * * @param m matrix to postmultiply by - * @return this*m + * @return this * m * @throws IllegalArgumentException * if columnDimension(this) != rowDimension(m) */ - public RealMatrix multiply(RealMatrix m); + RealMatrix multiply(RealMatrix m) throws IllegalArgumentException; /** - * Returns matrix entries as a two-dimensional array + * Returns matrix entries as a two-dimensional array. + * * @return 2-dimensional array of entries */ - public double[][] getData(); + double[][] getData(); /** - * Sets/overwrites the underlying data for the matrix + * Overwrites the underlying data for the matrix with + * a fresh copy of <code>data</code>. + * * @param data 2-dimensional array of entries */ - public void setData(double[][] data); + void setData(double[][] data); /** - * Returns the norm of the matrix + * Returns the <a href="http://mathworld.wolfram.com/ + * MaximumAbsoluteRowSumNorm.html">maximum absolute row sum norm</a> + * of the matrix. + * * @return norm */ - public double getNorm(); + double getNorm(); /** - * Returns entries in row as an array - * @param row the row to be fetched - * @return array of entries in the row - * @throws IllegalArgumentException if row > rowDimension + * Returns the entries in row number <code>row</code> as an array. + * + * @param row the row to be fetched + * @return array of entries in the row + * @throws IllegalArgumentException if row > rowDimension */ - public double[] getRow(int row); + double[] getRow(int row) throws IllegalArgumentException; /** - * Returns entries in column as an array + * Returns the entries in column number <code>col</code> as an array. + * * @param col column to fetch - * @return array of entries in the column - * @throws IllegalArgumentException if column > columnDimension + * @return array of entries in the column + * @throws IllegalArgumentException if column > columnDimension */ - public double[] getColumn(int col); + double[] getColumn(int col) throws IllegalArgumentException; /** - * Returns the entry in the specified row and column + * Returns the entry in the specified row and column. + * * @param row row location of entry to be fetched * @param column column location of entry to be fetched * @return matrix entry in row,column * @throws IllegalArgumentException if entry does not exist */ - public double getEntry(int row, int column); + double getEntry(int row, int column) throws IllegalArgumentException; /** - * Sets the entry in the specified row and column to the specified value + * Sets the entry in the specified row and column to the specified value. + * * @param row row location of entry to be set * @param column column location of entry to be set * @param value value to set * @throws IllegalArgumentException if entry does not exist */ - public void setEntry(int row, int column, double value); + void setEntry(int row, int column, double value) + throws IllegalArgumentException; /** - * Returns the transpose of this matrix + * Returns the transpose of this matrix. + * * @return transpose matrix */ - public RealMatrix transpose(); + RealMatrix transpose(); /** - * Returns the inverse of this matrix + * Returns the inverse of this matrix. + * * @return inverse matrix * @throws IllegalArgumentException if *this is not invertible */ - public RealMatrix inverse(); + RealMatrix inverse() throws IllegalArgumentException; /** - * Returns the determinant of this matrix + * Returns the determinant of this matrix. + * * @return determinant */ - public double getDeterminant(); + double getDeterminant(); /** * Is this a square matrix? * @return true if the matrix is square (rowDimension = columnDimension) */ - public boolean isSquare(); + boolean isSquare(); /** * Is this a singular matrix? * @return true if the matrix is singular */ - public boolean isSingular(); + boolean isSingular(); /** - * Returns the number of rows in the matrix + * Returns the number of rows in the matrix. + * * @return rowDimension */ - public int getRowDimension(); + int getRowDimension(); /** - * Returns the number of columns in the matrix + * Returns the number of columns in the matrix. + * * @return columnDimension */ - public int getColumnDimension(); + int getColumnDimension(); /** - * Returns the trace of the matrix + * Returns the <a href="http://mathworld.wolfram.com/MatrixTrace.html"> + * trace</a> of the matrix (the sum of the elements on the main diagonal). + * * @return trace */ - public double getTrace(); + double getTrace(); /** - * Returns the result of multiplying this by vector v + * Returns the result of multiplying this by the vector <code>v</code>. + * + * @param v the vector to operate on * @return this*v * @throws IllegalArgumentException if columnDimension != v.size() */ - public double[] operate(double[] v); + double[] operate(double[] v) throws IllegalArgumentException; /** - * Returns the result of premultiplying this by vector v + * Returns the result of premultiplying this by the vector <code>v</code>. + * + * @param v the row vector to premultiply by * @return v*this * @throws IllegalArgumentException if rowDimension != v.size() */ - public RealMatrix preMultiply(double[] v); + RealMatrix preMultiply(double[] v) throws IllegalArgumentException; /** * Returns the solution vector for a linear system with coefficient - * matrix = *this and constant vector = b + * matrix = this and constant vector = <code>b</code>. + * * @param b constant vector * @return vector of solution values to AX = b, where A is *this * @throws IllegalArgumentException if rowDimension != b.length or matrix * is singular */ - public double[] solve(double[] b); + double[] solve(double[] b) throws IllegalArgumentException; + + /** + * Returns a matrix of (column) solution vectors for linear systems with + * coefficient matrix = this and constant vectors = columns of + * <code>b</code>. + * + * @param b matrix of constant vectors forming RHS of linear systems to + * to solve + * @return matrix of solution vectors + * @throws IllegalArgumentException if rowDimension != row dimension of b + * or this is singular + */ + RealMatrix solve(RealMatrix b) throws IllegalArgumentException; } 1.4 +510 -133 jakarta-commons-sandbox/math/src/java/org/apache/commons/math/RealMatrixImpl.java Index: RealMatrixImpl.java =================================================================== RCS file: /home/cvs/jakarta-commons-sandbox/math/src/java/org/apache/commons/math/RealMatrixImpl.java,v retrieving revision 1.3 retrieving revision 1.4 diff -u -r1.3 -r1.4 --- RealMatrixImpl.java 11 Jun 2003 14:50:29 -0000 1.3 +++ RealMatrixImpl.java 15 Jun 2003 17:01:39 -0000 1.4 @@ -56,50 +56,118 @@ import java.io.Serializable; /** - * Implementation for RealMatrix using double[][] array - * @author Phil Stetiz + * Implementation for RealMatrix using a double[][] array to store entries + * and <a href="http://www.math.gatech.edu/~bourbaki/ + * math2601/Web-notes/2num.pdf">LU decompostion</a> to support linear system + * solution and inverse. + * <p> + * The <a href="http://www.math.gatech.edu/~bourbaki/math2601/Web-notes + * /2num.pdf">LU decompostion</a> is performed as needed, to support the + * following operations: <ul> + * <li>solve</li> + * <li>isSingular</li> + * <li>getDeterminant</li> + * <li>inverse</li> </ul> + * <p> + * <strong>Usage note</strong>:<br> + * The LU decomposition is stored and reused on subsequent calls. If matrix + * data are modified using any of the public setXxx methods, the saved + * decomposition is discarded. If data are modified via references to the + * underlying array obtained using <code>getDataRef()</code>, then the stored + * LU decomposition will not be discarded. In this case, you need to + * explicitly invoke <code>LUDecompose()</code> to recompute the decomposition + * before using any of the methods above. + * + * @author Phil Steitz * @version $Revision$ $Date$ */ public class RealMatrixImpl implements RealMatrix, Serializable { /** Entries of the matrix */ - private double data[][]; + private double data[][] = null; + /** Entries of LU decomposition. + * All updates to data (other than luDecompostion) *must* set this to null + */ + private double lu[][] = null; + + /** Pivot array associated with LU decompostion */ + private int[] pivot = null; + + /** Parity of the permutation associated with the LU decomposition */ + private int parity = 1; + + /** Bound to determine effective singularity in LU decomposition */ + private static double TOO_SMALL = 10E-12; + + /** + * Creates a matrix with no data + */ public RealMatrixImpl() { } - - /** - * Create a new RealMatrix with the supplied row and column dimensions + + /** + * Create a new RealMatrix with the supplied row and column dimensions. + * * @param rowDimension the number of rows in the new matrix * @param columnDimension the number of columns in the new matrix - */ + */ public RealMatrixImpl(int rowDimension, - int columnDimension) { + int columnDimension) { data = new double[rowDimension][columnDimension]; + lu = null; } - public RealMatrixImpl(double[][] data) { - this.data = data; + /** + * Create a new RealMatrix using the <code>data</code> as the underlying + * data array. + * <p> + * The input array is copied, not referenced. + * + * @param d data for new matrix + */ + public RealMatrixImpl(double[][] d) { + this.copyIn(d); + lu = null; } /** - * Create a new RealMatrix which is a copy of *this + * Create a new (column) RealMatrix using <code>v</code> as the + * data for the unique column of the <code>v.length x 1</code> matrix + * created. + * <p> + * The input array is copied, not referenced. + * + * @param v column vector holding data for new matrix + */ + public RealMatrixImpl(double[] v) { + int nRows = v.length; + data = new double[nRows][1]; + for (int row = 0; row < nRows; row++) { + data[row][0] = v[row]; + } + } + + /** + * Create a new RealMatrix which is a copy of this. + * * @return the cloned matrix */ public RealMatrix copy() { - throw new UnsupportedOperationException("not implemented yet"); + return new RealMatrixImpl(this.copyOut()); } /** - * Compute the sum of *this and m + * Compute the sum of this and <code>m</code>. + * * @param m matrix to be added * @return this + m - * @exception IllegalArgumentException if m is not the same size as *this + * @exception IllegalArgumentException if m is not the same size as this */ - public RealMatrix add(RealMatrix m) { + public RealMatrix add(RealMatrix m) throws IllegalArgumentException { if (this.getColumnDimension() != m.getColumnDimension() || - this.getRowDimension() != m.getRowDimension()) { - throw new IllegalArgumentException("matrix dimension mismatch"); + this.getRowDimension() != m.getRowDimension()) { + throw new IllegalArgumentException("matrix dimension mismatch"); } int rowCount = this.getRowDimension(); int columnCount = this.getColumnDimension(); @@ -114,15 +182,16 @@ } /** - * Compute *this minus m + * Compute this minus <code>m</code>. + * * @param m matrix to be subtracted * @return this + m * @exception IllegalArgumentException if m is not the same size as *this */ - public RealMatrix subtract(RealMatrix m) { + public RealMatrix subtract(RealMatrix m) throws IllegalArgumentException { if (this.getColumnDimension() != m.getColumnDimension() || - this.getRowDimension() != m.getRowDimension()) { - throw new IllegalArgumentException("matrix dimension mismatch"); + this.getRowDimension() != m.getRowDimension()) { + throw new IllegalArgumentException("matrix dimension mismatch"); } int rowCount = this.getRowDimension(); int columnCount = this.getColumnDimension(); @@ -137,16 +206,19 @@ } /** - * Returns the rank of the matrix - * @return the rank of this matrix + * Returns the rank of the matrix. + * + * @return the rank of this matrix */ public int getRank() { + // FIXME: need to add singular value decomposition or drop this throw new UnsupportedOperationException("not implemented yet"); } - - /** - * Returns the result of adding d to each entry of *this + + /** + * Returns the result of adding d to each entry of this. + * * @param d value to be added to each entry * @return d + this */ @@ -161,11 +233,11 @@ } return new RealMatrixImpl(outData); } - + /** - * Returns the result multiplying each entry of *this by d - * @param d value to multiply all entries by - * @return d*this + * Returns the result multiplying each entry of this by <code>d</code> + * @param d value to multiply all entries by + * @return d * this */ public RealMatrix scalarMultiply(double d) { int rowCount = this.getRowDimension(); @@ -173,152 +245,216 @@ double[][] outData = new double[rowCount][columnCount]; for (int row = 0; row < rowCount; row++) { for (int col = 0; col < columnCount; col++) { - outData[row][col] = data[row][col]*d; + outData[row][col] = data[row][col] * d; } } return new RealMatrixImpl(outData); } /** - * Returns the result postmultiplying *this by m + * Returns the result postmultiplying this by <code>m</code>. * @param m matrix to postmultiply by * @return this*m - * @throws IllegalArgumentException + * @throws IllegalArgumentException * if columnDimension(this) != rowDimension(m) */ - public RealMatrix multiply(RealMatrix m) { - if (this.getColumnDimension() != m.getRowDimension()) { - throw new IllegalArgumentException + public RealMatrix multiply(RealMatrix m) throws IllegalArgumentException { + if (this.getColumnDimension() != m.getRowDimension()) { + throw new IllegalArgumentException ("Matrices are not multiplication compatible."); - } - double[][] mData = m.getData(); - double[][] outData = - new double[this.getRowDimension()][m.getColumnDimension()]; - double sum = 0; - for (int row = 0; row < this.getRowDimension(); row++) { - for (int col = 0; col < m.getColumnDimension(); col++) { - sum = 0; - for (int i = 0; i < this.getColumnDimension(); i++) { - sum += data[row][i] * mData[i][col]; - } - outData[row][col] = sum; - } - } - return new RealMatrixImpl(outData); + } + int nRows = this.getRowDimension(); + int nCols = this.getColumnDimension(); + double[][] mData = m.getData(); + double[][] outData = + new double[nRows][nCols]; + double sum = 0; + for (int row = 0; row < nRows; row++) { + for (int col = 0; col < nCols; col++) { + sum = 0; + for (int i = 0; i < nCols; i++) { + sum += data[row][i] * mData[i][col]; + } + outData[row][col] = sum; + } + } + return new RealMatrixImpl(outData); } /** - * Returns matrix entries as a two-dimensional array + * Returns matrix entries as a two-dimensional array. + * <p> + * Makes a fresh copy of the underlying data. + * * @return 2-dimensional array of entries */ public double[][] getData() { + return copyOut(); + } + + /** + * Overwrites the underlying data for the matrix + * with a fresh copy of <code>inData</code>. + * + * @param inData 2-dimensional array of entries + */ + public void setData(double[][] inData) { + copyIn(inData); + lu = null; + } + + /** + * Returns a reference to the underlying data array. + * <p> + * Does not make a fresh copy of the underlying data. + * + * @return 2-dimensional array of entries + */ + public double[][] getDataRef() { return data; } /** - * Sets/overwrites the underlying data for the matrix - * @param data 2-dimensional array of entries + * Overwrites the underlying data for the matrix + * with a reference to <code>inData</code>. + * <p> + * Does not make a fresh copy of <code>data</code>. + * + * @param inData 2-dimensional array of entries */ - public void setData(double[][] data) { - this.data = data; + public void setDataRef(double[][] inData) { + this.data = inData; + lu = null; } /** - * Returns the 1-norm of the matrix (max column sum) + * * @return norm */ public double getNorm() { - double maxColSum = 0; - for (int col = 0; col < this.getColumnDimension(); col++) { - double sum = 0; - for (int row = 0; row < this.getRowDimension(); row++) { - sum += Math.abs(data[row][col]); - } - maxColSum = Math.max(maxColSum,sum); - } - return maxColSum; + double maxColSum = 0; + for (int col = 0; col < this.getColumnDimension(); col++) { + double sum = 0; + for (int row = 0; row < this.getRowDimension(); row++) { + sum += Math.abs(data[row][col]); + } + maxColSum = Math.max(maxColSum, sum); + } + return maxColSum; } /** - * Returns entries in row as an array + * * @param row the row to be fetched - * @return array of entries in the row - * @throws IllegalArgumentException if row > rowDimension + * @return array of entries in the row + * @throws IllegalArgumentException if row > rowDimension or row < 1 */ - public double[] getRow(int row) { - return data[row]; + public double[] getRow(int row) throws IllegalArgumentException { + if (row > this.getRowDimension() || row < 1) { + throw new IllegalArgumentException("illegal row argument"); + } + int ncols = this.getColumnDimension(); + double[] out = new double[ncols]; + System.arraycopy(data[row - 1], 0, out, 0, ncols); + return out; } /** - * Returns entries in column as an array - * @param col column to fetch - * @return array of entries in the column - * @throws IllegalArgumentException if column > columnDimension + * @param col column to fetch + * @return array of entries in the column + * @throws IllegalArgumentException if column > columnDimension or + * column < 1 */ - public double[] getColumn(int col) { - throw new UnsupportedOperationException("not implemented yet"); + public double[] getColumn(int col) throws IllegalArgumentException { + if (col > this.getColumnDimension() || col < 1) { + throw new IllegalArgumentException("illegal column argument"); + } + int nRows = this.getRowDimension(); + double[] out = new double[nRows]; + for (int row = 0; row < nRows; row++) { + out[row] = data[row][col - 1]; + } + return out; } /** - * Returns the entry in the specified row and column - * @param row row location of entry to be fetched + * @param row row location of entry to be fetched * @param column column location of entry to be fetched - * @return matrix entry in row,column - * @throws IllegalArgumentException if entry does not exist + * @return matrix entry in row,column + * @throws IllegalArgumentException if entry does not exist */ - public double getEntry(int row, int column) { - if (row < 1 || column < 1 || row > this.getRowDimension() - || column > this.getColumnDimension()) { - throw new IllegalArgumentException - ("matrix entry does not exist"); + public double getEntry(int row, int column) + throws IllegalArgumentException { + if (row < 1 || column < 1 || row > this.getRowDimension() + || column > this.getColumnDimension()) { + throw new IllegalArgumentException + ("matrix entry does not exist"); } - return data[row-1][column-1]; + return data[row - 1][column - 1]; } /** - * Sets the entry in the specified row and column to the specified value - * @param row row location of entry to be set + * @param row row location of entry to be set * @param column column location of entry to be set - * @param value value to set + * @param value value to set * @throws IllegalArgumentException if entry does not exist */ - public void setEntry(int row, int column, double value) { + public void setEntry(int row, int column, double value) + throws IllegalArgumentException { if (row < 1 || column < 1 || row > this.getRowDimension() - || column > this.getColumnDimension()) { - throw new IllegalArgumentException - ("matrix entry does not exist"); + || column > this.getColumnDimension()) { + throw new IllegalArgumentException + ("matrix entry does not exist"); } - data[row-1][column-1] = value; + data[row - 1][column - 1] = value; + lu = null; } /** - * Returns the transpose of this matrix + * * @return transpose matrix */ public RealMatrix transpose() { - throw new UnsupportedOperationException("not implemented yet"); - } - + int nRows = this.getRowDimension(); + int nCols = this.getColumnDimension(); + RealMatrixImpl out = new RealMatrixImpl(nCols, nRows); + double[][] outData = out.getDataRef(); + for (int row = 0; row < nRows; row++) { + for (int col = 0; col < nCols; col++) { + outData[col][row] = data[row][col]; + } + } + return out; + } /** - * Returns the inverse of this matrix * @return inverse matrix - * @throws IllegalArgumentException if *this is not invertible + * @throws IllegalArgumentException if this is not invertible */ - public RealMatrix inverse() { - throw new UnsupportedOperationException("not implemented yet"); + public RealMatrix inverse() throws IllegalArgumentException { + return solve(getIdentity(this.getRowDimension())); } /** - * Returns the determinant of this matrix * @return determinant + * @throws IllegalArgumentException if matrix is not square */ - public double getDeterminant() { - throw new UnsupportedOperationException("not implemented yet"); + public double getDeterminant() throws IllegalArgumentException { + if (!isSquare()) { + throw new IllegalArgumentException("matrix is not square"); + } + if (isSingular()) { // note: this has side effect of attempting LU + return 0d; // decomp if lu == null + } else { + double det = (double) parity; + for (int i = 0; i < this.getRowDimension(); i++) { + det *= lu[i][i]; + } + return det; + } } /** - * Is this a square matrix? * @return true if the matrix is square (rowDimension = columnDimension) */ public boolean isSquare() { @@ -326,23 +462,29 @@ } /** - * Is this a singular matrix? * @return true if the matrix is singular */ public boolean isSingular() { - throw new UnsupportedOperationException("not implemented yet"); + if (lu == null) { + try { + LUDecompose(); + return false; + } catch (IllegalArgumentException ex) { + return true; + } + } else { // LU decomp must have been successfully performed + return false; // so the matrix is not singular + } } /** - * Returns the number of rows in the matrix * @return rowDimension */ public int getRowDimension() { - return data.length; + return data.length; } /** - * Returns the number of columns in the matrix * @return columnDimension */ public int getColumnDimension() { @@ -350,41 +492,276 @@ } /** - * Returns the trace of the matrix * @return trace + * @throws IllegalArgumentException if the matrix is not square */ - public double getTrace() { - throw new UnsupportedOperationException("not implemented yet"); + public double getTrace() throws IllegalArgumentException { + if (!isSquare()) { + throw new IllegalArgumentException("matrix is not square"); + } + double trace = data[0][0]; + for (int i = 1; i < this.getRowDimension(); i++) { + trace += data[i][i]; + } + return trace; } /** - * Returns the result of multiplying this by the vector b - * @return this*v - * @throws IllegalArgumentException if columnDimension != v.size() + * @param v vector to operate on + * @throws IllegalArgumentException if columnDimension != v.length + * @return resulting vector */ - public double[] operate(double[] v) { - throw new UnsupportedOperationException("not implemented yet"); + public double[] operate(double[] v) throws IllegalArgumentException { + if (v.length != this.getColumnDimension()) { + throw new IllegalArgumentException("vector has wrong length"); + } + int nRows = this.getRowDimension(); + int nCols = this.getColumnDimension(); + double[] out = new double[v.length]; + for (int row = 0; row < nRows; row++) { + double sum = 0; + for (int i = 0; i < nCols; i++) { + sum += data[row][i] * v[i]; + } + out[row] = sum; + } + return out; } /** - * Returns the result of premultiplying this by the vector v - * @return v*this - * @throws IllegalArgumentException if rowDimension != v.size() + * @param v vector to premultiply by + * @throws IllegalArgumentException if rowDimension != v.length + * @return resulting matrix */ - public RealMatrix preMultiply(double[] v) { - throw new UnsupportedOperationException("not implemented yet"); + public RealMatrix preMultiply(double[] v) throws IllegalArgumentException { + int nCols = this.getColumnDimension(); + if (v.length != nCols) { + throw new IllegalArgumentException("vector has wrong length"); + } + // being a bit lazy here -- probably should implement directly, like + // operate + RealMatrix pm = new RealMatrixImpl(v).transpose(); + return pm.multiply(this); } /** - * Returns the solution vector for a linear system with coefficient - * matrix = *this and constant vector = b * @param b constant vector - * @return vector of solution values to AX = b, where A is *this - * @throws IllegalArgumentException if rowDimension != b.length or matrix + * @return vector of solution values to AX = b, where A is this + * @throws IllegalArgumentException if rowDimension != b.length or matrix * is singular */ - public double[] solve(double[] b) { - throw new UnsupportedOperationException("not implemented yet"); - } + public double[] solve(double[] b) throws IllegalArgumentException { + int nRows = this.getRowDimension(); + if (b.length != nRows) { + throw new IllegalArgumentException + ("constant vector has wrong length"); + } + RealMatrix bMatrix = new RealMatrixImpl(b); + double[][] solution = ((RealMatrixImpl) (solve(bMatrix))).getDataRef(); + double[] out = new double[nRows]; + for (int row = 0; row < nRows; row++) { + out[row] = solution[row][0]; + } + return out; + } + + /** + * Uses LU decomposition, performing the composition if the matrix has + * not been decomposed, or if there have been changes to the matrix since + * the last decomposition. + * + * @param b the constant vector + * @return solution matrix + * @throws IllegalArgumentException if this is singular or dimensions + * do not match. + */ + public RealMatrix solve(RealMatrix b) throws IllegalArgumentException { + if (b.getRowDimension() != this.getRowDimension()) { + throw new IllegalArgumentException("Incorrect row dimension"); + } + if (this.isSingular()) { // side effect: compute LU decomp + throw new IllegalArgumentException("Matrix is singular."); + } + + int nCol = this.getColumnDimension(); + int nRow = this.getRowDimension(); + int nColB = b.getColumnDimension(); + int nRowB = b.getRowDimension(); + + // Apply permutations to b + double[][] bv = b.getData(); + double[][] bp = new double[nRowB][nColB]; + for (int row = 0; row < nRowB; row++) { + for (int col = 0; col < nColB; col++) { + bp[row][col] = bv[pivot[row]][col]; + } + } + bv = null; + + // Solve LY = b + for (int col = 0; col < nCol; col++) { + for (int i = col + 1; i < nCol; i++) { + for (int j = 0; j < nColB; j++) { + bp[i][j] -= bp[col][j] * lu[i][col]; + } + } + } + + // Solve UX = Y + for (int col = nCol - 1; col >= 0; col--) { + for (int j = 0; j < nColB; j++) { + bp[col][j] /= lu[col][col]; + } + for (int i = 0; i < col; i++) { + for (int j = 0; j < nColB; j++) { + bp[i][j] -= bp[col][j] * lu[i][col]; + } + } + } + + RealMatrixImpl outMat = new RealMatrixImpl(bp); + return outMat; + } + + /** + * Computes a new <a href="http://www.math.gatech.edu/~bourbaki/ + * math2601/Web-notes/2num.pdf">LU decompostion</a> for this matrix, + * storing the result for use by other methods. + * <p> + * <strong>Implementation Note</strong>:<br> + * Uses <a href="http://www.damtp.cam.ac.uk/user/fdl/ + * people/sd/lectures/nummeth98/linear.htm">Crout's algortithm</a>, + * with partial pivoting. + * <p> + * <strong>Usage Note</strong>:<br> + * This method should rarely be invoked directly. Its only use is + * to force recomputation of the LU decomposition when changes have been + * made to the underlying data using direct array references. Changes + * made using setXxx methods will trigger recomputation when needed + * automatically. + * + * @throws IllegalArgumentException if the matrix is singular + */ + public void LUDecompose() throws IllegalArgumentException { + int nRows = this.getRowDimension(); + int nCols = this.getColumnDimension(); + lu = this.getData(); + + // Initialize pivot array and parity + pivot = new int[nRows]; + for (int row = 0; row < nRows; row++) { + pivot[row] = row; + } + parity = 1; + + // Loop over columns + for (int col = 0; col < nCols; col++) { + + double sum = 0; + + // upper + for (int row = 0; row < col; row++) { + sum = lu[row][col]; + for (int i = 0; i < row; i++) { + sum -= lu[row][i] * lu[i][col]; + } + lu[row][col] = sum; + } + + // lower + int max = col; // pivot row + double largest = 0d; + for (int row = col; row < nRows; row++) { + sum = lu[row][col]; + for (int i = 0; i < col; i++) { + sum -= lu[row][i] * lu[i][col]; + } + lu[row][col] = sum; + + // maintain best pivot choice + if (Math.abs(sum) > largest) { + largest = Math.abs(sum); + max = row; + } + } + + // Singularity check + if (Math.abs(lu[max][col]) < TOO_SMALL) { + lu = null; + throw new IllegalArgumentException("matrix is singular"); + } + + // Pivot if necessary + if (max != col) { + double tmp = 0; + for (int i = 0; i < nCols; i++) { + tmp = lu[max][i]; + lu[max][i] = lu[col][i]; + lu[col][i] = tmp; + } + int temp = pivot[max]; + pivot[max] = pivot[col]; + pivot[col] = temp; + parity = -parity; + } + + //Divide the lower elements by the "winning" diagonal elt. + for (int row = col + 1; row < nRows; row++) { + lu[row][col] /= lu[col][col]; + } + } + } + + //------------------------ Protected methods + + /** + * Returns <code>dimension x dimension</code> identity matrix. + * + * @param dimension dimension of identity matrix to generate + * @return identity matrix + */ + protected RealMatrix getIdentity(int dimension) { + RealMatrixImpl out = new RealMatrixImpl(dimension, dimension); + double[][] d = out.getDataRef(); + for (int row = 0; row < dimension; row++) { + for (int col = 0; col < dimension; col++) { + d[row][col] = row == col ? 1d : 0d; + } + } + return out; + } + + //------------------------ Private methods + + /** + * Returns a fresh copy of the underlying data array. + * + * @return a copy of the underlying data array. + */ + private double[][] copyOut() { + int nRows = this.getRowDimension(); + double[][] out = + new double[nRows][this.getColumnDimension()]; + // can't copy 2-d array in one shot, otherwise get row references + for (int i = 0; i < nRows; i++) { + System.arraycopy(data[i], 0, out[i], 0, data[i].length); + } + return out; + } + /** + * Replaces data with a fresh copy of the input array. + * + * @param in data to copy in + */ + private void copyIn(double[][] in) { + int nRows = in.length; + int nCols = in[0].length; + data = new double[nRows][nCols]; + System.arraycopy(in, 0, data, 0, in.length); + for (int i = 0; i < nRows ; i++) { + System.arraycopy(in[i], 0, data[i], 0, nCols); + } + lu = null; + } } 1.2 +232 -6 jakarta-commons-sandbox/math/src/test/org/apache/commons/math/RealMatrixImplTest.java Index: RealMatrixImplTest.java =================================================================== RCS file: /home/cvs/jakarta-commons-sandbox/math/src/test/org/apache/commons/math/RealMatrixImplTest.java,v retrieving revision 1.1 retrieving revision 1.2 diff -u -r1.1 -r1.2 --- RealMatrixImplTest.java 12 May 2003 19:02:53 -0000 1.1 +++ RealMatrixImplTest.java 15 Jun 2003 17:01:39 -0000 1.2 @@ -67,15 +67,28 @@ public final class RealMatrixImplTest extends TestCase { private double[][] testData = { {1d,2d,3d}, {2d,5d,3d}, {1d,0d,8d} }; + private double[][] testDataPlus2 = { {3d,4d,5d}, {4d,7d,5d}, {3d,2d,10d} }; + private double[][] testDataMinus = { {-1d,-2d,-3d}, {-2d,-5d,-3d}, + {-1d,0d,-8d} }; + private double[] testDataRow1 = {1d,2d,3d}; + private double[] testDataCol3 = {3d,3d,8d}; private double[][] testDataInv = { {-40d,16d,9d}, {13d,-5d,-3d}, {5d,-2d,-1d} }; + private double[][] preMultTest = {{8,12,33}}; private double[][] testData2 ={ {1d,2d,3d}, {2d,5d,3d}}; + private double[][] testData2T = { {1d,2d}, {2d,5d}, {3d,3d}}; private double[][] testDataPlusInv = { {-39d,18d,12d}, {15d,0d,0d}, {6d,-2d,7d} }; private double[][] id = { {1d,0d,0d}, {0d,1d,0d}, {0d,0d,1d} }; + private double[][] luData = { {2d,3d,3d}, {0d,5d,7d}, {6d,9d,8d} }; + private double[][] singular = { {2d,3d}, {2d,3d} }; + private double[][] bigSingular = {{1d,2d,3d,4d}, {2d,5d,3d,4d}, + {7d,3d,256d,1930d}, {3d,7d,6d,8d}}; // 4th row = 1st + 2nd + private double[][] detData = { {1d,2d,3d}, {4d,5d,6d}, {7d,8d,10d} }; private double[] testVector = {1,2,3}; - private double entryTolerance = Math.pow(2,-64); - private double normTolerance = Math.pow(2,-64); + private double[] testVector2 = {1,2,3,4}; + private double entryTolerance = 10E-16; + private double normTolerance = 10E-14; public RealMatrixImplTest(String name) { super(name); @@ -101,7 +114,24 @@ assertEquals("testData2 row dimension",m2.getRowDimension(),2); assertEquals("testData2 column dimension",m2.getColumnDimension(),3); assertTrue("testData2 is not square",!m2.isSquare()); - } + } + + /** test copy functions */ + public void testCopyFunctions() { + RealMatrixImpl m = new RealMatrixImpl(testData); + RealMatrixImpl m2 = new RealMatrixImpl(testData2); + m2.setData(m.getData()); + assertClose("getData",m2,m,entryTolerance); + // no dangling reference... + m2.setEntry(1,1,2000d); + RealMatrixImpl m3 = new RealMatrixImpl(testData); + assertClose("no getData side effect",m,m3,entryTolerance); + m3 = (RealMatrixImpl) m.copy(); + double[][] stompMe = {{1d,2d,3d}}; + m3.setDataRef(stompMe); + assertClose("no copy side effect",m,new RealMatrixImpl(testData), + entryTolerance); + } /** test add */ public void testAdd() { @@ -143,7 +173,13 @@ RealMatrixImpl m = new RealMatrixImpl(testData); RealMatrixImpl m2 = new RealMatrixImpl(testDataInv); assertClose("m-n = m + -n",m.subtract(m2), - m2.scalarMultiply(-1d).add(m),entryTolerance); + m2.scalarMultiply(-1d).add(m),entryTolerance); + try { + RealMatrix a = m.subtract(new RealMatrixImpl(testData2)); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } } /** test multiply */ @@ -161,12 +197,202 @@ assertClose("identity multiply",identity.multiply(mInv), mInv,entryTolerance); assertClose("identity multiply",m2.multiply(identity), - m2,entryTolerance); + m2,entryTolerance); + try { + RealMatrix a = m.multiply(new RealMatrixImpl(bigSingular)); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + /** test isSingular */ + public void testIsSingular() { + RealMatrixImpl m = new RealMatrixImpl(singular); + assertTrue("singular",m.isSingular()); + m = new RealMatrixImpl(bigSingular); + assertTrue("big singular",m.isSingular()); + m = new RealMatrixImpl(id); + assertTrue("identity nonsingular",!m.isSingular()); + m = new RealMatrixImpl(testData); + assertTrue("testData nonsingular",!m.isSingular()); + } + + /** test inverse */ + public void testInverse() { + RealMatrixImpl m = new RealMatrixImpl(testData); + RealMatrix mInv = new RealMatrixImpl(testDataInv); + assertClose("inverse",mInv,m.inverse(),normTolerance); + assertClose("inverse^2",m,m.inverse().inverse(),10E-12); + } + + /** test solve */ + public void testSolve() { + RealMatrixImpl m = new RealMatrixImpl(testData); + RealMatrix mInv = new RealMatrixImpl(testDataInv); + // being a bit slothful here -- actually testing that X = A^-1 * B + assertClose("inverse-operate",mInv.operate(testVector), + m.solve(testVector),normTolerance); + try { + double[] x = m.solve(testVector2); + fail("expecting IllegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + RealMatrix bs = new RealMatrixImpl(bigSingular); + try { + RealMatrix a = bs.solve(bs); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + try { + RealMatrix a = m.solve(bs); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + /** test determinant */ + public void testDeterminant() { + RealMatrix m = new RealMatrixImpl(bigSingular); + assertEquals("singular determinant",0,m.getDeterminant(),0); + m = new RealMatrixImpl(detData); + assertEquals("nonsingular test",-3d,m.getDeterminant(),normTolerance); + try { + double a = new RealMatrixImpl(testData2).getDeterminant(); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + /** test trace */ + public void testTrace() { + RealMatrix m = new RealMatrixImpl(id); + assertEquals("identity trace",3d,m.getTrace(),entryTolerance); + m = new RealMatrixImpl(testData2); + try { + double x = m.getTrace(); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } } + /** test sclarAdd */ + public void testScalarAdd() { + RealMatrix m = new RealMatrixImpl(testData); + assertClose("scalar add",new RealMatrixImpl(testDataPlus2), + m.scalarAdd(2d),entryTolerance); + } + + /** test operate */ + public void testOperate() { + RealMatrix m = new RealMatrixImpl(id); + double[] x = m.operate(testVector); + assertClose("identity operate",testVector,x,entryTolerance); + m = new RealMatrixImpl(bigSingular); + try { + x = m.operate(testVector); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + /** test transpose */ + public void testTranspose() { + RealMatrix m = new RealMatrixImpl(testData); + assertClose("inverse-transpose",m.inverse().transpose(), + m.transpose().inverse(),normTolerance); + m = new RealMatrixImpl(testData2); + RealMatrix mt = new RealMatrixImpl(testData2T); + assertClose("transpose",mt,m.transpose(),normTolerance); + } + + /** test preMultiply */ + public void testPremultiply() { + RealMatrix m = new RealMatrixImpl(testData); + RealMatrix mp = new RealMatrixImpl(preMultTest); + assertClose("premultiply",m.preMultiply(testVector),mp,normTolerance); + m = new RealMatrixImpl(bigSingular); + try { + RealMatrix x = m.preMultiply(testVector); + fail("expecting IllegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + public void testGetVectors() { + RealMatrix m = new RealMatrixImpl(testData); + assertClose("get row",m.getRow(1),testDataRow1,entryTolerance); + assertClose("get col",m.getColumn(3),testDataCol3,entryTolerance); + try { + double[] x = m.getRow(10); + fail("expecting IllegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + try { + double[] x = m.getColumn(-1); + fail("expecting IllegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + public void testEntryMutators() { + RealMatrix m = new RealMatrixImpl(testData); + assertEquals("get entry",m.getEntry(1,2),2d,entryTolerance); + m.setEntry(1,2,100d); + assertEquals("get entry",m.getEntry(1,2),100d,entryTolerance); + try { + double x = m.getEntry(0,2); + fail("expecting IllegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + try { + m.setEntry(1,4,200d); + fail("expecting IllegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + + //--------------- -----------------Private methods + + /** verifies that two matrices are close (1-norm) */ private void assertClose(String msg, RealMatrix m, RealMatrix n, double tolerance) { assertTrue(msg,m.subtract(n).getNorm() < tolerance); + } + + /** verifies that two vectors are close (sup norm) */ + private void assertClose(String msg, double[] m, double[] n, + double tolerance) { + if (m.length != n.length) { + fail("vectors not same length"); + } + for (int i = 0; i < m.length; i++) { + assertEquals(msg + " " + i + " elements differ", + m[i],n[i],tolerance); + } + } + + /** Useful for debugging */ + private void dumpMatrix(RealMatrix m) { + for (int i = 0; i < m.getRowDimension(); i++) { + String os = ""; + for (int j = 0; j < m.getColumnDimension(); j++) { + os += m.getEntry(i+1, j+1) + " "; + } + System.out.println(os); + } } }
--------------------------------------------------------------------- To unsubscribe, e-mail: [EMAIL PROTECTED] For additional commands, e-mail: [EMAIL PROTECTED]