This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 09493ef  [SYSTEMDS-3152] Min/max support in grouped aggregates (CP and 
Spark)
09493ef is described below

commit 09493efdebdbb03afb798418d2047290c7f5ebfa
Author: Thomas Reichel <[email protected]>
AuthorDate: Thu Mar 3 21:01:35 2022 +0100

    [SYSTEMDS-3152] Min/max support in grouped aggregates (CP and Spark)
    
    DIA project WS2021/22
    Closes #1507.
    
    Co-authored-by: Thomas Moder <[email protected]>
    Co-authored-by: burimvrella <[email protected]>
---
 .../ParameterizedBuiltinFunctionExpression.java    |   4 +-
 .../java/org/apache/sysds/parser/Statement.java    |   2 +
 .../apache/sysds/runtime/functionobjects/CM.java   |  36 ++++++
 .../runtime/instructions/InstructionUtils.java     |  28 ++--
 .../runtime/instructions/cp/CM_COV_Object.java     |  23 +++-
 .../spark/ParameterizedBuiltinSPInstruction.java   |   6 +
 .../sysds/runtime/matrix/operators/CMOperator.java |   8 +-
 .../aggregate/FullGroupedAggregateMatrixTest.java  |  69 ++++++----
 .../aggregate/FullGroupedAggregateTest.java        | 143 ++++++++++++++++-----
 .../scripts/functions/aggregate/GroupedAggregate.R |  10 ++
 .../functions/aggregate/GroupedAggregate.dml       |   8 ++
 .../functions/aggregate/GroupedAggregateMatrix.R   |  10 ++
 .../functions/aggregate/GroupedAggregateMatrix.dml |   8 ++
 .../aggregate/GroupedAggregateMatrixNoDims.dml     |   8 ++
 .../functions/aggregate/GroupedAggregateWeights.R  |  10 ++
 .../aggregate/GroupedAggregateWeights.dml          |   9 ++
 16 files changed, 312 insertions(+), 70 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 6b6ca9b..a6bcc2d 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -790,7 +790,9 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                        else if (fnameStr.equals(Statement.GAGG_FN_COUNT) 
                                        || 
fnameStr.equals(Statement.GAGG_FN_SUM) 
                                        || 
fnameStr.equals(Statement.GAGG_FN_MEAN)
-                                       || 
fnameStr.equals(Statement.GAGG_FN_VARIANCE)){}
+                                       || 
fnameStr.equals(Statement.GAGG_FN_VARIANCE)
+                                       || 
fnameStr.equals(Statement.GAGG_FN_MIN)
+                                       || 
fnameStr.equals(Statement.GAGG_FN_MAX)){}
                        else { 
                                raiseValidateError("fname is " + fnameStr + " 
but must be either centeralmoment, count, sum, mean, variance", conditional);
                        }
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java 
b/src/main/java/org/apache/sysds/parser/Statement.java
index 4b0237d..995a1e2 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -57,6 +57,8 @@ public abstract class Statement implements ParseInfo
        public static final String GAGG_FN_MEAN     = "mean";
        public static final String GAGG_FN_VARIANCE = "variance";
        public static final String GAGG_FN_CM       = "centralmoment";
+       public static final String GAGG_FN_MIN      = "min";
+       public static final String GAGG_FN_MAX      = "max";
        public static final String GAGG_FN_CM_ORDER = "order";
        public static final String GAGG_NUM_GROUPS  = "ngroups";
 
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java 
b/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
index ae3e718..54a2d83 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
@@ -93,6 +93,8 @@ public class CM extends ValueFunction
                if(cm1.isCMAllZeros()) {
                        cm1.w=1;
                        cm1.mean.set(in2, 0);
+                       cm1.min = in2;
+                       cm1.max = in2;
                        cm1.m2.set(0,0);
                        cm1.m3.set(0,0);
                        cm1.m4.set(0,0);
@@ -114,6 +116,16 @@ public class CM extends ValueFunction
                                cm1.w=w;
                                break;
                        }
+                       case MIN:
+                       {
+                               cm1.min = Math.min(cm1.min, in2);
+                               break;
+                       }
+                       case MAX:
+                       {
+                               cm1.max = Math.max(cm1.max, in2);
+                               break;
+                       }
                        case CM2:
                        {
                                double w= cm1.w + 1;
@@ -197,6 +209,8 @@ public class CM extends ValueFunction
                {
                        cm1.w=w2;
                        cm1.mean.set(in2, 0);
+                       cm1.min = in2 * w2;
+                       cm1.max = in2 * w2;
                        cm1.m2.set(0,0);
                        cm1.m3.set(0,0);
                        cm1.m4.set(0,0);
@@ -210,6 +224,16 @@ public class CM extends ValueFunction
                                cm1.w = Math.round(cm1.w + w2);
                                break;
                        }
+                       case MIN:
+                       {
+                               cm1.min = Math.min(cm1.min, in2 * w2);
+                               break;
+                       }
+                       case MAX:
+                       {
+                               cm1.max = Math.max(cm1.max, in2 * w2);
+                               break;
+                       }
                        case MEAN:
                        {
                                double w = cm1.w + w2;
@@ -303,6 +327,8 @@ public class CM extends ValueFunction
                {
                        cm1.w=cm2.w;
                        cm1.mean.set(cm2.mean);
+                       cm1.min = cm2.min;
+                       cm1.max = cm2.max;
                        cm1.m2.set(cm2.m2);
                        cm1.m3.set(cm2.m3);
                        cm1.m4.set(cm2.m4);
@@ -318,6 +344,16 @@ public class CM extends ValueFunction
                                cm1.w = Math.round(cm1.w + cm2.w);              
                
                                break;
                        }
+                       case MIN:
+                       {
+                               cm1.min = Math.min(cm1.min, cm2.min);
+                               break;
+                       }
+                       case MAX:
+                       {
+                               cm1.max = Math.max(cm1.max, cm2.max);
+                               break;
+                       }
                        case MEAN:
                        {
                                double w = cm1.w + cm2.w;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index f9d9ab3..f22fdfe 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -1000,19 +1000,23 @@ public class InstructionUtils
                        op = CMOperator.getAggOpType(fn, null);
        
                switch(op) {
-               case SUM:
-                       return new AggregateOperator(0, 
KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTCOLUMN);
+                       case SUM:
+                               return new AggregateOperator(0, 
KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTCOLUMN);
+                               
+                       case COUNT:
+                       case MEAN:
+                       case VARIANCE:
+                       case CM2:
+                       case CM3:
+                       case CM4:
                        
-               case COUNT:
-               case MEAN:
-               case VARIANCE:
-               case CM2:
-               case CM3:
-               case CM4:
-                       return new CMOperator(CM.getCMFnObject(op), op);
-               case INVALID:
-               default:
-                       throw new DMLRuntimeException("Invalid Aggregate 
Operation in GroupedAggregateInstruction: " + op);
+                       //TODO use appropriate function objects for min/max 
(see sum)
+                       case MIN:
+                       case MAX:
+                               return new CMOperator(CM.getCMFnObject(op), op);
+                       case INVALID:
+                       default:
+                               throw new DMLRuntimeException("Invalid 
Aggregate Operation in GroupedAggregateInstruction: " + op);
                }
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java
index 39abdcd..8591c3b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java
@@ -37,13 +37,15 @@ public class CM_COV_Object extends Data
        public KahanObject m2;
        public KahanObject m3;
        public KahanObject m4;
+       public double min;
+       public double max;
        
        public KahanObject mean_v;
        public KahanObject c2;
        
        @Override
        public String toString() {
-               return "weight: "+w+", mean: "+mean+", m2: "+m2+", m3: "+m3+", 
m4: "+m4+", mean2: "+mean_v+", c2: "+c2;
+               return "weight: "+w+", mean: "+mean+", m2: "+m2+", m3: "+m3+", 
m4: "+m4+", min: "+min+", max: "+max+", mean2: "+mean_v+", c2: "+c2;
        }
        
        public CM_COV_Object()
@@ -56,6 +58,8 @@ public class CM_COV_Object extends Data
                m4=new KahanObject(0,0);
                mean_v=new KahanObject(0,0);
                c2=new KahanObject(0,0);
+               min=0;
+               max=0;
        }
        
        public void reset()
@@ -67,6 +71,8 @@ public class CM_COV_Object extends Data
                m4=new KahanObject(0,0);
                mean_v=new KahanObject(0,0);
                c2=new KahanObject(0,0);
+               min=0;
+               max=0;
        }
        
        public int compareTo(CM_COV_Object that)
@@ -83,6 +89,10 @@ public class CM_COV_Object extends Data
                        return KahanObject.compare(m4, that.m4);
                else if(mean_v!=that.mean_v)
                        return KahanObject.compare(mean_v, that.mean_v);
+               else if(min!=that.min)
+                       return Double.compare(min, that.min);
+               else if(max!=that.max)
+                       return Double.compare(max, that.max);
                else
                        return KahanObject.compare(c2, that.c2);
        }
@@ -96,7 +106,8 @@ public class CM_COV_Object extends Data
                CM_COV_Object that = (CM_COV_Object)o;
                return (w==that.w && mean.equals(that.mean) && 
m2.equals(that.m2))
                                && m3.equals(that.m3) && m4.equals(that.m4) 
-                               && mean_v.equals(that.mean_v) && 
c2.equals(that.c2);
+                               && mean_v.equals(that.mean_v) && 
c2.equals(that.c2)
+                               && min==that.min && max == that.max;
        }
        
        @Override
@@ -113,11 +124,13 @@ public class CM_COV_Object extends Data
                this.m4.set(that.m4);
                this.mean_v.set(that.mean_v);
                this.c2.set(that.c2);
+               this.min=that.min;
+               this.max=that.max;
        }
        
        public boolean isCMAllZeros()
        {
-               return w==0 && mean.isAllZero() && m2.isAllZero()  && 
m3.isAllZero()  && m4.isAllZero() ;
+               return w==0 && mean.isAllZero() && m2.isAllZero()  && 
m3.isAllZero()  && m4.isAllZero() && min==0 && max==0;
        }
        
        public boolean isCOVAllZeros()
@@ -166,6 +179,10 @@ public class CM_COV_Object extends Data
                                return m3._sum/w;
                        case CM4:
                                return m4._sum/w;
+                       case MIN:
+                               return min;
+                       case MAX:
+                               return max;
                        case VARIANCE:
                                return w==1.0? 0:m2._sum/(w-1);
                        default:
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 1478c5f..0494f42 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -807,6 +807,12 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                                        case MEAN:
                                                val = kv.getValue();
                                                break;
+                                       case MIN:
+                                               val = kv.getValue();
+                                               break;
+                                       case MAX:
+                                               val = kv.getValue();
+                                               break;
                                        case CM2:
                                                val = kv.getValue() / 
kv.getWeight();
                                                break;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
index 579e681..7e56e0e 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
@@ -35,6 +35,8 @@ public class CMOperator extends Operator
                CM2,
                CM3,
                CM4,
+               MIN,
+               MAX,
                VARIANCE,
                INVALID
        }
