This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 52ca491315 [SYSTEMDS-3831] New builtin for vectorized simple
exponential smoothing
52ca491315 is described below
commit 52ca4913155401180ccf66c85db8be08e86f9388
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Feb 6 17:27:32 2025 +0100
[SYSTEMDS-3831] New builtin for vectorized simple exponential smoothing
This patch introduces a new vectorized builtin function for vectorized
simple exponential smoothing which largely relies on cumsumprod.
---
scripts/builtin/ses.dml | 55 +++++++++++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../runtime/instructions/SPInstructionParser.java | 1 -
.../instructions/cp/CompressionCPInstruction.java | 1 -
.../functions/builtin/part2/BuiltinSESTest.java | 68 ++++++++++++++++++++++
src/test/scripts/functions/builtin/ses.dml | 26 +++++++++
6 files changed, 150 insertions(+), 2 deletions(-)
diff --git a/scripts/builtin/ses.dml b/scripts/builtin/ses.dml
new file mode 100644
index 0000000000..f4b82ad390
--- /dev/null
+++ b/scripts/builtin/ses.dml
@@ -0,0 +1,55 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Builtin function for simple exponential smoothing (SES).
+#
+# INPUT:
+#
------------------------------------------------------------------------------
+# x Time series vector [shape: n-by-1]
+# h Forecasting horizon
+# alpha Smoothing parameter yhat_t = alpha * x_y + (1-alpha) * yhat_t-1
+#
------------------------------------------------------------------------------
+#
+# OUTPUT:
+#
------------------------------------------------------------------------------
+# yhat Forecasts [shape: h-by-1]
+#
------------------------------------------------------------------------------
+
+m_ses = function(Matrix[Double] x, Integer h = 1, Double alpha = 0.5)
+ return (Matrix[Double] yhat)
+{
+ # check and ensure valid parameters
+ if(h < 1) {
+ print("SES: forecasting horizon should be larger one.");
+ h = 1;
+ }
+ if(alpha < 0 | alpha > 1) {
+ print("SES: smooting parameter should be in [0,1].");
+ alpha = 0.5;
+ }
+
+ # vectorized forecasting
+ # weights are 1 for first value and otherwise replicated alpha
+ # but to compensate alpha*x for the first, we use 1/alpha
+ w = rbind(as.matrix(1/alpha), matrix(1-alpha,nrow(x)-1,1));
+ y = cumsumprod(cbind(alpha*x, w));
+ yhat = matrix(as.scalar(y[nrow(x),1]), h, 1);
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 1a7ba207b8..4ff5654de0 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -301,6 +301,7 @@ public enum Builtins {
SD("sd", false),
SELVARTHRESH("selectByVarThresh", true),
SEQ("seq", false),
+ SES("ses", true),
SYMMETRICDIFFERENCE("symmetricDifference", true),
SHAPEXPLAINER("shapExplainer", true),
SHERLOCK("sherlock", true),
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 5014c0ac30..e08ef64ab8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -39,7 +39,6 @@ import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.lops.WeightedUnaryMMR;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import
org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
index 4216385b72..efc8e21777 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.instructions.cp;
import java.util.ArrayList;
import java.util.List;
-import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSESTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSESTest.java
new file mode 100644
index 0000000000..5573549552
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSESTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part2;
+
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class BuiltinSESTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "ses";
+ private final static String TEST_DIR = "functions/builtin/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
BuiltinSESTest.class.getSimpleName() + "/";
+
+ private final static int rows = 200;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"y"}));
+ }
+
+ @Test
+ public void testSES05() {
+ runSESTest(0.5, 199d);
+ }
+
+ @Test
+ public void testSES077() {
+ runSESTest(0.77, 199.7013);
+ }
+
+ @Test
+ public void testSES10() {
+ runSESTest(1.0, 200d);
+ }
+
+ private void runSESTest(double alpha, double expected) {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-args",
+ String.valueOf(rows), String.valueOf(alpha),
output("y")};
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+ HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("y");
+ Assert.assertEquals(7, dmlfile.size()); //forecast horizon 7
+ Assert.assertEquals(expected, dmlfile.get(new CellIndex(1,1)),
1e-3);
+ }
+}
diff --git a/src/test/scripts/functions/builtin/ses.dml
b/src/test/scripts/functions/builtin/ses.dml
new file mode 100644
index 0000000000..2148854c6e
--- /dev/null
+++ b/src/test/scripts/functions/builtin/ses.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+x = seq(1, $1);
+yhat = ses(x=x, alpha=$2, h=7)
+write(yhat, $3)
+