Repository: ignite Updated Branches: refs/heads/master 5d8e31806 -> 0abf6601f
http://git-wip-us.apache.org/repos/asf/ignite/blob/0abf6601/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorIterableTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorIterableTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorIterableTest.java new file mode 100644 index 0000000..16c2571 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorIterableTest.java @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.impls.vector; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Spliterator; +import java.util.function.BiConsumer; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.MathTestConstants; +import org.junit.Test; + +import static java.util.Spliterator.ORDERED; +import static java.util.Spliterator.SIZED; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** */ +public class VectorIterableTest { + /** */ + @Test + public void allTest() { + consumeSampleVectors( + (v, desc) -> { + int expIdx = 0; + + for (Vector.Element e : v.all()) { + int actualIdx = e.index(); + + assertEquals("Unexpected index for " + desc, + expIdx, actualIdx); + + expIdx++; + } + + assertEquals("Unexpected amount of elements for " + desc, + expIdx, v.size()); + } + ); + } + + /** */ + @Test + public void allTestBound() { + consumeSampleVectors( + (v, desc) -> iteratorTestBound(v.all().iterator(), desc) + ); + } + + /** */ + @Test + public void nonZeroesTestBasic() { + final int size = 5; + + final double[] nonZeroesOddData = new double[size], nonZeroesEvenData = new double[size]; + + for (int idx = 0; idx < size; idx++) { + final boolean odd = (idx & 1) == 1; + + nonZeroesOddData[idx] = odd ? 1 : 0; + + nonZeroesEvenData[idx] = odd ? 0 : 1; + } + + assertTrue("Arrays failed to initialize.", + !isZero(nonZeroesEvenData[0]) + && isZero(nonZeroesEvenData[1]) + && isZero(nonZeroesOddData[0]) + && !isZero(nonZeroesOddData[1])); + + final Vector nonZeroesEvenVec = new DenseLocalOnHeapVector(nonZeroesEvenData), + nonZeroesOddVec = new DenseLocalOnHeapVector(nonZeroesOddData); + + assertTrue("Vectors failed to initialize.", + !isZero(nonZeroesEvenVec.getElement(0).get()) + && isZero(nonZeroesEvenVec.getElement(1).get()) + && isZero(nonZeroesOddVec.getElement(0).get()) + && !isZero(nonZeroesOddVec.getElement(1).get())); + + assertTrue("Iterator(s) failed to start.", + nonZeroesEvenVec.nonZeroes().iterator().next() != null + && nonZeroesOddVec.nonZeroes().iterator().next() != null); + + int nonZeroesActual = 0; + + for (Vector.Element e : nonZeroesEvenVec.nonZeroes()) { + final int idx = e.index(); + + final boolean odd = (idx & 1) == 1; + + final double val = e.get(); + + assertTrue("Not an even index " + idx + ", for value " + val, !odd); + + assertTrue("Zero value " + val + " at even index " + idx, !isZero(val)); + + nonZeroesActual++; + } + + final int nonZeroesOddExp = (size + 1) / 2; + + assertEquals("Unexpected num of iterated odd non-zeroes.", nonZeroesOddExp, nonZeroesActual); + + assertEquals("Unexpected nonZeroElements of odd.", nonZeroesOddExp, nonZeroesEvenVec.nonZeroElements()); + + nonZeroesActual = 0; + + for (Vector.Element e : nonZeroesOddVec.nonZeroes()) { + final int idx = e.index(); + + final boolean odd = (idx & 1) == 1; + + final double val = e.get(); + + assertTrue("Not an odd index " + idx + ", for value " + val, odd); + + assertTrue("Zero value " + val + " at even index " + idx, !isZero(val)); + + nonZeroesActual++; + } + + final int nonZeroesEvenExp = size / 2; + + assertEquals("Unexpected num of iterated even non-zeroes", nonZeroesEvenExp, nonZeroesActual); + + assertEquals("Unexpected nonZeroElements of even", nonZeroesEvenExp, nonZeroesOddVec.nonZeroElements()); + } + + /** */ + @Test + public void nonZeroesTest() { + // todo make RandomVector constructor that accepts a function and use it here + // in order to *reliably* test non-zeroes in there + consumeSampleVectors( + (v, desc) -> consumeSampleVectorsWithZeroes(v, (vec, numZeroes) + -> { + int numZeroesActual = vec.size(); + + for (Vector.Element e : vec.nonZeroes()) { + numZeroesActual--; + + assertTrue("Unexpected zero at " + desc + ", index " + e.index(), !isZero(e.get())); + } + + assertEquals("Unexpected num zeroes at " + desc, (int)numZeroes, numZeroesActual); + })); + } + + /** */ + @Test + public void nonZeroesTestBound() { + consumeSampleVectors( + (v, desc) -> consumeSampleVectorsWithZeroes(v, (vec, numZeroes) + -> iteratorTestBound(vec.nonZeroes().iterator(), desc))); + } + + /** */ + @Test + public void nonZeroElementsTest() { + consumeSampleVectors( + (v, desc) -> consumeSampleVectorsWithZeroes(v, (vec, numZeroes) + -> assertEquals("Unexpected num zeroes at " + desc, + (int)numZeroes, vec.size() - vec.nonZeroElements()))); + } + + /** */ + @Test + public void allSpliteratorTest() { + consumeSampleVectors( + (v, desc) -> { + final String desc1 = " " + desc; + + Spliterator<Double> spliterator = v.allSpliterator(); + + assertNotNull(MathTestConstants.NULL_VAL + desc1, spliterator); + + assertNull(MathTestConstants.NOT_NULL_VAL + desc1, spliterator.trySplit()); + + assertTrue(MathTestConstants.UNEXPECTED_VAL + desc1, spliterator.hasCharacteristics(ORDERED | SIZED)); + + if (!readOnly(v)) + fillWithNonZeroes(v); + + spliterator = v.allSpliterator(); + + assertNotNull(MathTestConstants.NULL_VAL + desc1, spliterator); + + assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.estimateSize(), v.size()); + + assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.getExactSizeIfKnown(), v.size()); + + assertTrue(MathTestConstants.UNEXPECTED_VAL + desc1, spliterator.hasCharacteristics(ORDERED | SIZED)); + + Spliterator<Double> secondHalf = spliterator.trySplit(); + + assertNull(MathTestConstants.NOT_NULL_VAL + desc1, secondHalf); + + spliterator.tryAdvance(x -> { + }); + } + ); + } + + /** */ + @Test + public void nonZeroSpliteratorTest() { + consumeSampleVectors( + (v, desc) -> consumeSampleVectorsWithZeroes(v, (vec, numZeroes) + -> { + final String desc1 = " Num zeroes " + numZeroes + " " + desc; + + Spliterator<Double> spliterator = vec.nonZeroSpliterator(); + + assertNotNull(MathTestConstants.NULL_VAL + desc1, spliterator); + + assertNull(MathTestConstants.NOT_NULL_VAL + desc1, spliterator.trySplit()); + + assertTrue(MathTestConstants.UNEXPECTED_VAL + desc1, spliterator.hasCharacteristics(ORDERED | SIZED)); + + spliterator = vec.nonZeroSpliterator(); + + assertNotNull(MathTestConstants.NULL_VAL + desc1, spliterator); + + assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.estimateSize(), vec.size() - numZeroes); + + assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.getExactSizeIfKnown(), vec.size() - numZeroes); + + assertTrue(MathTestConstants.UNEXPECTED_VAL + desc1, spliterator.hasCharacteristics(ORDERED | SIZED)); + + Spliterator<Double> secondHalf = spliterator.trySplit(); + + assertNull(MathTestConstants.NOT_NULL_VAL + desc1, secondHalf); + + double[] data = new double[vec.size()]; + + for (Vector.Element e : vec.all()) + data[e.index()] = e.get(); + + spliterator = vec.nonZeroSpliterator(); + + assertNotNull(MathTestConstants.NULL_VAL + desc1, spliterator); + + assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.estimateSize(), + Arrays.stream(data).filter(x -> x != 0d).count()); + + assertEquals(MathTestConstants.VAL_NOT_EQUALS + desc1, spliterator.getExactSizeIfKnown(), + Arrays.stream(data).filter(x -> x != 0d).count()); + + assertTrue(MathTestConstants.UNEXPECTED_VAL + desc1, spliterator.hasCharacteristics(ORDERED | SIZED)); + + secondHalf = spliterator.trySplit(); + + assertNull(MathTestConstants.NOT_NULL_VAL + desc1, secondHalf); + + if (!spliterator.tryAdvance(x -> { + })) + fail(MathTestConstants.NO_NEXT_ELEMENT + desc1); + })); + } + + /** */ + private void iteratorTestBound(Iterator<Vector.Element> it, String desc) { + while (it.hasNext()) + assertNotNull(it.next()); + + boolean expECaught = false; + + try { + it.next(); + } + catch (NoSuchElementException e) { + expECaught = true; + } + + assertTrue("Expected exception missed for " + desc, + expECaught); + } + + /** */ + private void consumeSampleVectorsWithZeroes(Vector sample, + BiConsumer<Vector, Integer> consumer) { + if (readOnly(sample)) { + int numZeroes = 0; + + for (Vector.Element e : sample.all()) + if (isZero(e.get())) + numZeroes++; + + consumer.accept(sample, numZeroes); + + return; + } + + fillWithNonZeroes(sample); + + consumer.accept(sample, 0); + + final int sampleSize = sample.size(); + + if (sampleSize == 0) + return; + + for (Vector.Element e : sample.all()) + e.set(0); + + consumer.accept(sample, sampleSize); + + fillWithNonZeroes(sample); + + for (int testIdx : new int[] {0, sampleSize / 2, sampleSize - 1}) { + final Vector.Element e = sample.getElement(testIdx); + + final double backup = e.get(); + + e.set(0); + + consumer.accept(sample, 1); + + e.set(backup); + } + + if (sampleSize < 3) + return; + + sample.getElement(sampleSize / 3).set(0); + + sample.getElement((2 * sampleSize) / 3).set(0); + + consumer.accept(sample, 2); + } + + /** */ + private void fillWithNonZeroes(Vector sample) { + int idx = 0; + + for (Vector.Element e : sample.all()) + e.set(1 + idx++); + + assertEquals("Not all filled with non-zeroes", idx, sample.size()); + } + + /** */ + private void consumeSampleVectors(BiConsumer<Vector, String> consumer) { + new VectorImplementationsFixtures().consumeSampleVectors(null, consumer); + } + + /** */ + private boolean isZero(double val) { + return val == 0.0; + } + + /** */ + private boolean readOnly(Vector v) { + return v instanceof RandomVector || v instanceof ConstantVector; + } +} + http://git-wip-us.apache.org/repos/asf/ignite/blob/0abf6601/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorNormTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorNormTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorNormTest.java new file mode 100644 index 0000000..4e4f212 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorNormTest.java @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.impls.vector; + +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import org.apache.ignite.ml.math.Vector; +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +/** */ +public class VectorNormTest { + /** */ + @Test + public void normalizeTest() { + normalizeTest(2, (val, len) -> val / len, Vector::normalize); + } + + /** */ + @Test + public void normalizePowerTest() { + for (double pow : new double[] {0, 0.5, 1, 2, 2.5, Double.POSITIVE_INFINITY}) + normalizeTest(pow, (val, norm) -> val / norm, (v) -> v.normalize(pow)); + } + + /** */ + @Test + public void logNormalizeTest() { + normalizeTest(2, (val, len) -> Math.log1p(val) / (len * Math.log(2)), Vector::logNormalize); + } + + /** */ + @Test + public void logNormalizePowerTest() { + for (double pow : new double[] {1.1, 2, 2.5}) + normalizeTest(pow, (val, norm) -> Math.log1p(val) / (norm * Math.log(pow)), (v) -> v.logNormalize(pow)); + } + + /** */ + @Test + public void kNormTest() { + for (double pow : new double[] {0, 0.5, 1, 2, 2.5, Double.POSITIVE_INFINITY}) + toDoubleTest(pow, ref -> new Norm(ref, pow).calculate(), v -> v.kNorm(pow)); + } + + /** */ + @Test + public void getLengthSquaredTest() { + toDoubleTest(2.0, ref -> new Norm(ref, 2).sumPowers(), Vector::getLengthSquared); + } + + /** */ + @Test + public void getDistanceSquaredTest() { + consumeSampleVectors((v, desc) -> { + new VectorImplementationsTest.ElementsChecker(v, desc); // IMPL NOTE this initialises vector + + final int size = v.size(); + final Vector vOnHeap = new DenseLocalOnHeapVector(size); + final Vector vOffHeap = new DenseLocalOffHeapVector(size); + + invertValues(v, vOnHeap); + invertValues(v, vOffHeap); + + for (int idx = 0; idx < size; idx++) { + final double exp = v.get(idx); + final int idxMirror = size - 1 - idx; + + assertTrue("On heap vector difference at " + desc + ", idx " + idx, + exp - vOnHeap.get(idxMirror) == 0); + assertTrue("Off heap vector difference at " + desc + ", idx " + idx, + exp - vOffHeap.get(idxMirror) == 0); + } + + final double exp = vOnHeap.minus(v).getLengthSquared(); // IMPL NOTE this won't mutate vOnHeap + final VectorImplementationsTest.Metric metric = new VectorImplementationsTest.Metric(exp, v.getDistanceSquared(vOnHeap)); + + assertTrue("On heap vector not close enough at " + desc + ", " + metric, + metric.closeEnough()); + + final VectorImplementationsTest.Metric metric1 = new VectorImplementationsTest.Metric(exp, v.getDistanceSquared(vOffHeap)); + + assertTrue("Off heap vector not close enough at " + desc + ", " + metric1, + metric1.closeEnough()); + }); + } + + /** */ + @Test + public void dotTest() { + consumeSampleVectors((v, desc) -> { + new VectorImplementationsTest.ElementsChecker(v, desc); // IMPL NOTE this initialises vector + + final int size = v.size(); + final Vector v1 = new DenseLocalOnHeapVector(size); + + invertValues(v, v1); + + final double actual = v.dot(v1); + + double exp = 0; + + for (Vector.Element e : v.all()) + exp += e.get() * v1.get(e.index()); + + final VectorImplementationsTest.Metric metric = new VectorImplementationsTest.Metric(exp, actual); + + assertTrue("Dot product not close enough at " + desc + ", " + metric, + metric.closeEnough()); + }); + } + + /** */ + private void invertValues(Vector src, Vector dst) { + final int size = src.size(); + + for (Vector.Element e : src.all()) { + final int idx = size - 1 - e.index(); + final double val = e.get(); + + dst.set(idx, val); + } + } + + /** */ + private void toDoubleTest(Double val, Function<double[], Double> calcRef, Function<Vector, Double> calcVec) { + consumeSampleVectors((v, desc) -> { + final int size = v.size(); + final double[] ref = new double[size]; + + new VectorImplementationsTest.ElementsChecker(v, ref, desc); // IMPL NOTE this initialises vector and reference array + + final double exp = calcRef.apply(ref); + final double obtained = calcVec.apply(v); + final VectorImplementationsTest.Metric metric = new VectorImplementationsTest.Metric(exp, obtained); + + assertTrue("Not close enough at " + desc + + (val == null ? "" : ", value " + val) + ", " + metric, metric.closeEnough()); + }); + } + + /** */ + private void normalizeTest(double pow, BiFunction<Double, Double, Double> operation, + Function<Vector, Vector> vecOperation) { + consumeSampleVectors((v, desc) -> { + final int size = v.size(); + final double[] ref = new double[size]; + final boolean nonNegative = pow != (int)pow; + + final VectorImplementationsTest.ElementsChecker checker = new VectorImplementationsTest.ElementsChecker(v, ref, desc + ", pow = " + pow, nonNegative); + final double norm = new Norm(ref, pow).calculate(); + + for (int idx = 0; idx < size; idx++) + ref[idx] = operation.apply(ref[idx], norm); + + checker.assertCloseEnough(vecOperation.apply(v), ref); + }); + } + + /** */ + private void consumeSampleVectors(BiConsumer<Vector, String> consumer) { + new VectorImplementationsFixtures().consumeSampleVectors(null, consumer); + } + + /** */ + private static class Norm { + /** */ + private final double[] arr; + + /** */ + private final Double pow; + + /** */ + Norm(double[] arr, double pow) { + this.arr = arr; + this.pow = pow; + } + + /** */ + double calculate() { + if (pow.equals(0.0)) + return countNonZeroes(); // IMPL NOTE this is beautiful if you think of it + + if (pow.equals(Double.POSITIVE_INFINITY)) + return maxAbs(); + + return Math.pow(sumPowers(), 1 / pow); + } + + /** */ + double sumPowers() { + if (pow.equals(0.0)) + return countNonZeroes(); + + double norm = 0; + + for (double val : arr) + norm += pow == 1 ? Math.abs(val) : Math.pow(val, pow); + + return norm; + } + + /** */ + private int countNonZeroes() { + int cnt = 0; + + final Double zero = 0.0; + + for (double val : arr) + if (!zero.equals(val)) + cnt++; + + return cnt; + } + + /** */ + private double maxAbs() { + double res = 0; + + for (double val : arr) { + final double abs = Math.abs(val); + + if (abs > res) + res = abs; + } + + return res; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/0abf6601/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorToMatrixTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorToMatrixTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorToMatrixTest.java new file mode 100644 index 0000000..4d5bc56 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorToMatrixTest.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.impls.vector; + +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; +import org.apache.ignite.ml.math.impls.matrix.DenseLocalOffHeapMatrix; +import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; +import org.apache.ignite.ml.math.impls.matrix.RandomMatrix; +import org.apache.ignite.ml.math.impls.matrix.SparseLocalOnHeapMatrix; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests for methods of Vector that involve Matrix. */ +public class VectorToMatrixTest { + /** */ + private static final Map<Class<? extends Vector>, Class<? extends Matrix>> typesMap = typesMap(); + + /** */ + private static final List<Class<? extends Vector>> likeMatrixUnsupported = Arrays.asList(FunctionVector.class, + SingleElementVector.class, SingleElementVectorView.class, ConstantVector.class); + + /** */ + @Test + public void testHaveLikeMatrix() throws InstantiationException, IllegalAccessException { + for (Class<? extends Vector> key : typesMap.keySet()) { + Class<? extends Matrix> val = typesMap.get(key); + + if (val == null && likeMatrixSupported(key)) + System.out.println("Missing test for implementation of likeMatrix for " + key.getSimpleName()); + } + } + + /** */ + @Test + public void testLikeMatrixUnsupported() throws Exception { + consumeSampleVectors((v, desc) -> { + if (likeMatrixSupported(v.getClass())) + return; + + boolean expECaught = false; + + try { + assertNull("Null view instead of exception in " + desc, v.likeMatrix(1, 1)); + } + catch (UnsupportedOperationException uoe) { + expECaught = true; + } + + assertTrue("Expected exception was not caught in " + desc, expECaught); + }); + } + + /** */ + @Test + public void testLikeMatrix() { + consumeSampleVectors((v, desc) -> { + if (!availableForTesting(v)) + return; + + final Matrix matrix = v.likeMatrix(1, 1); + + Class<? extends Vector> key = v.getClass(); + + Class<? extends Matrix> expMatrixType = typesMap.get(key); + + assertNotNull("Expect non-null matrix for " + key.getSimpleName() + " in " + desc, matrix); + + Class<? extends Matrix> actualMatrixType = matrix.getClass(); + + assertTrue("Expected matrix type " + expMatrixType.getSimpleName() + + " should be assignable from actual type " + actualMatrixType.getSimpleName() + " in " + desc, + expMatrixType.isAssignableFrom(actualMatrixType)); + + for (int rows : new int[] {1, 2}) + for (int cols : new int[] {1, 2}) { + final Matrix actualMatrix = v.likeMatrix(rows, cols); + + String details = "rows " + rows + " cols " + cols; + + assertNotNull("Expect non-null matrix for " + details + " in " + desc, + actualMatrix); + + assertEquals("Unexpected number of rows in " + desc, rows, actualMatrix.rowSize()); + + assertEquals("Unexpected number of cols in " + desc, cols, actualMatrix.columnSize()); + } + }); + } + + /** */ + @Test + public void testToMatrix() { + consumeSampleVectors((v, desc) -> { + if (!availableForTesting(v)) + return; + + fillWithNonZeroes(v); + + final Matrix matrixRow = v.toMatrix(true); + + final Matrix matrixCol = v.toMatrix(false); + + for (Vector.Element e : v.all()) + assertToMatrixValue(desc, matrixRow, matrixCol, e.get(), e.index()); + }); + } + + /** */ + @Test + public void testToMatrixPlusOne() { + consumeSampleVectors((v, desc) -> { + if (!availableForTesting(v)) + return; + + fillWithNonZeroes(v); + + for (double zeroVal : new double[] {-1, 0, 1, 2}) { + final Matrix matrixRow = v.toMatrixPlusOne(true, zeroVal); + + final Matrix matrixCol = v.toMatrixPlusOne(false, zeroVal); + + final Metric metricRow0 = new Metric(zeroVal, matrixRow.get(0, 0)); + + assertTrue("Not close enough row like " + metricRow0 + " at index 0 in " + desc, + metricRow0.closeEnough()); + + final Metric metricCol0 = new Metric(zeroVal, matrixCol.get(0, 0)); + + assertTrue("Not close enough cols like " + metricCol0 + " at index 0 in " + desc, + metricCol0.closeEnough()); + + for (Vector.Element e : v.all()) + assertToMatrixValue(desc, matrixRow, matrixCol, e.get(), e.index() + 1); + } + }); + } + + /** */ + @Test + public void testCross() { + consumeSampleVectors((v, desc) -> { + if (!availableForTesting(v)) + return; + + fillWithNonZeroes(v); + + for (int delta : new int[] {-1, 0, 1}) { + final int size2 = v.size() + delta; + + if (size2 < 1) + return; + + final Vector v2 = new DenseLocalOnHeapVector(size2); + + for (Vector.Element e : v2.all()) + e.set(size2 - e.index()); + + assertCross(v, v2, desc); + } + }); + } + + /** */ + private void assertCross(Vector v1, Vector v2, String desc) { + assertNotNull(v1); + assertNotNull(v2); + + final Matrix res = v1.cross(v2); + + assertNotNull("Cross matrix is expected to be not null in " + desc, res); + + assertEquals("Unexpected number of rows in cross Matrix in " + desc, v1.size(), res.rowSize()); + + assertEquals("Unexpected number of cols in cross Matrix in " + desc, v2.size(), res.columnSize()); + + for (int row = 0; row < v1.size(); row++) + for (int col = 0; col < v2.size(); col++) { + final Metric metric = new Metric(v1.get(row) * v2.get(col), res.get(row, col)); + + assertTrue("Not close enough cross " + metric + " at row " + row + " at col " + col + + " in " + desc, metric.closeEnough()); + } + } + + /** */ + private void assertToMatrixValue(String desc, Matrix matrixRow, Matrix matrixCol, double exp, int idx) { + final Metric metricRow = new Metric(exp, matrixRow.get(0, idx)); + + assertTrue("Not close enough row like " + metricRow + " at index " + idx + " in " + desc, + metricRow.closeEnough()); + + final Metric metricCol = new Metric(exp, matrixCol.get(idx, 0)); + + assertTrue("Not close enough cols like " + matrixCol + " at index " + idx + " in " + desc, + metricCol.closeEnough()); + } + + /** */ + private void fillWithNonZeroes(Vector sample) { + if (sample instanceof RandomVector) + return; + + for (Vector.Element e : sample.all()) + e.set(1 + e.index()); + } + + /** */ + private boolean availableForTesting(Vector v) { + assertNotNull("Error in test: vector is null", v); + + if (!likeMatrixSupported(v.getClass())) + return false; + + final boolean availableForTesting = typesMap.get(v.getClass()) != null; + + final Matrix actualLikeMatrix = v.likeMatrix(1, 1); + + assertTrue("Need to enable matrix testing for vector type " + v.getClass().getSimpleName(), + availableForTesting || actualLikeMatrix == null); + + return availableForTesting; + } + + /** Ignore test for given vector type. */ + private boolean likeMatrixSupported(Class<? extends Vector> clazz) { + for (Class<? extends Vector> ignoredClass : likeMatrixUnsupported) + if (ignoredClass.isAssignableFrom(clazz)) + return false; + + return true; + } + + /** */ + private void consumeSampleVectors(BiConsumer<Vector, String> consumer) { + new VectorImplementationsFixtures().consumeSampleVectors(null, consumer); + } + + /** */ + private static Map<Class<? extends Vector>, Class<? extends Matrix>> typesMap() { + return new LinkedHashMap<Class<? extends Vector>, Class<? extends Matrix>>() {{ + put(DenseLocalOnHeapVector.class, DenseLocalOnHeapMatrix.class); + put(DenseLocalOffHeapVector.class, DenseLocalOffHeapMatrix.class); + put(RandomVector.class, RandomMatrix.class); + put(SparseLocalVector.class, SparseLocalOnHeapMatrix.class); + put(SingleElementVector.class, null); // todo find out if we need SingleElementMatrix to match, or skip it + put(ConstantVector.class, null); + put(FunctionVector.class, null); + put(PivotedVectorView.class, DenseLocalOnHeapMatrix.class); // IMPL NOTE per fixture + put(SingleElementVectorView.class, null); + put(MatrixVectorView.class, DenseLocalOnHeapMatrix.class); // IMPL NOTE per fixture + put(DelegatingVector.class, DenseLocalOnHeapMatrix.class); // IMPL NOTE per fixture + // IMPL NOTE check for presence of all implementations here will be done in testHaveLikeMatrix via Fixture + }}; + } + + /** */ + private static class Metric { // todo consider if softer tolerance (like say 0.1 or 0.01) would make sense here + /** */ + private final double exp; + + /** */ + private final double obtained; + + /** **/ + Metric(double exp, double obtained) { + this.exp = exp; + this.obtained = obtained; + } + + /** */ + boolean closeEnough() { + return new Double(exp).equals(obtained); + } + + /** {@inheritDoc} */ + @Override public String toString() { + return "Metric{" + "expected=" + exp + + ", obtained=" + obtained + + '}'; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/0abf6601/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorViewTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorViewTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorViewTest.java new file mode 100644 index 0000000..ad2bc3f --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/vector/VectorViewTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.impls.vector; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.function.BiConsumer; +import java.util.stream.IntStream; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; +import org.apache.ignite.ml.math.impls.MathTestConstants; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Unit tests for {@link VectorView}. + */ +public class VectorViewTest { + /** */ + private static final int OFFSET = 10; + + /** */ + private static final int VIEW_LENGTH = 80; + + /** */ + private static final String EXTERNALIZE_TEST_FILE_NAME = "externalizeTest"; + + /** */ + private VectorView testVector; + + /** */ + private DenseLocalOnHeapVector parentVector; + + /** */ + private double[] parentData; + + /** */ + @Before + public void setup() { + parentVector = new DenseLocalOnHeapVector(MathTestConstants.STORAGE_SIZE); + + IntStream.range(0, MathTestConstants.STORAGE_SIZE).forEach(idx -> parentVector.set(idx, Math.random())); + + parentData = parentVector.getStorage().data().clone(); + + testVector = new VectorView(parentVector, OFFSET, VIEW_LENGTH); + } + + /** */ + @AfterClass + public static void cleanup() throws IOException { + Files.deleteIfExists(Paths.get(EXTERNALIZE_TEST_FILE_NAME)); + } + + /** */ + @Test + public void testCopy() throws Exception { + Vector cp = testVector.copy(); + + assertTrue(MathTestConstants.VAL_NOT_EQUALS, cp.equals(testVector)); + } + + /** */ + @Test(expected = UnsupportedOperationException.class) + public void testLike() throws Exception { + for (int card : new int[] {1, 2, 4, 8, 16, 32, 64, 128}) + consumeSampleVectors((v, desc) -> { + Vector vLike = new VectorView(v, 0, 1).like(card); + + Class<? extends Vector> expType = v.getClass(); + + assertNotNull("Expect non-null like vector for " + expType.getSimpleName() + " in " + desc, vLike); + + assertEquals("Expect size equal to cardinality at " + desc, card, vLike.size()); + + Class<? extends Vector> actualType = vLike.getClass(); + + assertTrue("Expected matrix type " + expType.getSimpleName() + + " should be assignable from actual type " + actualType.getSimpleName() + " in " + desc, + expType.isAssignableFrom(actualType)); + + }); + } + + /** See also {@link VectorToMatrixTest#testLikeMatrix()}. */ + @Test + public void testLikeMatrix() { + consumeSampleVectors((v, desc) -> { + boolean expECaught = false; + + try { + assertNull("Null view instead of exception in " + desc, new VectorView(v, 0, 1).likeMatrix(1, 1)); + } + catch (UnsupportedOperationException uoe) { + expECaught = true; + } + + assertTrue("Expected exception was not caught in " + desc, expECaught); + }); + } + + /** */ + @Test + public void testWriteReadExternal() throws Exception { + assertNotNull("Unexpected null parent data", parentData); + + File f = new File(EXTERNALIZE_TEST_FILE_NAME); + + try { + ObjectOutputStream objOutputStream = new ObjectOutputStream(new FileOutputStream(f)); + + objOutputStream.writeObject(testVector); + + objOutputStream.close(); + + ObjectInputStream objInputStream = new ObjectInputStream(new FileInputStream(f)); + + VectorView readVector = (VectorView)objInputStream.readObject(); + + objInputStream.close(); + + assertTrue(MathTestConstants.VAL_NOT_EQUALS, testVector.equals(readVector)); + } + catch (ClassNotFoundException | IOException e) { + fail(e.getMessage()); + } + } + + /** */ + private void consumeSampleVectors(BiConsumer<Vector, String> consumer) { + new VectorImplementationsFixtures().consumeSampleVectors(null, consumer); + } + +}