@@ -103,6 +105,10 @@ public class CMOperator extends Operator
                                return AggregateOperationTypes.CM4;
                        else
                                return AggregateOperationTypes.INVALID;
+               } else if (fn.equalsIgnoreCase("min")) {
+                       return AggregateOperationTypes.MIN;
+               } else if (fn.equalsIgnoreCase("max")) {
+                       return AggregateOperationTypes.MAX;
                }
                return AggregateOperationTypes.INVALID;
        }
@@ -114,7 +120,7 @@ public class CMOperator extends Operator
                switch( aggOpType )
                {
                        case COUNT:
-                       case MEAN: 
+                       case MEAN:
                                ret = true; break;
                                
                        //NOTE: the following aggregation operators are not 
marked for partial aggregation 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateMatrixTest.java
 
b/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateMatrixTest.java
index dcf439a..10036cf 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateMatrixTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateMatrixTest.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.aggregate;
 import java.io.IOException;
 import java.util.HashMap;
 
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.common.Types.ValueType;
@@ -38,9 +37,6 @@ import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
 
-/**
- * 
- */
 public class FullGroupedAggregateMatrixTest extends AutomatedTestBase 
 {
        private final static String TEST_NAME1 = "GroupedAggregateMatrix";
@@ -66,9 +62,10 @@ public class FullGroupedAggregateMatrixTest extends 
AutomatedTestBase
                VARIANCE,
                MOMENT3,
                MOMENT4,
+               MIN,
+               MAX
        }
        
-       
        @Override
        public void setUp() 
        {
@@ -158,6 +155,26 @@ public class FullGroupedAggregateMatrixTest extends 
AutomatedTestBase
        }
 
        @Test
+       public void testGroupedAggMinDenseCP() {
+               runGroupedAggregateOperationTest(TEST_NAME1, OpType.MIN, false, 
ExecType.CP);
+       }
+
+       @Test
+       public void testGroupedAggMinSparseCP() {
+               runGroupedAggregateOperationTest(TEST_NAME1, OpType.MIN, true, 
ExecType.CP);
+       }
+
+       @Test
+       public void testGroupedAggMaxDenseCP() {
+               runGroupedAggregateOperationTest(TEST_NAME1, OpType.MAX, false, 
ExecType.CP);
+       }
+
+       @Test
+       public void testGroupedAggMaxSparseCP() {
+               runGroupedAggregateOperationTest(TEST_NAME1, OpType.MAX, true, 
ExecType.CP);
+       }
+
+       @Test
        public void testGroupedAggSumDenseWideCP() {
                runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, false, 
ExecType.CP, cols2);
        }
