This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 09493ef [SYSTEMDS-3152] Min/max support in grouped aggregates (CP and
Spark)
09493ef is described below
commit 09493efdebdbb03afb798418d2047290c7f5ebfa
Author: Thomas Reichel <[email protected]>
AuthorDate: Thu Mar 3 21:01:35 2022 +0100
[SYSTEMDS-3152] Min/max support in grouped aggregates (CP and Spark)
DIA project WS2021/22
Closes #1507.
Co-authored-by: Thomas Moder <[email protected]>
Co-authored-by: burimvrella <[email protected]>
---
.../ParameterizedBuiltinFunctionExpression.java | 4 +-
.../java/org/apache/sysds/parser/Statement.java | 2 +
.../apache/sysds/runtime/functionobjects/CM.java | 36 ++++++
.../runtime/instructions/InstructionUtils.java | 28 ++--
.../runtime/instructions/cp/CM_COV_Object.java | 23 +++-
.../spark/ParameterizedBuiltinSPInstruction.java | 6 +
.../sysds/runtime/matrix/operators/CMOperator.java | 8 +-
.../aggregate/FullGroupedAggregateMatrixTest.java | 69 ++++++----
.../aggregate/FullGroupedAggregateTest.java | 143 ++++++++++++++++-----
.../scripts/functions/aggregate/GroupedAggregate.R | 10 ++
.../functions/aggregate/GroupedAggregate.dml | 8 ++
.../functions/aggregate/GroupedAggregateMatrix.R | 10 ++
.../functions/aggregate/GroupedAggregateMatrix.dml | 8 ++
.../aggregate/GroupedAggregateMatrixNoDims.dml | 8 ++
.../functions/aggregate/GroupedAggregateWeights.R | 10 ++
.../aggregate/GroupedAggregateWeights.dml | 9 ++
16 files changed, 312 insertions(+), 70 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 6b6ca9b..a6bcc2d 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -790,7 +790,9 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
else if (fnameStr.equals(Statement.GAGG_FN_COUNT)
||
fnameStr.equals(Statement.GAGG_FN_SUM)
||
fnameStr.equals(Statement.GAGG_FN_MEAN)
- ||
fnameStr.equals(Statement.GAGG_FN_VARIANCE)){}
+ ||
fnameStr.equals(Statement.GAGG_FN_VARIANCE)
+ ||
fnameStr.equals(Statement.GAGG_FN_MIN)
+ ||
fnameStr.equals(Statement.GAGG_FN_MAX)){}
else {
raiseValidateError("fname is " + fnameStr + "
but must be either centeralmoment, count, sum, mean, variance", conditional);
}
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java
b/src/main/java/org/apache/sysds/parser/Statement.java
index 4b0237d..995a1e2 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -57,6 +57,8 @@ public abstract class Statement implements ParseInfo
public static final String GAGG_FN_MEAN = "mean";
public static final String GAGG_FN_VARIANCE = "variance";
public static final String GAGG_FN_CM = "centralmoment";
+ public static final String GAGG_FN_MIN = "min";
+ public static final String GAGG_FN_MAX = "max";
public static final String GAGG_FN_CM_ORDER = "order";
public static final String GAGG_NUM_GROUPS = "ngroups";
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
b/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
index ae3e718..54a2d83 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/CM.java
@@ -93,6 +93,8 @@ public class CM extends ValueFunction
if(cm1.isCMAllZeros()) {
cm1.w=1;
cm1.mean.set(in2, 0);
+ cm1.min = in2;
+ cm1.max = in2;
cm1.m2.set(0,0);
cm1.m3.set(0,0);
cm1.m4.set(0,0);
@@ -114,6 +116,16 @@ public class CM extends ValueFunction
cm1.w=w;
break;
}
+ case MIN:
+ {
+ cm1.min = Math.min(cm1.min, in2);
+ break;
+ }
+ case MAX:
+ {
+ cm1.max = Math.max(cm1.max, in2);
+ break;
+ }
case CM2:
{
double w= cm1.w + 1;
@@ -197,6 +209,8 @@ public class CM extends ValueFunction
{
cm1.w=w2;
cm1.mean.set(in2, 0);
+ cm1.min = in2 * w2;
+ cm1.max = in2 * w2;
cm1.m2.set(0,0);
cm1.m3.set(0,0);
cm1.m4.set(0,0);
@@ -210,6 +224,16 @@ public class CM extends ValueFunction
cm1.w = Math.round(cm1.w + w2);
break;
}
+ case MIN:
+ {
+ cm1.min = Math.min(cm1.min, in2 * w2);
+ break;
+ }
+ case MAX:
+ {
+ cm1.max = Math.max(cm1.max, in2 * w2);
+ break;
+ }
case MEAN:
{
double w = cm1.w + w2;
@@ -303,6 +327,8 @@ public class CM extends ValueFunction
{
cm1.w=cm2.w;
cm1.mean.set(cm2.mean);
+ cm1.min = cm2.min;
+ cm1.max = cm2.max;
cm1.m2.set(cm2.m2);
cm1.m3.set(cm2.m3);
cm1.m4.set(cm2.m4);
@@ -318,6 +344,16 @@ public class CM extends ValueFunction
cm1.w = Math.round(cm1.w + cm2.w);
break;
}
+ case MIN:
+ {
+ cm1.min = Math.min(cm1.min, cm2.min);
+ break;
+ }
+ case MAX:
+ {
+ cm1.max = Math.max(cm1.max, cm2.max);
+ break;
+ }
case MEAN:
{
double w = cm1.w + cm2.w;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index f9d9ab3..f22fdfe 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -1000,19 +1000,23 @@ public class InstructionUtils
op = CMOperator.getAggOpType(fn, null);
switch(op) {
- case SUM:
- return new AggregateOperator(0,
KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTCOLUMN);
+ case SUM:
+ return new AggregateOperator(0,
KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTCOLUMN);
+
+ case COUNT:
+ case MEAN:
+ case VARIANCE:
+ case CM2:
+ case CM3:
+ case CM4:
- case COUNT:
- case MEAN:
- case VARIANCE:
- case CM2:
- case CM3:
- case CM4:
- return new CMOperator(CM.getCMFnObject(op), op);
- case INVALID:
- default:
- throw new DMLRuntimeException("Invalid Aggregate
Operation in GroupedAggregateInstruction: " + op);
+ //TODO use appropriate function objects for min/max
(see sum)
+ case MIN:
+ case MAX:
+ return new CMOperator(CM.getCMFnObject(op), op);
+ case INVALID:
+ default:
+ throw new DMLRuntimeException("Invalid
Aggregate Operation in GroupedAggregateInstruction: " + op);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java
index 39abdcd..8591c3b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CM_COV_Object.java
@@ -37,13 +37,15 @@ public class CM_COV_Object extends Data
public KahanObject m2;
public KahanObject m3;
public KahanObject m4;
+ public double min;
+ public double max;
public KahanObject mean_v;
public KahanObject c2;
@Override
public String toString() {
- return "weight: "+w+", mean: "+mean+", m2: "+m2+", m3: "+m3+",
m4: "+m4+", mean2: "+mean_v+", c2: "+c2;
+ return "weight: "+w+", mean: "+mean+", m2: "+m2+", m3: "+m3+",
m4: "+m4+", min: "+min+", max: "+max+", mean2: "+mean_v+", c2: "+c2;
}
public CM_COV_Object()
@@ -56,6 +58,8 @@ public class CM_COV_Object extends Data
m4=new KahanObject(0,0);
mean_v=new KahanObject(0,0);
c2=new KahanObject(0,0);
+ min=0;
+ max=0;
}
public void reset()
@@ -67,6 +71,8 @@ public class CM_COV_Object extends Data
m4=new KahanObject(0,0);
mean_v=new KahanObject(0,0);
c2=new KahanObject(0,0);
+ min=0;
+ max=0;
}
public int compareTo(CM_COV_Object that)
@@ -83,6 +89,10 @@ public class CM_COV_Object extends Data
return KahanObject.compare(m4, that.m4);
else if(mean_v!=that.mean_v)
return KahanObject.compare(mean_v, that.mean_v);
+ else if(min!=that.min)
+ return Double.compare(min, that.min);
+ else if(max!=that.max)
+ return Double.compare(max, that.max);
else
return KahanObject.compare(c2, that.c2);
}
@@ -96,7 +106,8 @@ public class CM_COV_Object extends Data
CM_COV_Object that = (CM_COV_Object)o;
return (w==that.w && mean.equals(that.mean) &&
m2.equals(that.m2))
&& m3.equals(that.m3) && m4.equals(that.m4)
- && mean_v.equals(that.mean_v) &&
c2.equals(that.c2);
+ && mean_v.equals(that.mean_v) &&
c2.equals(that.c2)
+ && min==that.min && max == that.max;
}
@Override
@@ -113,11 +124,13 @@ public class CM_COV_Object extends Data
this.m4.set(that.m4);
this.mean_v.set(that.mean_v);
this.c2.set(that.c2);
+ this.min=that.min;
+ this.max=that.max;
}
public boolean isCMAllZeros()
{
- return w==0 && mean.isAllZero() && m2.isAllZero() &&
m3.isAllZero() && m4.isAllZero() ;
+ return w==0 && mean.isAllZero() && m2.isAllZero() &&
m3.isAllZero() && m4.isAllZero() && min==0 && max==0;
}
public boolean isCOVAllZeros()
@@ -166,6 +179,10 @@ public class CM_COV_Object extends Data
return m3._sum/w;
case CM4:
return m4._sum/w;
+ case MIN:
+ return min;
+ case MAX:
+ return max;
case VARIANCE:
return w==1.0? 0:m2._sum/(w-1);
default:
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 1478c5f..0494f42 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -807,6 +807,12 @@ public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction
case MEAN:
val = kv.getValue();
break;
+ case MIN:
+ val = kv.getValue();
+ break;
+ case MAX:
+ val = kv.getValue();
+ break;
case CM2:
val = kv.getValue() /
kv.getWeight();
break;
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
index 579e681..7e56e0e 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
@@ -35,6 +35,8 @@ public class CMOperator extends Operator
CM2,
CM3,
CM4,
+ MIN,
+ MAX,
VARIANCE,
INVALID
}
@@ -103,6 +105,10 @@ public class CMOperator extends Operator
return AggregateOperationTypes.CM4;
else
return AggregateOperationTypes.INVALID;
+ } else if (fn.equalsIgnoreCase("min")) {
+ return AggregateOperationTypes.MIN;
+ } else if (fn.equalsIgnoreCase("max")) {
+ return AggregateOperationTypes.MAX;
}
return AggregateOperationTypes.INVALID;
}
@@ -114,7 +120,7 @@ public class CMOperator extends Operator
switch( aggOpType )
{
case COUNT:
- case MEAN:
+ case MEAN:
ret = true; break;
//NOTE: the following aggregation operators are not
marked for partial aggregation
diff --git
a/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateMatrixTest.java
b/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateMatrixTest.java
index dcf439a..10036cf 100644
---
a/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateMatrixTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateMatrixTest.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.aggregate;
import java.io.IOException;
import java.util.HashMap;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
@@ -38,9 +37,6 @@ import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
-/**
- *
- */
public class FullGroupedAggregateMatrixTest extends AutomatedTestBase
{
private final static String TEST_NAME1 = "GroupedAggregateMatrix";
@@ -66,9 +62,10 @@ public class FullGroupedAggregateMatrixTest extends
AutomatedTestBase
VARIANCE,
MOMENT3,
MOMENT4,
+ MIN,
+ MAX
}
-
@Override
public void setUp()
{
@@ -158,6 +155,26 @@ public class FullGroupedAggregateMatrixTest extends
AutomatedTestBase
}
@Test
+ public void testGroupedAggMinDenseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MIN, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMinSparseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MIN, true,
ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMaxDenseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MAX, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMaxSparseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MAX, true,
ExecType.CP);
+ }
+
+ @Test
public void testGroupedAggSumDenseWideCP() {
runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, false,
ExecType.CP, cols2);
}
@@ -238,6 +255,26 @@ public class FullGroupedAggregateMatrixTest extends
AutomatedTestBase
}
@Test
+ public void testGroupedAggMinDenseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MIN, false,
ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMinSparseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MIN, true,
ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMaxDenseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MAX, false,
ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMaxSparseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MAX, true,
ExecType.SPARK);
+ }
+
+ @Test
public void testGroupedAggSumDenseWideSP() {
runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, false,
ExecType.SPARK, cols2);
}
@@ -254,17 +291,8 @@ public class FullGroupedAggregateMatrixTest extends
AutomatedTestBase
@SuppressWarnings("rawtypes")
private void runGroupedAggregateOperationTest( String testname, OpType
type, boolean sparse, ExecType instType, int numCols)
{
- //rtplatform for MR
- ExecMode platformOld = rtplatform;
- switch( instType ){
- case SPARK: rtplatform = ExecMode.SPARK; break;
- default: rtplatform = ExecMode.HYBRID; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
+ ExecMode platformOld = setExecMode(instType);
+
try
{
//determine script and function name
@@ -313,15 +341,12 @@ public class FullGroupedAggregateMatrixTest extends
AutomatedTestBase
checkDMLMetaDataFile("C", new
MatrixCharacteristics(numGroups,numCols,1,1));
}
}
- catch(IOException ex)
- {
+ catch(IOException ex) {
ex.printStackTrace();
throw new RuntimeException(ex);
}
- finally
- {
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ finally {
+ resetExecMode(platformOld);
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateTest.java
b/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateTest.java
index 9d810f2..ece8615 100644
---
a/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/aggregate/FullGroupedAggregateTest.java
@@ -25,7 +25,6 @@ import java.util.HashMap;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ExecType;
@@ -37,9 +36,7 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-/**
- *
- */
+
public class FullGroupedAggregateTest extends AutomatedTestBase
{
private final static String TEST_NAME1 = "GroupedAggregate";
@@ -64,6 +61,8 @@ public class FullGroupedAggregateTest extends
AutomatedTestBase
VARIANCE,
MOMENT3,
MOMENT4,
+ MIN,
+ MAX
}
@@ -227,7 +226,55 @@ public class FullGroupedAggregateTest extends
AutomatedTestBase
{
runGroupedAggregateOperationTest(OpType.MOMENT4, true, false,
false, ExecType.SPARK);
}
-
+
+ @Test
+ public void testGroupedAggMinDenseSP()
+ {
+ runGroupedAggregateOperationTest(OpType.MIN, false, false,
false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMinSparseSP()
+ {
+ runGroupedAggregateOperationTest(OpType.MIN, true, false,
false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMinDenseWeightsSP()
+ {
+ runGroupedAggregateOperationTest(OpType.MIN, false, true,
false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMinSparseWeightsSP()
+ {
+ runGroupedAggregateOperationTest(OpType.MIN, true, true, false,
ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMaxDenseSP()
+ {
+ runGroupedAggregateOperationTest(OpType.MAX, false, false,
false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMaxSparseSP()
+ {
+ runGroupedAggregateOperationTest(OpType.MAX, true, false,
false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMaxDenseWeightsSP()
+ {
+ runGroupedAggregateOperationTest(OpType.MAX, false, true,
false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMaxSparseWeightsSP()
+ {
+ runGroupedAggregateOperationTest(OpType.MAX, true, true, false,
ExecType.SPARK);
+ }
+
//
-----------------------------------------------------------------------
@Test
@@ -366,13 +413,13 @@ public class FullGroupedAggregateTest extends
AutomatedTestBase
/* TODO weighted central moment in R
@Test
- public void testGroupedAggMoment3DenseWeightsCP()
+ public void testGroupedAggMoment3DenseWeightsCP()
{
runGroupedAggregateOperationTest(OpType.MOMENT3, false, true,
false, ExecType.CP);
}
-
+
@Test
- public void testGroupedAggMoment3SparseWeightsCP()
+ public void testGroupedAggMoment3SparseWeightsCP()
{
runGroupedAggregateOperationTest(OpType.MOMENT3, true, true,
false, ExecType.CP);
}
@@ -389,33 +436,72 @@ public class FullGroupedAggregateTest extends
AutomatedTestBase
{
runGroupedAggregateOperationTest(OpType.MOMENT4, true, false,
false, ExecType.CP);
}
-
+
/* TODO weighted central moment in R
@Test
- public void testGroupedAggMoment4DenseWeightsCP()
+ public void testGroupedAggMoment4DenseWeightsCP()
{
runGroupedAggregateOperationTest(OpType.MOMENT4, false, true,
false, ExecType.CP);
}
-
+
@Test
- public void testGroupedAggMoment4SparseWeightsCP()
+ public void testGroupedAggMoment4SparseWeightsCP()
{
runGroupedAggregateOperationTest(OpType.MOMENT4, true, true,
false, ExecType.CP);
}
*/
-
+
+ @Test
+ public void testGroupedAggMinDenseCP()
+ {
+ runGroupedAggregateOperationTest(OpType.MIN, false, false,
false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMinSparseCP()
+ {
+ runGroupedAggregateOperationTest(OpType.MIN, true, false,
false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMinDenseWeightsCP()
+ {
+ runGroupedAggregateOperationTest(OpType.MIN, false, true,
false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMinSparseWeightsCP()
+ {
+ runGroupedAggregateOperationTest(OpType.MIN, true, true, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMaxDenseCP()
+ {
+ runGroupedAggregateOperationTest(OpType.MAX, false, false,
false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMaxSparseCP()
+ {
+ runGroupedAggregateOperationTest(OpType.MAX, true, false,
false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMaxDenseWeightsCP()
+ {
+ runGroupedAggregateOperationTest(OpType.MAX, false, true,
false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMaxSparseWeightsCP()
+ {
+ runGroupedAggregateOperationTest(OpType.MAX, true, true, false,
ExecType.CP);
+ }
+
private void runGroupedAggregateOperationTest( OpType type, boolean
sparse, boolean weights, boolean transpose, ExecType instType)
{
- //rtplatform for MR
- ExecMode platformOld = rtplatform;
- switch( instType ){
- case SPARK: rtplatform = ExecMode.SPARK; break;
- default: rtplatform = ExecMode.HYBRID; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ ExecMode platformOld = setExecMode(instType);
try
{
@@ -474,17 +560,12 @@ public class FullGroupedAggregateTest extends
AutomatedTestBase
HashMap<CellIndex, Double> rfile =
readRMatrixFromExpectedDir(weights?"D":"C");
TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
}
- catch(IOException ex)
- {
+ catch(IOException ex) {
ex.printStackTrace();
throw new RuntimeException(ex);
}
- finally
- {
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ finally {
+ resetExecMode(platformOld);
}
}
-
-
}
\ No newline at end of file
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregate.R
b/src/test/scripts/functions/aggregate/GroupedAggregate.R
index 63acceb..324a945 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregate.R
+++ b/src/test/scripts/functions/aggregate/GroupedAggregate.R
@@ -58,4 +58,14 @@ if( fn==5 )
C = aggregate(as.vector(A), by=list(as.vector(B)), FUN=moment, order=4,
central=TRUE)[,2]
}
+if ( fn==6 )
+{
+ C = aggregate(as.vector(A), by=list(as.vector(B)), FUN=min)[,2]
+}
+
+if ( fn==7 )
+{
+ C = aggregate(as.vector(A), by=list(as.vector(B)), FUN=max)[,2]
+}
+
writeMM(as(C, "CsparseMatrix"), paste(args[3], "C", sep=""));
\ No newline at end of file
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregate.dml
b/src/test/scripts/functions/aggregate/GroupedAggregate.dml
index 6381ffc..132cf59 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregate.dml
+++ b/src/test/scripts/functions/aggregate/GroupedAggregate.dml
@@ -47,5 +47,13 @@ else if( fn==5 )
{
C = aggregate(target=A, groups=B, fn="centralmoment", order="4");
}
+else if( fn==6 )
+{
+ C = aggregate(target=A, groups=B, fn="min");
+}
+else if( fn==7 )
+{
+ C = aggregate(target=A, groups=B, fn="max");
+}
write(C, $4, format="text");
\ No newline at end of file
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R
b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R
index d67b978..f052e46 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R
@@ -64,6 +64,16 @@ if( fn==5 )
C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=moment, order=4,
central=TRUE)[,2]
}
+if( fn==6 )
+{
+ C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=min)[,2]
+}
+
+if( fn==7 )
+{
+ C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=max)[,2]
+}
+
R[,j] = C;
}
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml
b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml
index c4e70c8..d4ca660 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml
@@ -47,5 +47,13 @@ else if( fn==5 )
{
C = aggregate(target=A, groups=B, fn="centralmoment", order="4",
ngroups=$4);
}
+else if( fn==6 )
+{
+ C = aggregate(target=A, groups=B, fn="min", ngroups=$4);
+}
+else if( fn==7 )
+{
+ C = aggregate(target=A, groups=B, fn="max", ngroups=$4);
+}
write(C, $5, format="text");
\ No newline at end of file
diff --git
a/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml
b/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml
index d92366d..1366612 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml
@@ -47,5 +47,13 @@ else if( fn==5 )
{
C = aggregate(target=A, groups=B, fn="centralmoment", order="4");
}
+else if( fn==6 )
+{
+ C = aggregate(target=A, groups=B, fn="min");
+}
+else if( fn==7 )
+{
+ C = aggregate(target=A, groups=B, fn="max");
+}
write(C, $5, format="text");
\ No newline at end of file
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateWeights.R
b/src/test/scripts/functions/aggregate/GroupedAggregateWeights.R
index eea2f94..1fed6c7 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregateWeights.R
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateWeights.R
@@ -68,4 +68,14 @@ if( fn==5 )
D = aggregate(as.vector(A*C), by=list(as.vector(B)), FUN=moment, order=4,
central=TRUE)[,2]
}
+if( fn==6 )
+{
+ D = aggregate(as.vector(A*C), by=list(as.vector(B)), FUN=min)[,2]
+}
+
+if( fn==7 )
+{
+ D = aggregate(as.vector(A*C), by=list(as.vector(B)), FUN=max)[,2]
+}
+
writeMM(as(D, "CsparseMatrix"), paste(args[3], "D", sep=""));
\ No newline at end of file
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateWeights.dml
b/src/test/scripts/functions/aggregate/GroupedAggregateWeights.dml
index 1e1ecf7..1d86907 100644
--- a/src/test/scripts/functions/aggregate/GroupedAggregateWeights.dml
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateWeights.dml
@@ -48,5 +48,14 @@ else if( fn==5 )
{
D = aggregate(target=A, groups=B, weights=C, fn="centralmoment", order="4");
}
+else if( fn==6 )
+{
+ D = aggregate(target=A, groups=B, weights=C, fn="min");
+}
+else if( fn==7 )
+{
+ D = aggregate(target=A, groups=B, weights=C, fn="max");
+}
+
write(D, $5, format="text");
\ No newline at end of file