This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 52ca491315 [SYSTEMDS-3831] New builtin for vectorized simple 
exponential smoothing
52ca491315 is described below

commit 52ca4913155401180ccf66c85db8be08e86f9388
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Feb 6 17:27:32 2025 +0100

    [SYSTEMDS-3831] New builtin for vectorized simple exponential smoothing
    
    This patch introduces a new vectorized builtin function for vectorized
    simple exponential smoothing which largely relies on cumsumprod.
---
 scripts/builtin/ses.dml                            | 55 +++++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |  1 +
 .../runtime/instructions/SPInstructionParser.java  |  1 -
 .../instructions/cp/CompressionCPInstruction.java  |  1 -
 .../functions/builtin/part2/BuiltinSESTest.java    | 68 ++++++++++++++++++++++
 src/test/scripts/functions/builtin/ses.dml         | 26 +++++++++
 6 files changed, 150 insertions(+), 2 deletions(-)

diff --git a/scripts/builtin/ses.dml b/scripts/builtin/ses.dml
new file mode 100644
index 0000000000..f4b82ad390
--- /dev/null
+++ b/scripts/builtin/ses.dml
@@ -0,0 +1,55 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Builtin function for simple exponential smoothing (SES).
+#
+# INPUT:
+# 
------------------------------------------------------------------------------
+# x        Time series vector [shape: n-by-1]
+# h        Forecasting horizon
+# alpha    Smoothing parameter yhat_t = alpha * x_y + (1-alpha) * yhat_t-1
+# 
------------------------------------------------------------------------------
+#
+# OUTPUT:
+# 
------------------------------------------------------------------------------
+# yhat     Forecasts [shape: h-by-1]
+# 
------------------------------------------------------------------------------
+
+m_ses = function(Matrix[Double] x, Integer h = 1, Double alpha = 0.5)
+  return (Matrix[Double] yhat)
+{
+  # check and ensure valid parameters
+  if(h < 1) {
+    print("SES: forecasting horizon should be larger one.");
+    h = 1;
+  }
+  if(alpha < 0 | alpha > 1) {
+    print("SES: smooting parameter should be in [0,1].");
+    alpha = 0.5;
+  }
+
+  # vectorized forecasting
+  # weights are 1 for first value and otherwise replicated alpha
+  # but to compensate alpha*x for the first, we use 1/alpha
+  w = rbind(as.matrix(1/alpha), matrix(1-alpha,nrow(x)-1,1));
+  y = cumsumprod(cbind(alpha*x, w));
+  yhat = matrix(as.scalar(y[nrow(x),1]), h, 1);
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 1a7ba207b8..4ff5654de0 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -301,6 +301,7 @@ public enum Builtins {
        SD("sd", false),
        SELVARTHRESH("selectByVarThresh", true),
        SEQ("seq", false),
+       SES("ses", true),
        SYMMETRICDIFFERENCE("symmetricDifference", true),
        SHAPEXPLAINER("shapExplainer", true),
        SHERLOCK("sherlock", true),
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 5014c0ac30..e08ef64ab8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -39,7 +39,6 @@ import org.apache.sysds.lops.WeightedSquaredLossR;
 import org.apache.sysds.lops.WeightedUnaryMM;
 import org.apache.sysds.lops.WeightedUnaryMMR;
 import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import 
org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
index 4216385b72..efc8e21777 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.instructions.cp;
 import java.util.ArrayList;
 import java.util.List;
 
-import org.apache.commons.lang3.NotImplementedException;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSESTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSESTest.java
new file mode 100644
index 0000000000..5573549552
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSESTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part2;
+
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class BuiltinSESTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "ses";
+       private final static String TEST_DIR = "functions/builtin/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinSESTest.class.getSimpleName() + "/";
+
+       private final static int rows = 200;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"y"}));
+       }
+
+       @Test
+       public void testSES05() {
+               runSESTest(0.5, 199d);
+       }
+       
+       @Test
+       public void testSES077() {
+               runSESTest(0.77, 199.7013);
+       }
+       
+       @Test
+       public void testSES10() {
+               runSESTest(1.0, 200d);
+       }
+
+       private void runSESTest(double alpha, double expected) {
+               loadTestConfiguration(getTestConfiguration(TEST_NAME));
+               String HOME = SCRIPT_DIR + TEST_DIR;
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-args", 
+                       String.valueOf(rows), String.valueOf(alpha), 
output("y")};
+               runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+               HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("y");
+               Assert.assertEquals(7, dmlfile.size()); //forecast horizon 7
+               Assert.assertEquals(expected, dmlfile.get(new CellIndex(1,1)), 
1e-3);
+       }
+}
diff --git a/src/test/scripts/functions/builtin/ses.dml 
b/src/test/scripts/functions/builtin/ses.dml
new file mode 100644
index 0000000000..2148854c6e
--- /dev/null
+++ b/src/test/scripts/functions/builtin/ses.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+x = seq(1, $1);
+yhat = ses(x=x, alpha=$2, h=7)
+write(yhat, $3)
+

Reply via email to