This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 06ad1b5a82 [SYSTEMDS-3258] Builtin for matrix square root (multiple 
strategies)
06ad1b5a82 is described below

commit 06ad1b5a825dc1fddbeea317e89ac7550366c2d8
Author: trp-ex <[email protected]>
AuthorDate: Thu Jan 16 08:07:19 2025 +0100

    [SYSTEMDS-3258] Builtin for matrix square root (multiple strategies)
    
    Closes #2178.
    
    Co-authored-by: Florian Hoffmann <[email protected]>
    Co-authored-by: Melisa Akbaydar <[email protected]>
---
 scripts/builtin/sqrtMatrix.dml                     | 114 ++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |   2 +
 src/main/java/org/apache/sysds/common/Types.java   |   2 +-
 src/main/java/org/apache/sysds/hops/UnaryOp.java   |   2 +-
 .../sysds/parser/BuiltinFunctionExpression.java    |  13 ++
 .../org/apache/sysds/parser/DMLTranslator.java     |   1 +
 .../runtime/instructions/CPInstructionParser.java  |   1 +
 .../sysds/runtime/matrix/data/LibCommonsMath.java  |  18 +-
 .../builtin/part2/BuiltinSQRTMatrixTest.java       | 235 +++++++++++++++++++++
 src/test/scripts/functions/builtin/SQRTMatrix.dml  |  32 +++
 10 files changed, 416 insertions(+), 4 deletions(-)