@@ -238,6 +255,26 @@ public class FullGroupedAggregateMatrixTest extends 
AutomatedTestBase
        }
 
        @Test
+       public void testGroupedAggMinDenseSP() {
+               runGroupedAggregateOperationTest(TEST_NAME1, OpType.MIN, false, 
ExecType.SPARK);
+       }
+
+       @Test
+       public void testGroupedAggMinSparseSP() {
+               runGroupedAggregateOperationTest(TEST_NAME1, OpType.MIN, true, 
ExecType.SPARK);
+       }
+
+       @Test
+       public void testGroupedAggMaxDenseSP() {
+               runGroupedAggregateOperationTest(TEST_NAME1, OpType.MAX, false, 
ExecType.SPARK);
+       }
+
+       @Test
+       public void testGroupedAggMaxSparseSP() {
+               runGroupedAggregateOperationTest(TEST_NAME1, OpType.MAX, true, 
ExecType.SPARK);
+       }
+
+       @Test
        public void testGroupedAggSumDenseWideSP() {
                runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, false, 
ExecType.SPARK, cols2);
        }
@@ -254,17 +291,8 @@ public class FullGroupedAggregateMatrixTest extends 
AutomatedTestBase
        @SuppressWarnings("rawtypes")
        private void runGroupedAggregateOperationTest( String testname, OpType 
type, boolean sparse, ExecType instType, int numCols) 
        {
-               //rtplatform for MR
-               ExecMode platformOld = rtplatform;
-               switch( instType ){
-                       case SPARK: rtplatform = ExecMode.SPARK; break;
-                       default: rtplatform = ExecMode.HYBRID; break;
-               }
-       
-               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               if( rtplatform == ExecMode.SPARK )
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-       
+               ExecMode platformOld = setExecMode(instType);
+               
                try
                {
                        //determine script and function name
@@ -313,15 +341,12 @@ public class FullGroupedAggregateMatrixTest extends 
AutomatedTestBase
                                checkDMLMetaDataFile("C", new 
MatrixCharacteristics(numGroups,numCols,1,1));
                        }
                }
