IGNITE-5925: Implemented get row/col for matrices.
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/45708b97 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/45708b97 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/45708b97 Branch: refs/heads/ignite-5578 Commit: 45708b9725b2d94d54380abf635508358a288a66 Parents: f00449f Author: Yury Babak <[email protected]> Authored: Fri Aug 18 16:12:05 2017 +0300 Committer: Igor Sapego <[email protected]> Committed: Fri Aug 18 16:12:05 2017 +0300 ---------------------------------------------------------------------- .../java/org/apache/ignite/ml/math/Blas.java | 4 ++-- .../java/org/apache/ignite/ml/math/Matrix.java | 16 +++++++++++++ .../ml/math/impls/matrix/AbstractMatrix.java | 25 ++++++++++++++++++++ .../impls/matrix/MatrixImplementationsTest.java | 17 +++++++++++++ 4 files changed, 60 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/45708b97/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java index a61d796..4b83ede 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java @@ -298,7 +298,7 @@ public class Blas { throw new CardinalityException(a.columnSize(), y.size()); checkMatrixType(a, "gemv"); - checkVectorType(x,"gemv"); + checkVectorType(x, "gemv"); checkVectorType(y, "gemv"); if (alpha == 0.0 && beta == 1.0) @@ -322,7 +322,7 @@ public class Blas { /** * M := alpha * M. * @param m Matrix M. - * @param alpha Aplha. + * @param alpha Alpha. */ private static void scal(Matrix m, double alpha) { if (alpha != 1.0) http://git-wip-us.apache.org/repos/asf/ignite/blob/45708b97/modules/ml/src/main/java/org/apache/ignite/ml/math/Matrix.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/Matrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/Matrix.java index 66de1a1..8c171a6 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/Matrix.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/Matrix.java @@ -414,6 +414,14 @@ public interface Matrix extends MetaAttributes, Externalizable, StorageOpsMetric public Matrix setRow(int row, double[] data); /** + * Get a specific row from matrix. + * + * @param row Row index. + * @return row. + */ + public Vector getRow(int row); + + /** * Sets values for given column. * * @param col Column index. @@ -425,6 +433,14 @@ public interface Matrix extends MetaAttributes, Externalizable, StorageOpsMetric public Matrix setColumn(int col, double[] data); /** + * Get a specific row from matrix. + * + * @param col Col index. + * @return Col. + */ + public Vector getCol(int col); + + /** * Sets given value without checking for index bounds. This method is marginally faster * than its {@link #set(int, int, double)} sibling. * http://git-wip-us.apache.org/repos/asf/ignite/blob/45708b97/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java index 2195a70..06fb34c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java @@ -41,6 +41,7 @@ import org.apache.ignite.ml.math.functions.IgniteDoubleFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.functions.IgniteTriFunction; import org.apache.ignite.ml.math.functions.IntIntToDoubleFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.math.impls.vector.MatrixVectorView; /** @@ -709,6 +710,18 @@ public abstract class AbstractMatrix implements Matrix { } /** {@inheritDoc} */ + @Override public Vector getRow(int row) { + checkRowIndex(row); + + Vector res = new DenseLocalOnHeapVector(columnSize()); + + for (int i = 0; i < columnSize(); i++) + res.setX(i, getX(row,i)); + + return res; + } + + /** {@inheritDoc} */ @Override public Matrix setColumn(int col, double[] data) { checkColumnIndex(col); @@ -724,6 +737,18 @@ public abstract class AbstractMatrix implements Matrix { } /** {@inheritDoc} */ + @Override public Vector getCol(int col) { + checkColumnIndex(col); + + Vector res = new DenseLocalOnHeapVector(rowSize()); + + for (int i = 0; i < rowSize(); i++) + res.setX(i, getX(i,col)); + + return res; + } + + /** {@inheritDoc} */ @Override public Matrix setX(int row, int col, double val) { storageSet(row, col, val); http://git-wip-us.apache.org/repos/asf/ignite/blob/45708b97/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java index 89b6224..8270da1 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java @@ -927,6 +927,23 @@ public class MatrixImplementationsTest extends ExternalizeTest<Matrix> { } /** */ + @Test + public void testGetRowCol(){ + consumeSampleMatrix((m,desc)-> { + if (! (m instanceof RandomMatrix)) + for (int i = 0; i < m.rowSize(); i++) + for (int j = 0; j < m.columnSize(); j++) + m.setX(i, j, i + j); + + for (int i = 0; i < m.rowSize(); i++) + assertNotNull("Unexpected value for " + desc + " at row " + i, m.getRow(i)); + + for (int i = 0; i < m.columnSize(); i++) + assertNotNull("Unexpected value for " + desc + " at col " + i, m.getCol(i)); + }); + } + + /** */ private double[] fillArray(int len) { double[] newValues = new double[len];