diff --git a/scripts/builtin/sqrtMatrix.dml b/scripts/builtin/sqrtMatrix.dml
new file mode 100644
index 0000000000..12e1829ad4
--- /dev/null
+++ b/scripts/builtin/sqrtMatrix.dml
@@ -0,0 +1,114 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Computes the matrix square root B of a matrix A, such that
+# A = B %*% B.
+#
+# INPUT:
+# 
------------------------------------------------------------------------------
+# A          Input Matrix A
+# S          Strategy (COMMON .. java-based commons-math, DML)
+# 
------------------------------------------------------------------------------
+#
+# OUTPUT:
+# 
------------------------------------------------------------------------------
+# B    Output Matrix B
+# 
------------------------------------------------------------------------------
+
+
+m_sqrtMatrix = function(Matrix[Double] A, String S)
+  return(Matrix[Double] B)
+{
+  if (S == "COMMON") {
+    B = sqrtMatrixJava(A)
+  } else if (S == "DML") {
+    N = nrow(A);
+    D = ncol(A);
+
+    #check that matrix is square
+    if (D != N){
+      stop("matrixSqrt Input Error: matrix not square!")
+    }
+
+    # Any non singualar square matrix has a square root
+    isDiag = isDiagonal(A)
+    if(isDiag) {
+      B = sqrtDiagMatrix(A);
+    } else {
+      [eValues, eVectors] = eigen(A);
+
+      hasNonNegativeEigenValues = (sum(eValues >= 0) == length(eValues));
+
+      if(!hasNonNegativeEigenValues) {
+        stop("matrixSqrt exec Error: matrix has imaginary square root");
+      }
+
+      isSymmetric = sum(A == t(A)) == length(A);
+      allEigenValuesUnique = length(eValues) == length(unique(eValues));
+
+      if(allEigenValuesUnique | isSymmetric) {
+        # calculate X = VDV^(-1) -> S = sqrt(D) -> sqrt_x = VSV^(-1)
+        sqrtD = sqrtDiagMatrix(diag(eValues));
+        V_Inv = inv(eVectors);
+        B = eVectors %*% sqrtD %*% V_Inv;
+      } else {
+        #formular: (Denman–Beavers iteration)
+        Y = A
+        #identity matrix
+        Z = diag(matrix(1.0, rows=N, cols=1))
+
+        for (x in 1:100) {
+          Y_new = (1 / 2) * (Y + inv(Z))
+          Z_new = (1 / 2) * (Z + inv(Y))
+          Y = Y_new
+          Z = Z_new
+        }
+        B = Y
+      }
+    }
+  } else {
+    stop("Error: Unknown strategy for matrix square root.")
+  }
+}
+
+# assumes square and diagonal matrix
+sqrtDiagMatrix = function(Matrix[Double] X)
+  return(Matrix[Double] sqrt_x)
+{
+    N = nrow(X);
+
+    #check if identity matrix
+    is_identity = sum(diag(diag(X)) == X)==length(X)
+                & sum(diag(X) == matrix(1,nrow(X),1))==nrow(X);
+
+    if(is_identity)
+        sqrt_x = X;
+    else
+        sqrt_x = diag(sqrt(diag(X)));
+}
+
+isDiagonal = function (Matrix[Double] X)
+  return(boolean diagonal)
+{
+  #all cells should be the same to be diagonal
+  diagonal = sum(diag(diag(X)) == X) == length(X);
+}
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index a6331905ac..5429cb287c 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -325,6 +325,8 @@ public enum Builtins {
        STEPLM("steplm",true, ReturnType.MULTI_RETURN),
        STFT("stft", false, ReturnType.MULTI_RETURN),
        SQRT("sqrt", false),
+       SQRT_MATRIX("sqrtMatrix", true),
+       SQRT_MATRIX_JAVA("sqrtMatrixJava", false, ReturnType.SINGLE_RETURN),
        SUM("sum", false),
        SVD("svd", false, ReturnType.MULTI_RETURN),
        TABLE("table", "ctable", false),
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index dd351ae894..21595efd03 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -542,7 +542,7 @@ public interface Types {
                CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, 
INVERSE,
                IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
                MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, 
STOP, _EVICT,
-               SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
+               SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
                //fused ML-specific operators for performance 
                SPROP, //sample proportion: P * (1 - P)
                SIGMOID, //sigmoid function: 1 / (1 + exp(-X))
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java 
b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index 2c0cd4a61b..1bda77530b 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -512,7 +512,7 @@ public class UnaryOp extends MultiThreadedHop
                
                //ensure cp exec type for single-node operations
                if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op == 
OpOp1.STOP || _op == OpOp1.TYPEOF
-                       || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == 
OpOp1.CHOLESKY || _op == OpOp1.SVD
+                       || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == 
OpOp1.CHOLESKY || _op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA
                        || getInput().get(0).getDataType() == DataType.LIST || 
isMetadataOperation() )
                {
                        _etype = ExecType.CP;
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 1de3442dd9..c12e4c4705 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1759,6 +1759,19 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
                        output.setDimensions(in.getDim1(), in.getDim2());
                        output.setBlocksize(in.getBlocksize());
                        break;
+
+               case SQRT_MATRIX_JAVA:
+
+                       checkNumParameters(1);
+                       checkMatrixParam(getFirstExpr());
+                       output.setDataType(DataType.MATRIX);
+                       output.setValueType(ValueType.FP64);
+                       Identifier sqrt = getFirstExpr().getOutput();
+                       if(sqrt.dimsKnown() && sqrt.getDim1() != sqrt.getDim2())
+                               raiseValidateError("Input to sqrtMatrix() must 
be square matrix -- given: a " + sqrt.getDim1() + "x" + sqrt.getDim2() + " 
matrix.", conditional);
+                       output.setDimensions( sqrt.getDim1(),  sqrt.getDim2());
+                       output.setBlocksize( sqrt.getBlocksize());
+                       break;
                
                case CHOLESKY:
                {
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 6121711933..b0673be092 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2749,6 +2749,7 @@ public class DMLTranslator
                        break;
 
                case INVERSE:
+               case SQRT_MATRIX_JAVA:
                case CHOLESKY:
                case TYPEOF:
                case DETECTSCHEMA:
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index a0270f6b20..2d19b39f8a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -208,6 +208,7 @@ public class CPInstructionParser extends InstructionParser {
                String2CPInstructionType.put( "ucummax", CPType.Unary);
                String2CPInstructionType.put( "stop"  , CPType.Unary);
                String2CPInstructionType.put( "inverse", CPType.Unary);
+               String2CPInstructionType.put( "sqrt_matrix_java", CPType.Unary);
                String2CPInstructionType.put( "cholesky",CPType.Unary);
                String2CPInstructionType.put( "sprop", CPType.Unary);
                String2CPInstructionType.put( "sigmoid", CPType.Unary);
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
index 61a5f0d784..5365944a3b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
@@ -80,7 +80,7 @@ public class LibCommonsMath
        }
        
        public static boolean isSupportedUnaryOperation( String opcode ) {
-               return ( opcode.equals("inverse") || opcode.equals("cholesky") 
);
+               return ( opcode.equals("inverse") || opcode.equals("cholesky") 
|| opcode.equals("sqrt_matrix_java") );
        }
        
        public static boolean isSupportedMultiReturnOperation( String opcode ) {
@@ -111,6 +111,8 @@ public class LibCommonsMath
                        return computeMatrixInverse(matrixInput);
                else if (opcode.equals("cholesky"))
                        return computeCholesky(matrixInput);
+               else if (opcode.equals("sqrt_matrix_java"))
+                       return computeSqrt(inj);
                return null;
        }
 
@@ -512,7 +514,19 @@ public class LibCommonsMath
 
                return new MatrixBlock[] { U, Sigma, V };
        }
-       
+
+       /**
+        * Computes the square root of a matrix Calls Apache Commons Math 
EigenDecomposition.
+        *
+        * @param in Input matrix
+        * @return matrix block
+        */
+       private static MatrixBlock computeSqrt(MatrixBlock in) {
+               Array2DRowRealMatrix matrixInput = 
DataConverter.convertToArray2DRowRealMatrix(in);
+               EigenDecomposition ed = new EigenDecomposition(matrixInput);
+               return DataConverter.convertToMatrixBlock(ed.getSquareRoot());
+       }
+
        /**
         * Function to compute matrix inverse via matrix decomposition.
         * 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSQRTMatrixTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSQRTMatrixTest.java
new file mode 100644
index 0000000000..a86a6892a8
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSQRTMatrixTest.java
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part2;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class BuiltinSQRTMatrixTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "SQRTMatrix";
+       private final static String TEST_DIR = "functions/builtin/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinSQRTMatrixTest.class.getSimpleName() + "/";
+
+       private final static double eps = 1e-8;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"}));
+       }
+
+       // tests for strategy "COMMON"
+       @Test
+       public void testSQRTMatrixJavaSize1x1() {
+               runSQRTMatrix(true, ExecType.CP, "COMMON", 1);
+       }
+
+       @Test
+       public void testSQRTMatrixJavaUpperTriangularMatrixSize2x2() {
+               runSQRTMatrix(true, ExecType.CP, "COMMON", 2);
+       }
+
+
+       @Test
+       public void testSQRTMatrixJavaDiagonalMatrixSize2x2() {
+               runSQRTMatrix(true, ExecType.CP, "COMMON", 3);
+       }
+
+       @Test
+       public void testSQRTMatrixJavaPSDMatrixSize2x2() {
+               runSQRTMatrix(true, ExecType.CP, "COMMON", 4);
+       }
+
+       @Test
+       public void testSQRTMatrixJavaPSDMatrixSize3x3() {
+               runSQRTMatrix(true, ExecType.CP, "COMMON", 5);
+       }
+
+       @Test
+       public void testSQRTMatrixJavaPSDMatrixSize4x4() {
+               runSQRTMatrix(true, ExecType.CP, "COMMON", 6);
+       }
+
+       @Test
+       public void testSQRTMatrixJavaPSDMatrixSize8x8() {
+               runSQRTMatrix(true, ExecType.CP, "COMMON", 7);
+       }
+
+       // tests for strategy "DML"
+       @Test
+       public void testSQRTMatrixDMLSize1x1() {
+               runSQRTMatrix(true, ExecType.CP, "DML", 1);
+       }
+
+       @Test
+       public void testSQRTMatrixDMLUpperTriangularMatrixSize2x2() {
+               runSQRTMatrix(true, ExecType.CP, "DML", 2);
+       }
+
+       @Test
+       public void testSQRTMatrixDMLDiagonalMatrixSize2x2() {
+               runSQRTMatrix(true, ExecType.CP, "DML", 3);
+       }
+
+       @Test
+       public void testSQRTMatrixDMLPSDMatrixSize2x2() {
+               runSQRTMatrix(true, ExecType.CP, "DML", 4);
+       }
+
+       @Test
+       public void testSQRTMatrixDMLPSDMatrixSize3x3() {
+               runSQRTMatrix(true, ExecType.CP, "DML", 5);
+       }
+
+       @Test
+       public void testSQRTMatrixDMLPSDMatrixSize4x4() {
+               runSQRTMatrix(true, ExecType.CP, "DML", 6);
+       }
+
+       @Test
+       public void testSQRTMatrixDMLPSDMatrixSize8x8() {
+               runSQRTMatrix(true, ExecType.CP, "DML", 7);
+       }
+
+       private void runSQRTMatrix(boolean defaultProb, ExecType instType, 
String strategy, int test_case) {
+               Types.ExecMode platformOld = setExecMode(instType);
+               
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       // find path to associated dml script and define 
parameters
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-args", input("X"), 
strategy, output("Y")};
+
+                       // define input matrix for the matrix sqrt function 
according to test case
+                       double[][] X = null;
+                       switch(test_case) {
+                               case 1: // arbitrary square matrix of dimension 
1x1 (PSD)
+                                       double[][] X1 = {
+                                                       {4}
+                                       };
+                                       X = X1;
+                                       break;
+                               case 2: // arbitrary upper right triangular 
matrix (PSD) of dimension 2x2
+                                       double[][] X2 = {
+                                                       {1, 1},
+                                                       {1, 1},
+                                       };
+                                       X = X2;
+                                       break;
+                               case 3: // arbitrary diagonal matrix (PSD) of 
dimension 2x2
+                                       double[][] X3 = {
+                                                       {1, 0},
+                                                       {0, 1},
+                                       };
+                                       X = X3;
+                                       break;
+                               case 4: // arbitrary PSD matrix of dimension 2x2
+                                       // PSD matrix generated by taking 
(A^T)A of matrix A = [[1, 0], [2, 3]]
+                                       double[][] X4 = {
+                                                       {1, 2},
+                                                       {2, 13}
+                                       };
+                                       X = X4;
+                                       break;
+                               case 5: // arbitrary PSD matrix of dimension 3x3
+                                       // PSD matrix generated by taking 
(A^T)A of matrix A =
+                                       // [[1.5, 0, 1.2],
+                                       // [2.2, 3.8, 4.4],
+                                       // [4.2, 6.1, 0.2]]
+                                       double[][] X5 = {
+                                                       {3.69, 8.58, 6.54},
+                                                       {8.58, 38.64, 33.30},
+                                                       {6.54, 33.3, 54.89}
+                                       };
+                                       X = X5;
+                                       break;
+                               case 6: // arbitrary PSD matrix of dimension 4x4
+                                       // PSD matrix generated by taking 
(A^T)A of matrix A=
+                                       // [[1, 0, 5, 6],
+                                       //  [2, 3, 0, 2],
+                                       //  [5, 0, 1, 1],
+                                       //  [2, 3, 4, 8]]
+                                       double[][] X6 = {
+                                                       {62, 14, 16, 70},
+                                                       {14, 17, 12, 29},
+                                                       {16, 12, 27, 22},
+                                                       {70, 29, 22, 93}
+                                       };
+                                       X = X6;
+                                       break;
+                               case 7: // arbitrary PSD matrix of dimension 8x8
+                                       // PSD matrix generated by taking 
(A^T)A of matrix A =
+                                       // [[ 8.41557894,  3.44748042,  
1.44911908,  4.95381036,  4.42875187,   4.14710712, -0.42719386,  6.1366026 ],
+                                       // [ 3.44748042, 11.38083039,  
4.99475137,  3.36734826,  4.08943809,   4.23308448,  4.50030176,  3.92552912],
+                                       // [ 1.44911908,  4.99475137,  
9.78651357,  4.00347878,  4.60244914,   4.24468227,  3.62945751,  6.54033601],
+                                       // [ 4.95381036,  3.36734826,  
4.00347878, 12.75936071,  3.78643598,   1.99998784,  5.41689723,  7.9756991 ],
+                                       // [ 4.42875187,  4.08943809,  
4.60244914,  3.78643598, 12.49158813,   6.69560056,  3.87176913,  5.5028702 ],
+                                       // [ 4.14710712,  4.23308448,  
4.24468227,  1.99998784,  6.69560056,   7.66015758,  4.21792513,  4.53489207],
+                                       // [-0.42719386,  4.50030176,  
3.62945751,  5.41689723,  3.87176913,   4.21792513,  9.07079513,  2.64352781],
+                                       // [ 6.1366026 ,  3.92552912,  
6.54033601,  7.9756991 ,  5.5028702 ,   4.53489207,  2.64352781,  8.92801728]]
+                                       double[][] X7 = {
+                                                       {184, 150, 140, 194, 
192, 153,  91, 211},
+                                                       {150, 248, 203, 198, 
216, 187, 171, 214},
+                                                       {140, 203, 234, 212, 
223, 185, 165, 237},
+                                                       {194, 198, 212, 326, 
228, 177, 190, 287},
+                                                       {192, 216, 223, 228, 
318, 239, 180, 262},
+                                                       {153, 187, 185, 177, 
239, 199, 152, 209},
+                                                       { 91, 171, 165, 190, 
180, 152, 185, 170},
+                                                       {211, 214, 237, 287, 
262, 209, 170, 297}
+                                       };
+                                       X = X7;
+                                       break;
+                       }
+
+                       assert X != null;
+
+                       // write the input matrix and strategy for matrix sqrt 
function to dml script
+                       writeInputMatrixWithMTD("X", X, true);
+
+                       // run the test dml script
+                       runTest(true, false, null, -1);
+
+                       // read the result matrix from the dml script output Y
+                       HashMap<MatrixValue.CellIndex, Double> actual_Y = 
readDMLMatrixFromOutputDir("Y");
+
+                       // create a HashMap with Matrix Values from the input 
matrix X to compare to the received output matrix
+                       HashMap<MatrixValue.CellIndex, Double> expected_Y = new 
HashMap<>();
+                       for (int r = 0; r < X.length; r++) {
+                               for (int c = 0; c < X[0].length; c++) {
+                                       expected_Y.put(new 
MatrixValue.CellIndex(r + 1, c + 1), X[r][c]);
+                               }
+                       }
+
+                       // compare the expected matrix (the input matrix X) 
with the received output matrix Y, which should be the (SQRT_MATRIX(X))^2 = X 
again
+                       TestUtils.compareMatrices(expected_Y, actual_Y, eps, 
"Expected-DML", "Actual-DML");
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+}
diff --git a/src/test/scripts/functions/builtin/SQRTMatrix.dml 
b/src/test/scripts/functions/builtin/SQRTMatrix.dml
new file mode 100644
index 0000000000..f14d72e6a3
--- /dev/null
+++ b/src/test/scripts/functions/builtin/SQRTMatrix.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# DML script to test the Square Root Operator for matrices
+# Result should be correct, if the result * result == input
+
+X = read($1)
+S = $2
+
+A = sqrtMatrix(X, S)
+Y = A %*% A
+
+write (Y, $3);
+

Reply via email to