-               catch(IOException ex)
-               {
+               catch(IOException ex) {
                        ex.printStackTrace();
                        throw new RuntimeException(ex);
                }
-               finally
-               {
-                       rtplatform = platformOld;
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               finally {
+                       resetExecMode(platformOld);
                }
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateTest.java
 
b/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateTest.java
index 9d810f2..ece8615 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateTest.java
@@ -25,7 +25,6 @@ import java.util.HashMap;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.common.Types.ExecType;
@@ -37,9 +36,7 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 
-/**
- * 
- */
+
 public class FullGroupedAggregateTest extends AutomatedTestBase 
 {
        private final static String TEST_NAME1 = "GroupedAggregate";
@@ -64,6 +61,8 @@ public class FullGroupedAggregateTest extends 
AutomatedTestBase
                VARIANCE,
                MOMENT3,
                MOMENT4,
+               MIN,
+               MAX
        }
        
        
@@ -227,7 +226,55 @@ public class FullGroupedAggregateTest extends 
AutomatedTestBase
        {
                runGroupedAggregateOperationTest(OpType.MOMENT4, true, false, 
false, ExecType.SPARK);
        }
-       
+
+       @Test
+       public void testGroupedAggMinDenseSP()
+       {
+               runGroupedAggregateOperationTest(OpType.MIN, false, false, 
false, ExecType.SPARK);
+       }
+
+       @Test
+       public void testGroupedAggMinSparseSP()
+       {
+               runGroupedAggregateOperationTest(OpType.MIN, true, false, 
false, ExecType.SPARK);
+       }
+
+       @Test
+       public void testGroupedAggMinDenseWeightsSP()
+       {
+               runGroupedAggregateOperationTest(OpType.MIN, false, true, 
false, ExecType.SPARK);
+       }
+
+       @Test
+       public void testGroupedAggMinSparseWeightsSP()
+       {
+               runGroupedAggregateOperationTest(OpType.MIN, true, true, false, 
ExecType.SPARK);
+       }
+
+       @Test
+       public void testGroupedAggMaxDenseSP()
+       {
+               runGroupedAggregateOperationTest(OpType.MAX, false, false, 
false, ExecType.SPARK);
+       }
+
+       @Test
+       public void testGroupedAggMaxSparseSP()
+       {
+               runGroupedAggregateOperationTest(OpType.MAX, true, false, 
false, ExecType.SPARK);
+       }
+
+       @Test
+       public void testGroupedAggMaxDenseWeightsSP()
+       {
+               runGroupedAggregateOperationTest(OpType.MAX, false, true, 
false, ExecType.SPARK);
+       }
+
+       @Test
+       public void testGroupedAggMaxSparseWeightsSP()
+       {
+               runGroupedAggregateOperationTest(OpType.MAX, true, true, false, 
ExecType.SPARK);
+       }
+
        // 
-----------------------------------------------------------------------
        
        @Test
@@ -366,13 +413,13 @@ public class FullGroupedAggregateTest extends 
AutomatedTestBase
        
        /* TODO weighted central moment in R
        @Test
-       public void testGroupedAggMoment3DenseWeightsCP() 
+       public void testGroupedAggMoment3DenseWeightsCP()
        {
                runGroupedAggregateOperationTest(OpType.MOMENT3, false, true, 
false, ExecType.CP);
        }
-       
+
        @Test
-       public void testGroupedAggMoment3SparseWeightsCP() 
+       public void testGroupedAggMoment3SparseWeightsCP()
        {
                runGroupedAggregateOperationTest(OpType.MOMENT3, true, true, 
false, ExecType.CP);
        }
@@ -389,33 +436,72 @@ public class FullGroupedAggregateTest extends 
AutomatedTestBase
        {
                runGroupedAggregateOperationTest(OpType.MOMENT4, true, false, 
false, ExecType.CP);
        }
-       
+
        /* TODO weighted central moment in R
        @Test
-       public void testGroupedAggMoment4DenseWeightsCP() 
+       public void testGroupedAggMoment4DenseWeightsCP()
        {
                runGroupedAggregateOperationTest(OpType.MOMENT4, false, true, 
false, ExecType.CP);
        }
-       
+
        @Test
-       public void testGroupedAggMoment4SparseWeightsCP() 
+       public void testGroupedAggMoment4SparseWeightsCP()
        {
                runGroupedAggregateOperationTest(OpType.MOMENT4, true, true, 
false, ExecType.CP);
        }
        */
-       
+
+       @Test
+       public void testGroupedAggMinDenseCP()
+       {
+               runGroupedAggregateOperationTest(OpType.MIN, false, false, 
false, ExecType.CP);
+       }
+
+       @Test
+       public void testGroupedAggMinSparseCP()
+       {
+               runGroupedAggregateOperationTest(OpType.MIN, true, false, 
false, ExecType.CP);
+       }
+
+       @Test
+       public void testGroupedAggMinDenseWeightsCP()
+       {
+               runGroupedAggregateOperationTest(OpType.MIN, false, true, 
false, ExecType.CP);
+       }
+
+       @Test
+       public void testGroupedAggMinSparseWeightsCP()
+       {
+               runGroupedAggregateOperationTest(OpType.MIN, true, true, false, 
ExecType.CP);
+       }
+
+       @Test
+       public void testGroupedAggMaxDenseCP()
+       {
+               runGroupedAggregateOperationTest(OpType.MAX, false, false, 
false, ExecType.CP);
+       }
+
+       @Test
+       public void testGroupedAggMaxSparseCP()
+       {
+               runGroupedAggregateOperationTest(OpType.MAX, true, false, 
false, ExecType.CP);
+       }
+
+       @Test
+       public void testGroupedAggMaxDenseWeightsCP()
+       {
+               runGroupedAggregateOperationTest(OpType.MAX, false, true, 
false, ExecType.CP);
+       }
+
+       @Test
+       public void testGroupedAggMaxSparseWeightsCP()
+       {
+               runGroupedAggregateOperationTest(OpType.MAX, true, true, false, 
ExecType.CP);
+       }
+
        private void runGroupedAggregateOperationTest( OpType type, boolean 
sparse, boolean weights, boolean transpose, ExecType instType) 
        {
-               //rtplatform for MR
-               ExecMode platformOld = rtplatform;
-               switch( instType ){
-                       case SPARK: rtplatform = ExecMode.SPARK; break;
-                       default: rtplatform = ExecMode.HYBRID; break;
-               }
-       
-               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               if( rtplatform == ExecMode.SPARK )
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               ExecMode platformOld = setExecMode(instType);
        
                try
                {
@@ -474,17 +560,12 @@ public class FullGroupedAggregateTest extends 
AutomatedTestBase
                        HashMap<CellIndex, Double> rfile  = 
readRMatrixFromExpectedDir(weights?"D":"C");
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
                }
-               catch(IOException ex)
-               {
+               catch(IOException ex) {
                        ex.printStackTrace();
                        throw new RuntimeException(ex);
                }
-               finally
-               {
-                       rtplatform = platformOld;
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               finally {
+                       resetExecMode(platformOld);
                }
        }
-       
-               
 }
\ No newline at end of file
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregate.R 
b/src/test/scripts/functions/aggregate/GroupedAggregate.R
index 63acceb..324a945 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregate.R
+++ b/src/test/scripts/functions/aggregate/GroupedAggregate.R
@@ -58,4 +58,14 @@ if( fn==5 )
    C = aggregate(as.vector(A), by=list(as.vector(B)), FUN=moment, order=4, 
central=TRUE)[,2]
 }
 
+if ( fn==6 )
+{
+   C = aggregate(as.vector(A), by=list(as.vector(B)), FUN=min)[,2]
+}
+
+if ( fn==7 )
+{
+   C = aggregate(as.vector(A), by=list(as.vector(B)), FUN=max)[,2]
+}
+
 writeMM(as(C, "CsparseMatrix"), paste(args[3], "C", sep="")); 
\ No newline at end of file
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregate.dml 
b/src/test/scripts/functions/aggregate/GroupedAggregate.dml
index 6381ffc..132cf59 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregate.dml
+++ b/src/test/scripts/functions/aggregate/GroupedAggregate.dml
@@ -47,5 +47,13 @@ else if( fn==5 )
 {
    C = aggregate(target=A, groups=B, fn="centralmoment", order="4");
 }
+else if( fn==6 )
+{
+   C = aggregate(target=A, groups=B, fn="min");
+}
+else if( fn==7 )
+{
+   C = aggregate(target=A, groups=B, fn="max");
+}
 
 write(C, $4, format="text");
\ No newline at end of file
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R 
b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R
index d67b978..f052e46 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R
@@ -64,6 +64,16 @@ if( fn==5 )
    C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=moment, order=4, 
central=TRUE)[,2]
 }
 
