This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 35b8e03cbb62d9ea26f0417abfd100dbdef2e002
Author: Mufan Wang <[email protected]>
AuthorDate: Thu Apr 4 17:18:09 2024 +0200

    [SYSTEMDS-3686] STFT
    
    This commit adds a short time fourier transformation to the system.
    this applies fast fourier transformations on windows of different
    stride and widths, enabeling applications such as sound classification.
    
    LDE 23/24 project
    
    Co-authored-by: Mufan Wang <[email protected]>
    Co-authored-by: Frederic Caspar Zoepffel <[email protected]>
    Co-authored-by: Jessica Eva Sophie Priebe <[email protected]>
    
    Closes #2000
---
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../java/org/apache/sysds/hops/FunctionOp.java     |   9 +
 .../sysds/parser/BuiltinFunctionExpression.java    |  66 ++++++++
 .../org/apache/sysds/parser/DMLTranslator.java     |   1 +
 .../runtime/instructions/CPInstructionParser.java  |   1 +
 .../instructions/cp/ComputationCPInstruction.java  |  18 +-
 .../cp/MultiReturnBuiltinCPInstruction.java        |   8 +
 ...ltiReturnComplexMatrixBuiltinCPInstruction.java |  65 +++++++-
 .../sysds/runtime/matrix/data/LibCommonsMath.java  |  53 ++++++
 .../sysds/runtime/matrix/data/LibMatrixSTFT.java   | 121 ++++++++++++++
 .../test/component/matrix/EigenDecompTest.java     |   3 +
 .../sysds/test/component/matrix/STFTTest.java      | 182 +++++++++++++++++++++
 12 files changed, 525 insertions(+), 3 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 7e83984e47..8f113c092f 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -310,6 +310,7 @@ public enum Builtins {
        STATSNA("statsNA", true),
        STRATSTATS("stratstats", true),
        STEPLM("steplm",true, ReturnType.MULTI_RETURN),
+       STFT("stft", false, ReturnType.MULTI_RETURN),
        SQRT("sqrt", false),
        SUM("sum", false),
        SVD("svd", false, ReturnType.MULTI_RETURN),
diff --git a/src/main/java/org/apache/sysds/hops/FunctionOp.java 
b/src/main/java/org/apache/sysds/hops/FunctionOp.java
index ffc12c30ee..95b5411500 100644
--- a/src/main/java/org/apache/sysds/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysds/hops/FunctionOp.java
@@ -221,6 +221,11 @@ public class FunctionOp extends Hop
                                long outputIm = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0);
                                return outputRe+outputIm;
                        }
+                       else if ( getFunctionName().equalsIgnoreCase("stft") ) {
+                               long outputRe = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0);
+                               long outputIm = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0);
+                               return outputRe+outputIm;
+                       }
                        else if ( getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("lstm_backward") ) {
                                // TODO: To allow for initial version to always 
run on the GPU
                                return 0; 
@@ -286,6 +291,10 @@ public class FunctionOp extends Hop
                                // 2 matrices of size same as the input
                                return 
2*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 
getInput().get(0).getDim2(), 1.0);
                        }
+                       else if ( getFunctionName().equalsIgnoreCase("stft") ) {
+                               // 2 matrices of size same as the input
+                               return 
2*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 
getInput().get(0).getDim2(), 1.0);
+                       }
                        else if 
(getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
                                        
getFunctionName().equalsIgnoreCase("batch_norm2d_train") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
                                return 0; 
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 4b3c8e82f7..c3f1026627 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -589,6 +589,72 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
 
                        break;
                }
