This is an automated email from the ASF dual-hosted git repository. zaleslaw pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push: new e302fdb IGNITE-13386: add BrayCurtis,Canberra,JensenShannon,WeightedMinkowski distances (#8197) e302fdb is described below commit e302fdb7f7386e15e711c0f3de7bd1ad3ec84667 Author: Mark Andreev <mrk.andreev+git...@yandex.ru> AuthorDate: Fri Oct 9 12:08:09 2020 +0300 IGNITE-13386: add BrayCurtis,Canberra,JensenShannon,WeightedMinkowski distances (#8197) --- .../ml/math/distances/BrayCurtisDistance.java | 54 ++++++++++ .../ignite/ml/math/distances/CanberraDistance.java | 64 ++++++++++++ .../ml/math/distances/JensenShannonDistance.java | 91 +++++++++++++++++ .../math/distances/WeightedMinkowskiDistance.java | 73 ++++++++++++++ .../ml/math/distances/BrayCurtisDistanceTest.java | 103 +++++++++++++++++++ .../ml/math/distances/CanberraDistanceTest.java | 103 +++++++++++++++++++ .../ignite/ml/math/distances/DistanceTest.java | 55 ++++++++++ .../math/distances/JensenShannonDistanceTest.java | 105 +++++++++++++++++++ .../distances/WeightedMinkowskiDistanceTest.java | 112 +++++++++++++++++++++ 9 files changed, 760 insertions(+) diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/BrayCurtisDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/BrayCurtisDistance.java new file mode 100644 index 0000000..0b43159 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/BrayCurtisDistance.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import org.apache.ignite.ml.math.exceptions.math.CardinalityException; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.util.MatrixUtil; + +/** + * Calculates the Bray Curtis distance between two points. + * + * @see <a href="https://en.wikipedia.org/wiki/Bray%E2%80%93Curtis_dissimilarity"> + * Bray–Curtis dissimilarity</a> + */ +public class BrayCurtisDistance implements DistanceMeasure { + /** Serializable version identifier. */ + private static final long serialVersionUID = 1771556549784040091L; + + /** {@inheritDoc} */ + @Override public double compute(Vector a, Vector b) + throws CardinalityException { + double diff = MatrixUtil.localCopyOf(a).minus(b).kNorm(1); + double sum = MatrixUtil.localCopyOf(a).plus(b).kNorm(1); + + return diff / sum; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + + return obj != null && getClass() == obj.getClass(); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + return getClass().hashCode(); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/CanberraDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/CanberraDistance.java new file mode 100644 index 0000000..66e349c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/CanberraDistance.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import org.apache.ignite.ml.math.exceptions.math.CardinalityException; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.util.MatrixUtil; + +/** + * Calculates the Canberra distance between two points. + * + * @see <a href="https://en.wikipedia.org/wiki/Canberra_distance">Canberra distance</a> + */ +public class CanberraDistance implements DistanceMeasure { + /** + * Serializable version identifier. + */ + private static final long serialVersionUID = 1771556549784040092L; + + /** + * {@inheritDoc} + */ + @Override public double compute(Vector a, Vector b) + throws CardinalityException { + Vector top = MatrixUtil.localCopyOf(a).minus(b).map(Math::abs); + Vector down = MatrixUtil.localCopyOf(a).map(Math::abs) + .plus(MatrixUtil.localCopyOf(b).map(Math::abs)) + .map(value -> value != 0 ? 1 / value : 0); + + return top.times(down).sum(); + } + + /** + * {@inheritDoc} + */ + @Override public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + return obj != null && getClass() == obj.getClass(); + } + + /** + * {@inheritDoc} + */ + @Override public int hashCode() { + return getClass().hashCode(); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/JensenShannonDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/JensenShannonDistance.java new file mode 100644 index 0000000..ec48888 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/JensenShannonDistance.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import org.apache.ignite.ml.math.exceptions.math.CardinalityException; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.util.MatrixUtil; + +/** + * Calculates the JensenShannonDistance distance between two points. + * + * @see <a href="https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence"> + * Jensen–Shannon divergence</a> + */ +public class JensenShannonDistance implements DistanceMeasure { + /** + * Serializable version identifier. + */ + private static final long serialVersionUID = 1771556549784040093L; + + private final Double base; + + public JensenShannonDistance() { + base = Math.E; + } + + public JensenShannonDistance(Double base) { + this.base = base; + } + + /** + * {@inheritDoc} + */ + @Override public double compute(Vector a, Vector b) + throws CardinalityException { + Vector aNormalized = MatrixUtil.localCopyOf(a).divide(a.sum()); + Vector bNormalized = MatrixUtil.localCopyOf(b).divide(b.sum()); + + Vector mean = aNormalized.plus(bNormalized).divide(2d); + + double js = aNormalized.map(mean, this::relativeEntropy).sum() + + bNormalized.map(mean, this::relativeEntropy).sum(); + + js /= Math.log(base); + + return Math.sqrt(js / 2d); + } + + private double relativeEntropy(double x, double y) { + if (x > 0 && y > 0) { + return x * Math.log(x / y); + } + if (x == 0 && y >= 0) { + return 0; + } + + return Double.POSITIVE_INFINITY; + } + + /** + * {@inheritDoc} + */ + @Override public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + return obj != null && getClass() == obj.getClass(); + } + + /** + * {@inheritDoc} + */ + @Override public int hashCode() { + return getClass().hashCode(); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistance.java new file mode 100644 index 0000000..662bf90 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistance.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import org.apache.ignite.ml.math.exceptions.math.CardinalityException; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.util.MatrixUtil; + +/** + * Calculates the Weighted Minkowski distance between two points. + */ +public class WeightedMinkowskiDistance implements DistanceMeasure { + /** + * Serializable version identifier. + */ + private static final long serialVersionUID = 1771556549784040096L; + + private final int p; + + private final Vector weight; + + public WeightedMinkowskiDistance(int p, Vector weight) { + this.p = p; + this.weight = weight.copy().map(x -> Math.pow(Math.abs(x), p)); + } + + /** + * {@inheritDoc} + */ + @Override public double compute(Vector a, Vector b) + throws CardinalityException { + + return Math.pow( + MatrixUtil.localCopyOf(a).minus(b) + .map(x -> Math.pow(Math.abs(x), p)) + .times(weight) + .sum(), + 1 / (double) p + ); + } + + /** + * {@inheritDoc} + */ + @Override public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + return obj != null && getClass() == obj.getClass(); + } + + /** + * {@inheritDoc} + */ + @Override public int hashCode() { + return getClass().hashCode(); + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/BrayCurtisDistanceTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/BrayCurtisDistanceTest.java new file mode 100644 index 0000000..57ca7de --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/BrayCurtisDistanceTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import java.util.Arrays; +import java.util.Collection; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertEquals; + +/** + * Evaluate BrayCurtisDistance in multiple test datasets + */ +@RunWith(Parameterized.class) +public class BrayCurtisDistanceTest { + /** Precision. */ + private static final double PRECISION = 0.01; + + /** */ + @Parameterized.Parameters(name = "{0}") + public static Collection<TestData> data() { + return Arrays.asList( + new TestData( + new double[] {0, 0, 0}, + new double[] {2, 1, 0}, + 1.0 + ), + new TestData( + new double[] {1, 2, 3}, + new double[] {2, 1, 0}, + 0.55 + ), + new TestData( + new double[] {1, 2, 3}, + new double[] {2, 1, 50}, + 0.83 + ), + new TestData( + new double[] {1, -100, 3}, + new double[] {2, 1, -50}, + 1.04 + ) + ); + } + + private final TestData testData; + + /** */ + public BrayCurtisDistanceTest(TestData testData) { + this.testData = testData; + } + + /** */ + @Test + public void testBrayCurtisDistance() { + DistanceMeasure distanceMeasure = new BrayCurtisDistance(); + + assertEquals(testData.expRes, + distanceMeasure.compute(testData.vectorA, testData.vectorB), PRECISION); + assertEquals(testData.expRes, + distanceMeasure.compute(testData.vectorA, testData.vectorB), PRECISION); + } + + private static class TestData { + public final Vector vectorA; + + public final Vector vectorB; + + public final double expRes; + + private TestData(double[] vectorA, double[] vectorB, double expRes) { + this.vectorA = new DenseVector(vectorA); + this.vectorB = new DenseVector(vectorB); + this.expRes = expRes; + } + + @Override public String toString() { + return String.format("d(%s,%s) = %s", + Arrays.toString(vectorA.asArray()), + Arrays.toString(vectorB.asArray()), + expRes + ); + } + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/CanberraDistanceTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/CanberraDistanceTest.java new file mode 100644 index 0000000..c40251b --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/CanberraDistanceTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import java.util.Arrays; +import java.util.Collection; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertEquals; + +/** + * Evaluate CanberraDistance in multiple test datasets + */ +@RunWith(Parameterized.class) +public class CanberraDistanceTest { + /** Precision. */ + private static final double PRECISION = 0.01; + + /** */ + @Parameterized.Parameters(name = "{0}") + public static Collection<TestData> data() { + return Arrays.asList( + new TestData( + new double[] {0, 0, 0}, + new double[] {2, 1, 0}, + 2.0 + ), + new TestData( + new double[] {1, 2, 3}, + new double[] {2, 1, 0}, + 1.66 + ), + new TestData( + new double[] {1, 2, 3}, + new double[] {2, 1, 50}, + 1.55 + ), + new TestData( + new double[] {1, -100, 3}, + new double[] {2, 1, -50}, + 2.33 + ) + ); + } + + private final TestData testData; + + /** */ + public CanberraDistanceTest(TestData testData) { + this.testData = testData; + } + + /** */ + @Test + public void testCanberraDistance() { + DistanceMeasure distanceMeasure = new CanberraDistance(); + + assertEquals(testData.expRes, + distanceMeasure.compute(testData.vectorA, testData.vectorB), PRECISION); + assertEquals(testData.expRes, + distanceMeasure.compute(testData.vectorA, testData.vectorB), PRECISION); + } + + private static class TestData { + public final Vector vectorA; + + public final Vector vectorB; + + public final double expRes; + + private TestData(double[] vectorA, double[] vectorB, double expRes) { + this.vectorA = new DenseVector(vectorA); + this.vectorB = new DenseVector(vectorB); + this.expRes = expRes; + } + + @Override public String toString() { + return String.format("d(%s,%s) = %s", + Arrays.toString(vectorA.asArray()), + Arrays.toString(vectorB.asArray()), + expRes + ); + } + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java index d3b8058..0be0b54 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java @@ -40,6 +40,10 @@ public class DistanceTest { new EuclideanDistance(), new HammingDistance(), new ManhattanDistance(), + new BrayCurtisDistance(), + new CanberraDistance(), + new JensenShannonDistance(), + new WeightedMinkowskiDistance(4, new DenseVector(new double[]{1, 1, 1})), new MinkowskiDistance(Math.random())); /** */ @@ -150,6 +154,57 @@ public class DistanceTest { assertEquals(expRes, distanceMeasure.compute(v1, v2), PRECISION); } + /** */ + @Test + public void brayCurtisDistance() { + double expRes = 1.0; + + DistanceMeasure distanceMeasure = new BrayCurtisDistance(); + + assertEquals(expRes, distanceMeasure.compute(v1, data2), PRECISION); + assertEquals(expRes, distanceMeasure.compute(v1, v2), PRECISION); + } + + /** */ + @Test + public void canberraDistance() { + double expRes = 2.0; + + DistanceMeasure distanceMeasure = new CanberraDistance(); + + assertEquals(expRes, distanceMeasure.compute(v1, data2), PRECISION); + assertEquals(expRes, distanceMeasure.compute(v1, v2), PRECISION); + } + + /** */ + @Test + public void jensenShannonDistance() { + double precistion = 0.01; + double expRes = 0.83; + double[] pData = new double[] {1.0, 0.0, 0.0}; + Vector pV1 = new DenseVector(new double[] {0.0, 1.0, 0.0}); + Vector pV2 = new DenseVector(pData); + + DistanceMeasure distanceMeasure = new JensenShannonDistance(); + + assertEquals(expRes, distanceMeasure.compute(pV1, pData), precistion); + assertEquals(expRes, distanceMeasure.compute(pV1, pV2), precistion); + } + + /** */ + @Test + public void weightedMinkowskiDistance() { + double precistion = 0.01; + int p = 2; + double expRes = 5.0; + Vector v = new DenseVector(new double[]{2, 3, 4}); + + DistanceMeasure distanceMeasure = new WeightedMinkowskiDistance(p, v); + + assertEquals(expRes, distanceMeasure.compute(v1, data2), precistion); + assertEquals(expRes, distanceMeasure.compute(v1, v2), precistion); + } + /** Returns a random vector */ private static Vector randomVector(int length) { double[] vec = new double[length]; diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/JensenShannonDistanceTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/JensenShannonDistanceTest.java new file mode 100644 index 0000000..763a362 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/JensenShannonDistanceTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import java.util.Arrays; +import java.util.Collection; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertEquals; + +/** + * Evaluate JensenShannonDistance in multiple test datasets + */ +@RunWith(Parameterized.class) +public class JensenShannonDistanceTest { + /** Precision. */ + private static final double PRECISION = 0.01; + + /** */ + @Parameterized.Parameters(name = "{0}") + public static Collection<TestData> data() { + return Arrays.asList( + new TestData( + new double[] {1.0, 0.0, 0.0}, + new double[] {0.0, 1.0, 0.0}, + 2.0, + 1.0 + ), + new TestData( + new double[] {1.0, 0.0}, + new double[] {0.5, 0.5}, + Math.E, + 0.46 + ), + new TestData( + new double[] {1.0, 0.0, 0.0}, + new double[] {1.0, 0.5, 0.0}, + Math.E, + 0.36 + ) + ); + } + + private final TestData testData; + + /** */ + public JensenShannonDistanceTest(TestData testData) { + this.testData = testData; + } + + /** */ + @Test + public void testJensenShannonDistance() { + DistanceMeasure distanceMeasure = new JensenShannonDistance(testData.base); + + assertEquals(testData.expRes, + distanceMeasure.compute(testData.vectorA, testData.vectorB), PRECISION); + assertEquals(testData.expRes, + distanceMeasure.compute(testData.vectorA, testData.vectorB), PRECISION); + } + + private static class TestData { + public final Vector vectorA; + + public final Vector vectorB; + + public final Double expRes; + + public final Double base; + + private TestData(double[] vectorA, double[] vectorB, Double base, Double expRes) { + this.vectorA = new DenseVector(vectorA); + this.vectorB = new DenseVector(vectorB); + this.base = base; + this.expRes = expRes; + } + + @Override public String toString() { + return String.format("d(%s,%s;%s) = %s", + Arrays.toString(vectorA.asArray()), + Arrays.toString(vectorB.asArray()), + base, + expRes + ); + } + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistanceTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistanceTest.java new file mode 100644 index 0000000..1ab93a1 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistanceTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import java.util.Arrays; +import java.util.Collection; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertEquals; + +/** + * Evaluate WeightedMinkowski in multiple test datasets + */ +@RunWith(Parameterized.class) +public class WeightedMinkowskiDistanceTest { + /** Precision. */ + private static final double PRECISION = 0.01; + + /** */ + @Parameterized.Parameters(name = "{0}") + public static Collection<TestData> data() { + return Arrays.asList( + new TestData( + new double[] {1.0, 0.0, 0.0}, + new double[] {0.0, 1.0, 0.0}, + 1, + new double[] {2.0, 3.0, 4.0}, + 5.0 + ), + new TestData( + new double[] {1.0, 0.0, 0.0}, + new double[] {0.0, 1.0, 0.0}, + 2, + new double[] {2.0, 3.0, 4.0}, + 3.60 + ), + new TestData( + new double[] {1.0, 0.0, 0.0}, + new double[] {0.0, 1.0, 0.0}, + 3, + new double[] {2.0, 3.0, 4.0}, + 3.27 + ) + ); + } + + private final TestData testData; + + /** */ + public WeightedMinkowskiDistanceTest(TestData testData) { + this.testData = testData; + } + + /** */ + @Test + public void testWeightedMinkowski() { + DistanceMeasure distanceMeasure = new WeightedMinkowskiDistance(testData.p, testData.weight); + + assertEquals(testData.expRes, + distanceMeasure.compute(testData.vectorA, testData.vectorB), PRECISION); + assertEquals(testData.expRes, + distanceMeasure.compute(testData.vectorA, testData.vectorB), PRECISION); + } + + private static class TestData { + public final Vector vectorA; + + public final Vector vectorB; + + public final Integer p; + + public final Vector weight; + + public final Double expRes; + + private TestData(double[] vectorA, double[] vectorB, Integer p, double[] weight, double expRes) { + this.vectorA = new DenseVector(vectorA); + this.vectorB = new DenseVector(vectorB); + this.p = p; + this.weight = new DenseVector(weight); + this.expRes = expRes; + } + + @Override public String toString() { + return String.format("d(%s,%s;%s,%s) = %s", + Arrays.toString(vectorA.asArray()), + Arrays.toString(vectorB.asArray()), + p, + Arrays.toString(weight.asArray()), + expRes + ); + } + } +}