+if( fn==6 )
+{
+   C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=min)[,2]
+}
+
+if( fn==7 )
+{
+   C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=max)[,2]
+}
+
 R[,j] = C;
 }
 
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml 
b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml
index c4e70c8..d4ca660 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml
@@ -47,5 +47,13 @@ else if( fn==5 )
 {
    C = aggregate(target=A, groups=B, fn="centralmoment", order="4", 
ngroups=$4);
 }
+else if( fn==6 )
+{
+   C = aggregate(target=A, groups=B, fn="min", ngroups=$4);
+}
+else if( fn==7 )
+{
+   C = aggregate(target=A, groups=B, fn="max", ngroups=$4);
+}
 
 write(C, $5, format="text");
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml 
b/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml
index d92366d..1366612 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml
@@ -47,5 +47,13 @@ else if( fn==5 )
 {
    C = aggregate(target=A, groups=B, fn="centralmoment", order="4");
 }
+else if( fn==6 )
+{
+   C = aggregate(target=A, groups=B, fn="min");
+}
+else if( fn==7 )
+{
+   C = aggregate(target=A, groups=B, fn="max");
+}
 
 write(C, $5, format="text");
\ No newline at end of file
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateWeights.R 
b/src/test/scripts/functions/aggregate/GroupedAggregateWeights.R
index eea2f94..1fed6c7 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregateWeights.R
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateWeights.R
@@ -68,4 +68,14 @@ if( fn==5 )
    D = aggregate(as.vector(A*C), by=list(as.vector(B)), FUN=moment, order=4, 
central=TRUE)[,2]
 }
 
+if( fn==6 )
+{
+   D = aggregate(as.vector(A*C), by=list(as.vector(B)), FUN=min)[,2]
+}
+
+if( fn==7 )
+{
+   D = aggregate(as.vector(A*C), by=list(as.vector(B)), FUN=max)[,2]
+}
+
 writeMM(as(D, "CsparseMatrix"), paste(args[3], "D", sep="")); 
\ No newline at end of file
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateWeights.dml 
b/src/test/scripts/functions/aggregate/GroupedAggregateWeights.dml
index 1e1ecf7..1d86907 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregateWeights.dml
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateWeights.dml
@@ -48,5 +48,14 @@ else if( fn==5 )
 {
    D = aggregate(target=A, groups=B, weights=C, fn="centralmoment", order="4");
 }
+else if( fn==6 )
+{
+   D = aggregate(target=A, groups=B, weights=C, fn="min");
+}
+else if( fn==7 )
+{
+   D = aggregate(target=A, groups=B, weights=C, fn="max");
+}
+
 
 write(D, $5, format="text");
\ No newline at end of file

Reply via email to