+               case STFT: {
+                       checkMatrixParam(getFirstExpr());
+
+                       if((getFirstExpr() == null || getSecondExpr() == null 
|| getThirdExpr() == null) && _args.length > 0) {
+                               raiseValidateError("Missing argument for 
function " + this.getOpCode(), false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
+                       }
+                       else if(getFifthExpr() != null) {
+                               raiseValidateError("Invalid number of arguments 
for function " + this.getOpCode().toString().toLowerCase()
+                                       + "(). This function only takes 3 or 4 
arguments.", false);
+                       }
+                       else if(_args.length == 3) {
+                               checkScalarParam(getSecondExpr());
+                               checkScalarParam(getThirdExpr());
+                               if(!isPowerOfTwo(((ConstIdentifier) 
getSecondExpr().getOutput()).getLongValue())) {
+                                       raiseValidateError(
+                                               "This FFT implementation is 
only defined for matrices with dimensions that are powers of 2."
+                                                       + "The window size (2nd 
argument) is not a power of two",
+                                               false, 
LanguageErrorCodes.INVALID_PARAMETERS);
+                               }
+                               else if(((ConstIdentifier) 
getSecondExpr().getOutput())
+                                       .getLongValue() <= ((ConstIdentifier) 
getThirdExpr().getOutput()).getLongValue()) {
+                                       raiseValidateError("Overlap can't be 
larger than or equal to the window size.", false,
+                                               
LanguageErrorCodes.INVALID_PARAMETERS);
+                               }
+                       }
+                       else if(_args.length == 4) {
+                               checkMatrixParam(getSecondExpr());
+                               checkScalarParam(getThirdExpr());
+                               checkScalarParam(getFourthExpr());
+                               if(!isPowerOfTwo(((ConstIdentifier) 
getThirdExpr().getOutput()).getLongValue())) {
+                                       raiseValidateError(
+                                               "This FFT implementation is 
only defined for matrices with dimensions that are powers of 2."
+                                                       + "The window size (3rd 
argument) is not a power of two",
+                                               false, 
LanguageErrorCodes.INVALID_PARAMETERS);
+                               }
+                               else if(getFirstExpr().getOutput().getDim1() != 
getSecondExpr().getOutput().getDim1() ||
+                                       getFirstExpr().getOutput().getDim2() != 
getSecondExpr().getOutput().getDim2()) {
+                                       raiseValidateError("The real and 
imaginary part of the provided matrix are of different dimensions.",
+                                               false);
+                               }
+                               else if(((ConstIdentifier) 
getThirdExpr().getOutput())
+                                       .getLongValue() <= ((ConstIdentifier) 
getFourthExpr().getOutput()).getLongValue()) {
+                                       raiseValidateError("Overlap can't be 
larger than or equal to the window size.", false,
+                                               
LanguageErrorCodes.INVALID_PARAMETERS);
+                               }
+                       }
+
+                       // setup output properties
+                       DataIdentifier stftOut1 = (DataIdentifier) 
getOutputs()[0];
+                       DataIdentifier stftOut2 = (DataIdentifier) 
getOutputs()[1];
+
+                       // Output1 - stft Values
+                       stftOut1.setDataType(DataType.MATRIX);
+                       stftOut1.setValueType(ValueType.FP64);
+                       
stftOut1.setDimensions(getFirstExpr().getOutput().getDim1(), 
getFirstExpr().getOutput().getDim2());
+                       
stftOut1.setBlocksize(getFirstExpr().getOutput().getBlocksize());
+
+                       // Output2 - stft Vectors
+                       stftOut2.setDataType(DataType.MATRIX);
+                       stftOut2.setValueType(ValueType.FP64);
+                       
stftOut2.setDimensions(getFirstExpr().getOutput().getDim1(), 
getFirstExpr().getOutput().getDim2());
+                       
stftOut2.setBlocksize(getFirstExpr().getOutput().getBlocksize());
+
+                       break;
+               }
                case REMOVE: {
                        checkNumParameters(2);
                        checkListParam(getFirstExpr());
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 97d5523961..2b876c12be 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2240,6 +2240,7 @@ public class DMLTranslator
                        case IFFT:
                        case FFT_LINEARIZED:
                        case IFFT_LINEARIZED:
+                       case STFT:
                        case LSTM:
                        case LSTM_BACKWARD:
                        case BATCH_NORM2D:
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index b78c1e2b49..994c1cd51a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -338,6 +338,7 @@ public class CPInstructionParser extends InstructionParser {
                String2CPInstructionType.put( "ifft",  
CPType.MultiReturnComplexMatrixBuiltin);
                String2CPInstructionType.put( "fft_linearized", 
CPType.MultiReturnBuiltin);
                String2CPInstructionType.put( "ifft_linearized", 
CPType.MultiReturnComplexMatrixBuiltin);
+               String2CPInstructionType.put( "stft", 
CPType.MultiReturnComplexMatrixBuiltin);
                String2CPInstructionType.put( "svd",   
CPType.MultiReturnBuiltin);
 
                String2CPInstructionType.put( "partition", CPType.Partition);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ComputationCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ComputationCPInstruction.java
index de45036991..cc32f624d0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ComputationCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ComputationCPInstruction.java
@@ -35,7 +35,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 public abstract class ComputationCPInstruction extends CPInstruction 
implements LineageTraceable {
 
        public final CPOperand output;
-       public final CPOperand input1, input2, input3;
+       public final CPOperand input1, input2, input3, input4;
 
        protected ComputationCPInstruction(CPType type, Operator op, CPOperand 
in1, CPOperand in2, CPOperand out, String opcode,
                        String istr) {
@@ -43,6 +43,7 @@ public abstract class ComputationCPInstruction extends 
CPInstruction implements
                input1 = in1;
                input2 = in2;
                input3 = null;
+               input4 = null;
                output = out;
        }
 
@@ -52,6 +53,17 @@ public abstract class ComputationCPInstruction extends 
CPInstruction implements
                input1 = in1;
                input2 = in2;
                input3 = in3;
+               input4 = null;
+               output = out;
+       }
+
+       protected ComputationCPInstruction(CPType type, Operator op, CPOperand 
in1, CPOperand in2, CPOperand in3,
+               CPOperand in4, CPOperand out, String opcode, String istr) {
+               super(type, op, opcode, istr);
+               input1 = in1;
+               input2 = in2;
+               input3 = in3;
+               input4 = in4;
                output = out;
        }
 
@@ -64,7 +76,7 @@ public abstract class ComputationCPInstruction extends 
CPInstruction implements
        }
 
        public CPOperand[] getInputs(){
-               return new CPOperand[]{input1, input2, input3};
+               return new CPOperand[]{input1, input2, input3, input4};
        }
 
        public boolean hasFrameInput() {
@@ -74,6 +86,8 @@ public abstract class ComputationCPInstruction extends 
CPInstruction implements
                        return true;
                if (input3 != null && input3.isFrame())
                        return true;
+               if (input4 != null && input4.isFrame())
+                       return true;
                return false;
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java
index 4719948ec0..a65b56c8ac 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java
@@ -124,6 +124,14 @@ public class MultiReturnBuiltinCPInstruction extends 
ComputationCPInstruction {
                        return new MultiReturnBuiltinCPInstruction(null, null, 
outputs, opcode, str);
 
                }
+               else if ( opcode.equalsIgnoreCase("stft") ) {
+                       // one input and two outputs
+                       CPOperand in1 = new CPOperand(parts[1]);
+                       outputs.add ( new CPOperand(parts[2], ValueType.FP64, 
DataType.MATRIX) );
+                       outputs.add ( new CPOperand(parts[3], ValueType.FP64, 
DataType.MATRIX) );
+
+                       return new MultiReturnBuiltinCPInstruction(null, in1, 
outputs, opcode, str);
+               }
                else if ( opcode.equalsIgnoreCase("svd") ) {
                        CPOperand in1 = new CPOperand(parts[1]);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java
index f9c92d4858..f75b231052 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java
@@ -50,6 +50,12 @@ public class MultiReturnComplexMatrixBuiltinCPInstruction 
extends ComputationCPI
                _outputs = outputs;
        }
 
+       private MultiReturnComplexMatrixBuiltinCPInstruction(Operator op, 
CPOperand input1, CPOperand input2,
+               CPOperand input3, CPOperand input4, ArrayList<CPOperand> 
outputs, String opcode, String istr) {
+               super(CPType.MultiReturnBuiltin, op, input1, input2, input3, 
input4, outputs.get(0), opcode, istr);
+               _outputs = outputs;
+       }
+
        public CPOperand getOutput(int i) {
                return _outputs.get(i);
        }
@@ -102,6 +108,27 @@ public class MultiReturnComplexMatrixBuiltinCPInstruction 
extends ComputationCPI
 
                        return new 
MultiReturnComplexMatrixBuiltinCPInstruction(null, in1, outputs, opcode, str);
                }
+               else if(parts.length == 6 && opcode.equalsIgnoreCase("stft")) {
+                       CPOperand in1 = new CPOperand(parts[1]);
+                       CPOperand windowSize = new CPOperand(parts[2]);
+                       CPOperand overlap = new CPOperand(parts[3]);
+                       outputs.add(new CPOperand(parts[4], ValueType.FP64, 
DataType.MATRIX));
+                       outputs.add(new CPOperand(parts[5], ValueType.FP64, 
DataType.MATRIX));
+
+                       return new 
MultiReturnComplexMatrixBuiltinCPInstruction(null, in1, null, windowSize, 
overlap, outputs, opcode,
+                               str);
+               }
+               else if(parts.length == 7 && opcode.equalsIgnoreCase("stft")) {
+                       CPOperand in1 = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand windowSize = new CPOperand(parts[3]);
+                       CPOperand overlap = new CPOperand(parts[4]);
+                       outputs.add(new CPOperand(parts[5], ValueType.FP64, 
DataType.MATRIX));
+                       outputs.add(new CPOperand(parts[6], ValueType.FP64, 
DataType.MATRIX));
+
+                       return new 
MultiReturnComplexMatrixBuiltinCPInstruction(null, in1, in2, windowSize, 
overlap, outputs, opcode,
+                               str);
+               }
                else {
                        throw new DMLRuntimeException("Invalid opcode in 
MultiReturnBuiltin instruction: " + opcode);
                }
@@ -114,7 +141,11 @@ public class MultiReturnComplexMatrixBuiltinCPInstruction 
extends ComputationCPI
 
        @Override
        public void processInstruction(ExecutionContext ec) {
-               if(input2 == null)
+               if(getOpcode().equals("stft") && input2 == null)
+                       processSTFTInstruction(ec);
+               else if(getOpcode().equals("stft"))
+                       processSTFTTwoInstruction(ec);
+               else if(input2 == null)
                        processOneInputInstruction(ec);
                else
                        processTwoInputInstruction(ec);
@@ -148,6 +179,38 @@ public class MultiReturnComplexMatrixBuiltinCPInstruction 
extends ComputationCPI
                }
        }
 
+       private void processSTFTInstruction(ExecutionContext ec) {
+               if(!LibCommonsMath.isSupportedMultiReturnOperation(getOpcode()))
+                       throw new DMLRuntimeException("Invalid opcode in 
MultiReturnBuiltin instruction: " + getOpcode());
+
+               MatrixBlock in1 = ec.getMatrixInput(input1.getName());
+               // MatrixBlock in2 = ec.getMatrixInput(input2.getName());
+               int windowSize = Integer.parseInt(input3.getName());
+               int overlap = Integer.parseInt(input4.getName());
+               MatrixBlock[] out = LibCommonsMath.multiReturnOperations(in1, 
getOpcode(), windowSize, overlap);
+               ec.releaseMatrixInput(input1.getName());
+
+               for(int i = 0; i < _outputs.size(); i++) {
+                       ec.setMatrixOutput(_outputs.get(i).getName(), out[i]);
+               }
+       }
+
+       private void processSTFTTwoInstruction(ExecutionContext ec) {
+               if(!LibCommonsMath.isSupportedMultiReturnOperation(getOpcode()))
+                       throw new DMLRuntimeException("Invalid opcode in 
MultiReturnBuiltin instruction: " + getOpcode());
+
+               MatrixBlock in1 = ec.getMatrixInput(input1.getName());
+               MatrixBlock in2 = ec.getMatrixInput(input2.getName());
+               int windowSize = Integer.parseInt(input3.getName());
+               int overlap = Integer.parseInt(input4.getName());
+               MatrixBlock[] out = LibCommonsMath.multiReturnOperations(in1, 
in2, getOpcode(), windowSize, overlap);
+               ec.releaseMatrixInput(input1.getName(), input2.getName());
+
+               for(int i = 0; i < _outputs.size(); i++) {
+                       ec.setMatrixOutput(_outputs.get(i).getName(), out[i]);
+               }
+       }
+
        @Override
        public boolean hasSingleLineage() {
                return false;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
index dc42b15a96..b7fe116707 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
@@ -53,6 +53,7 @@ import static 
org.apache.sysds.runtime.matrix.data.LibMatrixFourier.fft;
 import static org.apache.sysds.runtime.matrix.data.LibMatrixFourier.ifft;
 import static 
org.apache.sysds.runtime.matrix.data.LibMatrixFourier.fft_linearized;
 import static 
org.apache.sysds.runtime.matrix.data.LibMatrixFourier.ifft_linearized;
+import static org.apache.sysds.runtime.matrix.data.LibMatrixSTFT.stft;
 
 /**
  * Library for matrix operations that need invocation of 
@@ -85,6 +86,7 @@ public class LibCommonsMath
                        case "ifft":
                        case "fft_linearized":
                        case "ifft_linearized":
+                       case "stft":
                        case "svd": return true;
                        default: return false;
                }
@@ -415,6 +417,57 @@ public class LibCommonsMath
                return computeIFFT_LINEARIZED(re, null, threads);
        }
 
+
+       /**
+        * Function to perform STFT on a given matrix.
+        *
+        * @param re matrix object
+        * @param im matrix object
+        * @param windowSize of stft
+        * @param overlap of stft
+        * @return array of matrix blocks
+        */
+       private static MatrixBlock[] computeSTFT(MatrixBlock re, MatrixBlock 
im, int windowSize, int overlap, int threads) {
+               if (re == null) {
+                       throw new DMLRuntimeException("Invalid empty block");
+               } else if (im != null && !im.isEmptyBlock(false)) {
+                       re.sparseToDense();
+                       im.sparseToDense();
+                       return stft(re, im, windowSize, overlap, threads);
+               } else {
+                       if (re.isEmptyBlock(false)) {
+                               // Return the original matrix as the result
+                               int rows = re.getNumRows();
+                               int cols = re.getNumColumns();
+
+                               int stepSize = windowSize - overlap;
+                               if (stepSize == 0) {
+                                       throw new 
IllegalArgumentException("windowSize - overlap is zero");
+                               }
+
+                               int numberOfFramesPerRow = (cols - overlap + 
stepSize - 1) / stepSize;
+                               int rowLength= numberOfFramesPerRow * 
windowSize;
+                               int out_len = rowLength * rows;
+
+                               double[] out_zero = new double[out_len];
+
+                               return new MatrixBlock[]{new MatrixBlock(rows, 
rowLength, out_zero), new MatrixBlock(rows, rowLength, out_zero)};
+                               }
+                       re.sparseToDense();
+                       return stft(re, windowSize, overlap, threads);
+               }
+       }
+
+       /**
+        * Function to perform STFT on a given matrix.
+        *
+        * @param re matrix object
+        * @return array of matrix blocks
+        */
+       private static MatrixBlock[] computeSTFT(MatrixBlock re, int 
windowSize, int overlap, int threads) {
+               return computeSTFT(re, null, windowSize, overlap, threads);
+       }
+
        /**
         * Performs Singular Value Decomposition. Calls Apache Commons Math SVD.
         * X = U * Sigma * Vt, where X is the input matrix,
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSTFT.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSTFT.java
new file mode 100644
index 0000000000..9251d4cdcd
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSTFT.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data;
+
+import static 
org.apache.sysds.runtime.matrix.data.LibMatrixFourier.fft_one_dim;
+
+import org.apache.sysds.runtime.util.CommonThreadPool;
+import java.util.ArrayList;
+import java.util.concurrent.Future;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ExecutionException;
+import java.util.List;
+
+/**
+ * Liberary file containing methods to perform short time fourier 
transformations.
+ */
+public class LibMatrixSTFT {
+
+       /**
+        * Function to perform STFT on two given matrices with windowSize and 
overlap. The first one represents the real
+        * values and the second one the imaginary values. The output also 
contains one matrix for the real and one for the
+        * imaginary values. The results of the fourier transformations are 
appended to each other in the output.
+        *
+        * @param re         Matrix object representing the real values
+        * @param im         Matrix object representing the imaginary values
+        * @param windowSize Size of window
+        * @param overlap    Size of overlap
+        * @param threads    The number of threads to use
+        * @return array of two matrix blocks
+        */
+       public static MatrixBlock[] stft(MatrixBlock re, MatrixBlock im, int 
windowSize, int overlap, int threads) {
+
+               int rows = re.getNumRows();
+               int cols = re.getNumColumns();
+
+               int stepSize = windowSize - overlap;
+               if(stepSize == 0) {
+                       throw new IllegalArgumentException("windowSize - 
overlap is zero");
+               }
+
+               int numberOfFramesPerRow = (cols - overlap + stepSize - 1) / 
stepSize;
+               int rowLength = numberOfFramesPerRow * windowSize;
+               int out_len = rowLength * rows;
+
+               double[] stftOutput_re = new double[out_len];
+               double[] stftOutput_im = new double[out_len];
+
+               double[] re_inter = new double[out_len];
+               double[] im_inter = new double[out_len];
+
+               final ExecutorService pool = CommonThreadPool.get(threads);
+
+               final List<Future<?>> tasks = new ArrayList<>();
+
+               try {
+                       for(int h = 0; h < rows; h++) {
+                               final int finalH = h;
+                               tasks.add(pool.submit(() -> {
+                                       for(int i = 0; i < 
numberOfFramesPerRow; i++) {
+                                               for(int j = 0; j < windowSize; 
j++) {
+                                                       if((i * stepSize + j) < 
cols) {
+                                                               
stftOutput_re[finalH * rowLength + i * windowSize + j] = re
+                                                                       
.getDenseBlockValues()[finalH * cols + i * stepSize + j];
+                                                               
stftOutput_im[finalH * rowLength + i * windowSize + j] = im
+                                                                       
.getDenseBlockValues()[finalH * cols + i * stepSize + j];
+                                                       }
+                                               }
+                                               fft_one_dim(stftOutput_re, 
stftOutput_im, re_inter, im_inter, finalH * rowLength + i * windowSize,
+                                                       finalH * rowLength + (i 
+ 1) * windowSize, windowSize, 1);
+                                       }
+                               }));
+                       }
+                       for(Future<?> f : tasks)
+                               f.get();
+               }
+               catch(InterruptedException | ExecutionException e) {
+                       throw new RuntimeException(e);
+               }
+               finally {
+                       pool.shutdown();
+               }
+
+               return new MatrixBlock[] {new MatrixBlock(rows, rowLength, 
stftOutput_re),
+                       new MatrixBlock(rows, rowLength, stftOutput_im)};
+       }
+
+       /**
+        * Function to perform STFT on a given matrices with windowSize and 
overlap. The matrix represents the real values.
+        * The output contains one matrix for the real and one for the 
imaginary values. The results of the fourier
+        * transformations are appended to each other in the output.
+        *
+        * @param re         matrix object representing the real values
+        * @param windowSize size of window
+        * @param overlap    size of overlap
+        * @param threads    The number of threads to use
+        * @return array of two matrix blocks
+        */
+       public static MatrixBlock[] stft(MatrixBlock re, int windowSize, int 
overlap, int threads) {
+               return stft(re,
+                       new MatrixBlock(re.getNumRows(), re.getNumColumns(), 
new double[re.getNumRows() * re.getNumColumns()]),
+                       windowSize, overlap, threads);
+       }
+
+}
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java 
b/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java
index f79b44f5a1..6292a14138 100644
--- a/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java
+++ b/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.component.matrix;
 
 import static org.junit.Assert.fail;
 
+import org.apache.commons.lang3.NotImplementedException;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
@@ -120,6 +121,8 @@ public class EigenDecompTest {
                                case QR:
                                        m = 
LibCommonsMath.multiReturnOperations(in, "eigen_qr", threads, 1);
                                        break;
+                               default:
+                                       throw new NotImplementedException();
                        }
 
                        isValidDecomposition(in, m[1], m[0], tol, t.toString());
diff --git a/src/test/java/org/apache/sysds/test/component/matrix/STFTTest.java 
b/src/test/java/org/apache/sysds/test/component/matrix/STFTTest.java
new file mode 100644
index 0000000000..7d30528e01
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/matrix/STFTTest.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+ package org.apache.sysds.test.component.matrix;
+
+ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+ import org.junit.Test;
+ 
+ import static org.junit.Assert.assertArrayEquals;
+ import static org.apache.sysds.runtime.matrix.data.LibMatrixSTFT.stft;
+ 
+ public class STFTTest {
+ 
+        int threads = Runtime.getRuntime().availableProcessors();
+ 
+        @Test
+        public void simple_test() {
+ 
+                MatrixBlock re = new MatrixBlock(1, 16,  new double[]{0, 1, 2, 
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
+ 
+                MatrixBlock[] res = stft(re, 4, 2, threads);
+ 
+                double[] res_re = res[0].getDenseBlockValues();
+                double[] res_im = res[1].getDenseBlockValues();
+ 
+                double[] expected_re = {6, -2, -2, -2, 14, -2, -2, -2, 22, -2, 
-2, -2, 30, -2, -2, -2, 38, -2, -2, -2, 46, -2, -2, -2, 54, -2, -2, -2};
+                double[] expected_im = {0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 
0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2};
+ 
+                assertArrayEquals(expected_re, res_re, 0.0001);
+                assertArrayEquals(expected_im, res_im, 0.0001);
+ 
+        }
+ 
+        @Test
+        public void simple_test_two() {
+ 
+                MatrixBlock re = new MatrixBlock(1, 15,  new double[]{0, 1, 2, 
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ 
+                MatrixBlock[] res = stft(re, 4, 2, threads);
+ 
+                double[] res_re = res[0].getDenseBlockValues();
+                double[] res_im = res[1].getDenseBlockValues();
+ 
+                double[] expected_re = {6, -2, -2, -2, 14, -2, -2, -2, 22, -2, 
-2, -2, 30, -2, -2, -2, 38, -2, -2, -2, 46, -2, -2, -2, 39, -2, 13, -2};
+                double[] expected_im = {0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 
0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, -13, 0, 13};
+ 
+                /*
+                for (int i = 0; i < res_re.length; i++) {
+                        System.out.println(res_re[i] + " " + res_im[i]);
+                }
+                 */
+                assertArrayEquals(expected_re, res_re, 0.0001);
+                assertArrayEquals(expected_im, res_im, 0.0001);
+ 
+        }
+ 
+        @Test
+        public void matrix_block_one_dim_test(){
+ 
+                MatrixBlock re = new MatrixBlock(1, 4,  new double[]{0, 18, 
-15, 3});
+ 
+                MatrixBlock[] res = stft(re, 4, 0, threads);
+ 
+                double[] res_re = res[0].getDenseBlockValues();
+                double[] res_im = res[1].getDenseBlockValues();
+ 
+                double[] expected_re = {6, 15, -36, 15};
+                double[] expected_im = {0, -15, 0, 15};
+ 
+                assertArrayEquals(expected_re, res_re, 0.0001);
+                assertArrayEquals(expected_im, res_im, 0.0001);
+ 
+        }
+ 
+        @Test
+        public void matrix_block_one_dim_test2(){
+ 
+                MatrixBlock re = new MatrixBlock(1, 8,  new double[]{10, 5, 
-3, 8, 15, -6, 2, 0});
+ 
+                MatrixBlock[] res = stft(re, 4, 2, threads);
+ 
+                double[] res_re = res[0].getDenseBlockValues();
+                double[] res_im = res[1].getDenseBlockValues();
+ 
+                double[] expected_re = {20.0, 13.0, -6.0, 13.0, 14.0, -18.0, 
10.0, -18.0, 11.0, 13.0, 23.0, 13.0};
+                double[] expected_im = {0.0, 3.0, 0.0, -3.0, 0.0, -14.0, 0.0, 
14.0, 0.0, 6.0, 0.0, -6.0};
+ 
+                assertArrayEquals(expected_re, res_re, 0.0001);
+                assertArrayEquals(expected_im, res_im, 0.0001);
+ 
+        }
+ 
+        @Test
+        public void matrix_block_one_dim_test3(){
+ 
+                MatrixBlock re = new MatrixBlock(1, 8,  new double[]{10, 5, 
-3, 8, 15, -6, 2, 0});
+                MatrixBlock im = new MatrixBlock(1, 8,  new double[]{0, 0, 0, 
0, 0, 0, 0, 0});
+ 
+                MatrixBlock[] res = stft(re, im, 4, 2, threads);
+ 
+                double[] res_re = res[0].getDenseBlockValues();
+                double[] res_im = res[1].getDenseBlockValues();
+ 
+                double[] expected_re = {20.0, 13.0, -6.0, 13.0, 14.0, -18.0, 
10.0, -18.0, 11.0, 13.0, 23.0, 13.0};
+                double[] expected_im = {0.0, 3.0, 0.0, -3.0, 0.0, -14.0, 0.0, 
14.0, 0.0, 6.0, 0.0, -6.0};
+ 
+                assertArrayEquals(expected_re, res_re, 0.0001);
+                assertArrayEquals(expected_im, res_im, 0.0001);
+ 
+        }
+ 
+        @Test
+        public void test_two_x_eight() {
+ 
+                MatrixBlock re = new MatrixBlock(2, 8,  new double[]{0, 1, 2, 
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
+ 
+                MatrixBlock[] res = stft(re, 4, 1, threads);
+ 
+                double[] res_re = res[0].getDenseBlockValues();
+                double[] res_im = res[1].getDenseBlockValues();
+ 
+                double[] expected_re = {6, -2, -2, -2, 18, -2, -2, -2, 13, 6, 
-1, 6, 38, -2, -2, -2, 50, -2, -2, -2, 29, 14, -1, 14};
+                double[] expected_im = {0, 2, 0, -2, 0, 2, 0, -2, 0, -7, 0, 7, 
0, 2, 0, -2, 0, 2, 0, -2, 0, -15, 0, 15};
+ 
+                assertArrayEquals(expected_re, res_re, 0.0001);
+                assertArrayEquals(expected_im, res_im, 0.0001);
+ 
+        }
+ 
+        @Test
+        public void test_four_x_four() {
+ 
+                MatrixBlock re = new MatrixBlock(4, 4,  new double[]{0, 1, 2, 
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
+ 
+                MatrixBlock[] res = stft(re, 4, 0, threads);
+ 
+                double[] res_re = res[0].getDenseBlockValues();
+                double[] res_im = res[1].getDenseBlockValues();
+ 
+                double[] expected_re = {6, -2, -2, -2, 22, -2, -2, -2, 38, -2, 
-2, -2, 54, -2, -2, -2};
+                double[] expected_im = {0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 
0, 2, 0, -2};
+ 
+                assertArrayEquals(expected_re, res_re, 0.0001);
+                assertArrayEquals(expected_im, res_im, 0.0001);
+ 
+        }
+ 
+        @Test
+        public void test_four_x_four_two() {
+ 
+                MatrixBlock re = new MatrixBlock(4, 5,  new double[]{0, 1, 2, 
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19});
+ 
+                MatrixBlock[] res = stft(re, 4, 1, threads);
+ 
+                double[] res_re = res[0].getDenseBlockValues();
+                double[] res_im = res[1].getDenseBlockValues();
+ 
+                double[] expected_re = {6.0, -2.0, -2.0, -2.0, 7.0, 3.0, -1.0, 
3.0, 26.0, -2.0, -2.0, -2.0, 17.0, 8.0, -1.0, 8.0, 46.0, -2.0, -2.0, -2.0, 
27.0, 13.0, -1.0, 13.0, 66.0, -2.0, -2.0, -2.0, 37.0, 18.0, -1.0, 18.0};
+                double[] expected_im = {0.0, 2.0, 0.0, -2.0, 0.0, -4.0, 0.0, 
4.0, 0.0, 2.0, 0.0, -2.0, 0.0, -9.0, 0.0, 9.0, 0.0, 2.0, 0.0, -2.0, 0.0, -14.0, 
0.0, 14.0, 0.0, 2.0, 0.0, -2.0, 0.0, -19.0, 0.0, 19.0};
+ 
+                assertArrayEquals(expected_re, res_re, 0.0001);
+                assertArrayEquals(expected_im, res_im, 0.0001);
+ 
+        }
+ 
+ }
\ No newline at end of file

Reply via email to