This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 38ec722a53 [MINOR] Adding a factory method for MatrixSketch
38ec722a53 is described below

commit 38ec722a53557037b79eb6259ce17cf2b850af4e
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Fri Nov 25 11:02:30 2022 -0800

    [MINOR] Adding a factory method for MatrixSketch
    
    This patch introduces a factory method for sketches.
    This will centralize the creation of all sketches in one place and
    prevent duplication of operator switching and validation logic.
    
    Closes #1738
---
 .../spark/AggregateUnarySketchSPInstruction.java   | 16 ++++----
 .../matrix/data/LibMatrixCountDistinct.java        | 40 ++++++--------------
 .../runtime/matrix/data/sketch/SketchFactory.java  | 44 ++++++++++++++++++++++
 3 files changed, 63 insertions(+), 37 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
index 767e4b0c0b..bfdecc635a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
@@ -117,7 +117,7 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
                     out1.fold(new CorrMatrixBlock(new MatrixBlock()),
                               new 
AggregateUnarySketchUnionAllFunction(this.op));
 
-            MatrixBlock out3 = 
LibMatrixCountDistinct.countDistinctValuesFromSketch(out2, this.op);
+            MatrixBlock out3 = 
LibMatrixCountDistinct.countDistinctValuesFromSketch(this.op, out2);
 
             // put output block into symbol table (no lineage because single 
block)
             // this also includes implicit maintenance of matrix 
characteristics
@@ -180,7 +180,7 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
             MatrixIndexes ixOut = new MatrixIndexes();
             this.op.indexFn.execute(ixIn, ixOut);
 
-            return LibMatrixCountDistinct.createSketch(blkIn, this.op);
+            return LibMatrixCountDistinct.createSketch(this.op, blkIn);
         }
     }
 
@@ -207,7 +207,7 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
                 return arg0;
             }
 
-            return LibMatrixCountDistinct.unionSketch(arg0, arg1, this.op);
+            return LibMatrixCountDistinct.unionSketch(this.op, arg0, arg1);
         }
     }
 
@@ -246,7 +246,7 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
         public CorrMatrixBlock call(MatrixBlock arg0)
                 throws Exception {
 
-            return LibMatrixCountDistinct.createSketch(arg0, this.op);
+            return LibMatrixCountDistinct.createSketch(this.op, arg0);
         }
     }
 
@@ -261,8 +261,8 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
 
         @Override
         public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1) 
throws Exception {
-            CorrMatrixBlock arg1WithCorr = 
LibMatrixCountDistinct.createSketch(arg1, this.op);
-            return LibMatrixCountDistinct.unionSketch(arg0, arg1WithCorr, 
this.op);
+            CorrMatrixBlock arg1WithCorr = 
LibMatrixCountDistinct.createSketch(this.op, arg1);
+            return LibMatrixCountDistinct.unionSketch(this.op, arg0, 
arg1WithCorr);
         }
     }
 
@@ -277,7 +277,7 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
 
         @Override
         public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock 
arg1) throws Exception {
-            return LibMatrixCountDistinct.unionSketch(arg0, arg1, this.op);
+            return LibMatrixCountDistinct.unionSketch(this.op, arg0, arg1);
         }
     }
 
@@ -292,7 +292,7 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
 
         @Override
         public MatrixBlock call(CorrMatrixBlock arg0) throws Exception {
-            return LibMatrixCountDistinct.countDistinctValuesFromSketch(arg0, 
this.op);
+            return 
LibMatrixCountDistinct.countDistinctValuesFromSketch(this.op, arg0);
         }
     }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index ccddb4db80..72bcd64b43 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -31,8 +31,8 @@ import org.apache.sysds.api.DMLException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.data.*;
 import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
-import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
-import 
org.apache.sysds.runtime.matrix.data.sketch.countdistinct.CountDistinctFunctionSketch;
+import org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch;
+import org.apache.sysds.runtime.matrix.data.sketch.SketchFactory;
 import 
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
@@ -356,36 +356,18 @@ public interface LibMatrixCountDistinct {
                return distinct.size();
        }
 
-       static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock arg0, 
CountDistinctOperator op) {
-               if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
-                       return new 
CountDistinctFunctionSketch(op).getValueFromSketch(arg0);
-               else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
-                       return new KMVSketch(op).getValueFromSketch(arg0);
-               else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
-                       throw new NotImplementedException("Not implemented 
yet");
-               else
-                       throw new NotImplementedException("Not implemented 
yet");
+       static MatrixBlock countDistinctValuesFromSketch(CountDistinctOperator 
op, CorrMatrixBlock corrBlkIn) {
+               MatrixSketch sketch = SketchFactory.get(op);
+               return sketch.getValueFromSketch(corrBlkIn);
        }
 
-       static CorrMatrixBlock createSketch(MatrixBlock blkIn, 
CountDistinctOperator op) {
-               if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
-                       return new 
CountDistinctFunctionSketch(op).create(blkIn);
-               else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
-                       return new KMVSketch(op).create(blkIn);
-               else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
-                       throw new NotImplementedException("Not implemented 
yet");
-               else
-                       throw new NotImplementedException("Not implemented 
yet");
+       static CorrMatrixBlock createSketch(CountDistinctOperator op, 
MatrixBlock blkIn) {
+               MatrixSketch sketch = SketchFactory.get(op);
+               return sketch.create(blkIn);
        }
 
-       static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0, 
CorrMatrixBlock arg1, CountDistinctOperator op) {
-               if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
-                       return new CountDistinctFunctionSketch(op).union(arg0, 
arg1);
-               else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
-                       return new KMVSketch(op).union(arg0, arg1);
-               else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
-                       throw new NotImplementedException("Not implemented 
yet");
-               else
-                       throw new NotImplementedException("Not implemented 
yet");
+       static CorrMatrixBlock unionSketch(CountDistinctOperator op, 
CorrMatrixBlock corrBlkIn0, CorrMatrixBlock corrBlkIn1) {
+               MatrixSketch sketch = SketchFactory.get(op);
+               return sketch.union(corrBlkIn0, corrBlkIn1);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/SketchFactory.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/SketchFactory.java
new file mode 100644
index 0000000000..434582374d
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/SketchFactory.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch;
+
+import org.apache.commons.lang.NotImplementedException;
+import 
org.apache.sysds.runtime.matrix.data.sketch.countdistinct.CountDistinctFunctionSketch;
+import 
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class SketchFactory {
+       public static MatrixSketch get(Operator op) {
+               if (op instanceof CountDistinctOperator) {
+                       CountDistinctOperator cdop = (CountDistinctOperator) op;
+                       if (cdop.getOperatorType() == 
CountDistinctOperatorTypes.COUNT) {
+                               return new CountDistinctFunctionSketch(op);
+                       } else if (cdop.getOperatorType() == 
CountDistinctOperatorTypes.KMV) {
+                               return new KMVSketch(op);
+                       } else {
+                               throw new NotImplementedException("Only COUNT 
and KMV count distinct sketches are supported for now");
+                       }
+               } else {
+                       throw new IllegalArgumentException("Only sketches for 
count distinct operators are supported for now");
+               }
+       }
+}

Reply via email to