This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 35b8e03cbb62d9ea26f0417abfd100dbdef2e002 Author: Mufan Wang <[email protected]> AuthorDate: Thu Apr 4 17:18:09 2024 +0200 [SYSTEMDS-3686] STFT This commit adds a short time fourier transformation to the system. this applies fast fourier transformations on windows of different stride and widths, enabeling applications such as sound classification. LDE 23/24 project Co-authored-by: Mufan Wang <[email protected]> Co-authored-by: Frederic Caspar Zoepffel <[email protected]> Co-authored-by: Jessica Eva Sophie Priebe <[email protected]> Closes #2000 --- .../java/org/apache/sysds/common/Builtins.java | 1 + .../java/org/apache/sysds/hops/FunctionOp.java | 9 + .../sysds/parser/BuiltinFunctionExpression.java | 66 ++++++++ .../org/apache/sysds/parser/DMLTranslator.java | 1 + .../runtime/instructions/CPInstructionParser.java | 1 + .../instructions/cp/ComputationCPInstruction.java | 18 +- .../cp/MultiReturnBuiltinCPInstruction.java | 8 + ...ltiReturnComplexMatrixBuiltinCPInstruction.java | 65 +++++++- .../sysds/runtime/matrix/data/LibCommonsMath.java | 53 ++++++ .../sysds/runtime/matrix/data/LibMatrixSTFT.java | 121 ++++++++++++++ .../test/component/matrix/EigenDecompTest.java | 3 + .../sysds/test/component/matrix/STFTTest.java | 182 +++++++++++++++++++++ 12 files changed, 525 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 7e83984e47..8f113c092f 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -310,6 +310,7 @@ public enum Builtins { STATSNA("statsNA", true), STRATSTATS("stratstats", true), STEPLM("steplm",true, ReturnType.MULTI_RETURN), + STFT("stft", false, ReturnType.MULTI_RETURN), SQRT("sqrt", false), SUM("sum", false), SVD("svd", false, ReturnType.MULTI_RETURN), diff --git a/src/main/java/org/apache/sysds/hops/FunctionOp.java b/src/main/java/org/apache/sysds/hops/FunctionOp.java index ffc12c30ee..95b5411500 100644 --- a/src/main/java/org/apache/sysds/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysds/hops/FunctionOp.java @@ -221,6 +221,11 @@ public class FunctionOp extends Hop long outputIm = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0); return outputRe+outputIm; } + else if ( getFunctionName().equalsIgnoreCase("stft") ) { + long outputRe = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0); + long outputIm = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0); + return outputRe+outputIm; + } else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) { // TODO: To allow for initial version to always run on the GPU return 0; @@ -286,6 +291,10 @@ public class FunctionOp extends Hop // 2 matrices of size same as the input return 2*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0); } + else if ( getFunctionName().equalsIgnoreCase("stft") ) { + // 2 matrices of size same as the input + return 2*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0); + } else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") || getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) { return 0; diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 4b3c8e82f7..c3f1026627 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -589,6 +589,72 @@ public class BuiltinFunctionExpression extends DataIdentifier { break; } + case STFT: { + checkMatrixParam(getFirstExpr()); + + if((getFirstExpr() == null || getSecondExpr() == null || getThirdExpr() == null) && _args.length > 0) { + raiseValidateError("Missing argument for function " + this.getOpCode(), false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + else if(getFifthExpr() != null) { + raiseValidateError("Invalid number of arguments for function " + this.getOpCode().toString().toLowerCase() + + "(). This function only takes 3 or 4 arguments.", false); + } + else if(_args.length == 3) { + checkScalarParam(getSecondExpr()); + checkScalarParam(getThirdExpr()); + if(!isPowerOfTwo(((ConstIdentifier) getSecondExpr().getOutput()).getLongValue())) { + raiseValidateError( + "This FFT implementation is only defined for matrices with dimensions that are powers of 2." + + "The window size (2nd argument) is not a power of two", + false, LanguageErrorCodes.INVALID_PARAMETERS); + } + else if(((ConstIdentifier) getSecondExpr().getOutput()) + .getLongValue() <= ((ConstIdentifier) getThirdExpr().getOutput()).getLongValue()) { + raiseValidateError("Overlap can't be larger than or equal to the window size.", false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + } + else if(_args.length == 4) { + checkMatrixParam(getSecondExpr()); + checkScalarParam(getThirdExpr()); + checkScalarParam(getFourthExpr()); + if(!isPowerOfTwo(((ConstIdentifier) getThirdExpr().getOutput()).getLongValue())) { + raiseValidateError( + "This FFT implementation is only defined for matrices with dimensions that are powers of 2." + + "The window size (3rd argument) is not a power of two", + false, LanguageErrorCodes.INVALID_PARAMETERS); + } + else if(getFirstExpr().getOutput().getDim1() != getSecondExpr().getOutput().getDim1() || + getFirstExpr().getOutput().getDim2() != getSecondExpr().getOutput().getDim2()) { + raiseValidateError("The real and imaginary part of the provided matrix are of different dimensions.", + false); + } + else if(((ConstIdentifier) getThirdExpr().getOutput()) + .getLongValue() <= ((ConstIdentifier) getFourthExpr().getOutput()).getLongValue()) { + raiseValidateError("Overlap can't be larger than or equal to the window size.", false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + } + + // setup output properties + DataIdentifier stftOut1 = (DataIdentifier) getOutputs()[0]; + DataIdentifier stftOut2 = (DataIdentifier) getOutputs()[1]; + + // Output1 - stft Values + stftOut1.setDataType(DataType.MATRIX); + stftOut1.setValueType(ValueType.FP64); + stftOut1.setDimensions(getFirstExpr().getOutput().getDim1(), getFirstExpr().getOutput().getDim2()); + stftOut1.setBlocksize(getFirstExpr().getOutput().getBlocksize()); + + // Output2 - stft Vectors + stftOut2.setDataType(DataType.MATRIX); + stftOut2.setValueType(ValueType.FP64); + stftOut2.setDimensions(getFirstExpr().getOutput().getDim1(), getFirstExpr().getOutput().getDim2()); + stftOut2.setBlocksize(getFirstExpr().getOutput().getBlocksize()); + + break; + } case REMOVE: { checkNumParameters(2); checkListParam(getFirstExpr()); diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index 97d5523961..2b876c12be 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2240,6 +2240,7 @@ public class DMLTranslator case IFFT: case FFT_LINEARIZED: case IFFT_LINEARIZED: + case STFT: case LSTM: case LSTM_BACKWARD: case BATCH_NORM2D: diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index b78c1e2b49..994c1cd51a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -338,6 +338,7 @@ public class CPInstructionParser extends InstructionParser { String2CPInstructionType.put( "ifft", CPType.MultiReturnComplexMatrixBuiltin); String2CPInstructionType.put( "fft_linearized", CPType.MultiReturnBuiltin); String2CPInstructionType.put( "ifft_linearized", CPType.MultiReturnComplexMatrixBuiltin); + String2CPInstructionType.put( "stft", CPType.MultiReturnComplexMatrixBuiltin); String2CPInstructionType.put( "svd", CPType.MultiReturnBuiltin); String2CPInstructionType.put( "partition", CPType.Partition); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ComputationCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ComputationCPInstruction.java index de45036991..cc32f624d0 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ComputationCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ComputationCPInstruction.java @@ -35,7 +35,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator; public abstract class ComputationCPInstruction extends CPInstruction implements LineageTraceable { public final CPOperand output; - public final CPOperand input1, input2, input3; + public final CPOperand input1, input2, input3, input4; protected ComputationCPInstruction(CPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { @@ -43,6 +43,7 @@ public abstract class ComputationCPInstruction extends CPInstruction implements input1 = in1; input2 = in2; input3 = null; + input4 = null; output = out; } @@ -52,6 +53,17 @@ public abstract class ComputationCPInstruction extends CPInstruction implements input1 = in1; input2 = in2; input3 = in3; + input4 = null; + output = out; + } + + protected ComputationCPInstruction(CPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, + CPOperand in4, CPOperand out, String opcode, String istr) { + super(type, op, opcode, istr); + input1 = in1; + input2 = in2; + input3 = in3; + input4 = in4; output = out; } @@ -64,7 +76,7 @@ public abstract class ComputationCPInstruction extends CPInstruction implements } public CPOperand[] getInputs(){ - return new CPOperand[]{input1, input2, input3}; + return new CPOperand[]{input1, input2, input3, input4}; } public boolean hasFrameInput() { @@ -74,6 +86,8 @@ public abstract class ComputationCPInstruction extends CPInstruction implements return true; if (input3 != null && input3.isFrame()) return true; + if (input4 != null && input4.isFrame()) + return true; return false; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java index 4719948ec0..a65b56c8ac 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java @@ -124,6 +124,14 @@ public class MultiReturnBuiltinCPInstruction extends ComputationCPInstruction { return new MultiReturnBuiltinCPInstruction(null, null, outputs, opcode, str); } + else if ( opcode.equalsIgnoreCase("stft") ) { + // one input and two outputs + CPOperand in1 = new CPOperand(parts[1]); + outputs.add ( new CPOperand(parts[2], ValueType.FP64, DataType.MATRIX) ); + outputs.add ( new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX) ); + + return new MultiReturnBuiltinCPInstruction(null, in1, outputs, opcode, str); + } else if ( opcode.equalsIgnoreCase("svd") ) { CPOperand in1 = new CPOperand(parts[1]); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java index f9c92d4858..f75b231052 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java @@ -50,6 +50,12 @@ public class MultiReturnComplexMatrixBuiltinCPInstruction extends ComputationCPI _outputs = outputs; } + private MultiReturnComplexMatrixBuiltinCPInstruction(Operator op, CPOperand input1, CPOperand input2, + CPOperand input3, CPOperand input4, ArrayList<CPOperand> outputs, String opcode, String istr) { + super(CPType.MultiReturnBuiltin, op, input1, input2, input3, input4, outputs.get(0), opcode, istr); + _outputs = outputs; + } + public CPOperand getOutput(int i) { return _outputs.get(i); } @@ -102,6 +108,27 @@ public class MultiReturnComplexMatrixBuiltinCPInstruction extends ComputationCPI return new MultiReturnComplexMatrixBuiltinCPInstruction(null, in1, outputs, opcode, str); } + else if(parts.length == 6 && opcode.equalsIgnoreCase("stft")) { + CPOperand in1 = new CPOperand(parts[1]); + CPOperand windowSize = new CPOperand(parts[2]); + CPOperand overlap = new CPOperand(parts[3]); + outputs.add(new CPOperand(parts[4], ValueType.FP64, DataType.MATRIX)); + outputs.add(new CPOperand(parts[5], ValueType.FP64, DataType.MATRIX)); + + return new MultiReturnComplexMatrixBuiltinCPInstruction(null, in1, null, windowSize, overlap, outputs, opcode, + str); + } + else if(parts.length == 7 && opcode.equalsIgnoreCase("stft")) { + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand windowSize = new CPOperand(parts[3]); + CPOperand overlap = new CPOperand(parts[4]); + outputs.add(new CPOperand(parts[5], ValueType.FP64, DataType.MATRIX)); + outputs.add(new CPOperand(parts[6], ValueType.FP64, DataType.MATRIX)); + + return new MultiReturnComplexMatrixBuiltinCPInstruction(null, in1, in2, windowSize, overlap, outputs, opcode, + str); + } else { throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode); } @@ -114,7 +141,11 @@ public class MultiReturnComplexMatrixBuiltinCPInstruction extends ComputationCPI @Override public void processInstruction(ExecutionContext ec) { - if(input2 == null) + if(getOpcode().equals("stft") && input2 == null) + processSTFTInstruction(ec); + else if(getOpcode().equals("stft")) + processSTFTTwoInstruction(ec); + else if(input2 == null) processOneInputInstruction(ec); else processTwoInputInstruction(ec); @@ -148,6 +179,38 @@ public class MultiReturnComplexMatrixBuiltinCPInstruction extends ComputationCPI } } + private void processSTFTInstruction(ExecutionContext ec) { + if(!LibCommonsMath.isSupportedMultiReturnOperation(getOpcode())) + throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + getOpcode()); + + MatrixBlock in1 = ec.getMatrixInput(input1.getName()); + // MatrixBlock in2 = ec.getMatrixInput(input2.getName()); + int windowSize = Integer.parseInt(input3.getName()); + int overlap = Integer.parseInt(input4.getName()); + MatrixBlock[] out = LibCommonsMath.multiReturnOperations(in1, getOpcode(), windowSize, overlap); + ec.releaseMatrixInput(input1.getName()); + + for(int i = 0; i < _outputs.size(); i++) { + ec.setMatrixOutput(_outputs.get(i).getName(), out[i]); + } + } + + private void processSTFTTwoInstruction(ExecutionContext ec) { + if(!LibCommonsMath.isSupportedMultiReturnOperation(getOpcode())) + throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + getOpcode()); + + MatrixBlock in1 = ec.getMatrixInput(input1.getName()); + MatrixBlock in2 = ec.getMatrixInput(input2.getName()); + int windowSize = Integer.parseInt(input3.getName()); + int overlap = Integer.parseInt(input4.getName()); + MatrixBlock[] out = LibCommonsMath.multiReturnOperations(in1, in2, getOpcode(), windowSize, overlap); + ec.releaseMatrixInput(input1.getName(), input2.getName()); + + for(int i = 0; i < _outputs.size(); i++) { + ec.setMatrixOutput(_outputs.get(i).getName(), out[i]); + } + } + @Override public boolean hasSingleLineage() { return false; diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java index dc42b15a96..b7fe116707 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java @@ -53,6 +53,7 @@ import static org.apache.sysds.runtime.matrix.data.LibMatrixFourier.fft; import static org.apache.sysds.runtime.matrix.data.LibMatrixFourier.ifft; import static org.apache.sysds.runtime.matrix.data.LibMatrixFourier.fft_linearized; import static org.apache.sysds.runtime.matrix.data.LibMatrixFourier.ifft_linearized; +import static org.apache.sysds.runtime.matrix.data.LibMatrixSTFT.stft; /** * Library for matrix operations that need invocation of @@ -85,6 +86,7 @@ public class LibCommonsMath case "ifft": case "fft_linearized": case "ifft_linearized": + case "stft": case "svd": return true; default: return false; } @@ -415,6 +417,57 @@ public class LibCommonsMath return computeIFFT_LINEARIZED(re, null, threads); } + + /** + * Function to perform STFT on a given matrix. + * + * @param re matrix object + * @param im matrix object + * @param windowSize of stft + * @param overlap of stft + * @return array of matrix blocks + */ + private static MatrixBlock[] computeSTFT(MatrixBlock re, MatrixBlock im, int windowSize, int overlap, int threads) { + if (re == null) { + throw new DMLRuntimeException("Invalid empty block"); + } else if (im != null && !im.isEmptyBlock(false)) { + re.sparseToDense(); + im.sparseToDense(); + return stft(re, im, windowSize, overlap, threads); + } else { + if (re.isEmptyBlock(false)) { + // Return the original matrix as the result + int rows = re.getNumRows(); + int cols = re.getNumColumns(); + + int stepSize = windowSize - overlap; + if (stepSize == 0) { + throw new IllegalArgumentException("windowSize - overlap is zero"); + } + + int numberOfFramesPerRow = (cols - overlap + stepSize - 1) / stepSize; + int rowLength= numberOfFramesPerRow * windowSize; + int out_len = rowLength * rows; + + double[] out_zero = new double[out_len]; + + return new MatrixBlock[]{new MatrixBlock(rows, rowLength, out_zero), new MatrixBlock(rows, rowLength, out_zero)}; + } + re.sparseToDense(); + return stft(re, windowSize, overlap, threads); + } + } + + /** + * Function to perform STFT on a given matrix. + * + * @param re matrix object + * @return array of matrix blocks + */ + private static MatrixBlock[] computeSTFT(MatrixBlock re, int windowSize, int overlap, int threads) { + return computeSTFT(re, null, windowSize, overlap, threads); + } + /** * Performs Singular Value Decomposition. Calls Apache Commons Math SVD. * X = U * Sigma * Vt, where X is the input matrix, diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSTFT.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSTFT.java new file mode 100644 index 0000000000..9251d4cdcd --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSTFT.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.matrix.data; + +import static org.apache.sysds.runtime.matrix.data.LibMatrixFourier.fft_one_dim; + +import org.apache.sysds.runtime.util.CommonThreadPool; +import java.util.ArrayList; +import java.util.concurrent.Future; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ExecutionException; +import java.util.List; + +/** + * Liberary file containing methods to perform short time fourier transformations. + */ +public class LibMatrixSTFT { + + /** + * Function to perform STFT on two given matrices with windowSize and overlap. The first one represents the real + * values and the second one the imaginary values. The output also contains one matrix for the real and one for the + * imaginary values. The results of the fourier transformations are appended to each other in the output. + * + * @param re Matrix object representing the real values + * @param im Matrix object representing the imaginary values + * @param windowSize Size of window + * @param overlap Size of overlap + * @param threads The number of threads to use + * @return array of two matrix blocks + */ + public static MatrixBlock[] stft(MatrixBlock re, MatrixBlock im, int windowSize, int overlap, int threads) { + + int rows = re.getNumRows(); + int cols = re.getNumColumns(); + + int stepSize = windowSize - overlap; + if(stepSize == 0) { + throw new IllegalArgumentException("windowSize - overlap is zero"); + } + + int numberOfFramesPerRow = (cols - overlap + stepSize - 1) / stepSize; + int rowLength = numberOfFramesPerRow * windowSize; + int out_len = rowLength * rows; + + double[] stftOutput_re = new double[out_len]; + double[] stftOutput_im = new double[out_len]; + + double[] re_inter = new double[out_len]; + double[] im_inter = new double[out_len]; + + final ExecutorService pool = CommonThreadPool.get(threads); + + final List<Future<?>> tasks = new ArrayList<>(); + + try { + for(int h = 0; h < rows; h++) { + final int finalH = h; + tasks.add(pool.submit(() -> { + for(int i = 0; i < numberOfFramesPerRow; i++) { + for(int j = 0; j < windowSize; j++) { + if((i * stepSize + j) < cols) { + stftOutput_re[finalH * rowLength + i * windowSize + j] = re + .getDenseBlockValues()[finalH * cols + i * stepSize + j]; + stftOutput_im[finalH * rowLength + i * windowSize + j] = im + .getDenseBlockValues()[finalH * cols + i * stepSize + j]; + } + } + fft_one_dim(stftOutput_re, stftOutput_im, re_inter, im_inter, finalH * rowLength + i * windowSize, + finalH * rowLength + (i + 1) * windowSize, windowSize, 1); + } + })); + } + for(Future<?> f : tasks) + f.get(); + } + catch(InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } + + return new MatrixBlock[] {new MatrixBlock(rows, rowLength, stftOutput_re), + new MatrixBlock(rows, rowLength, stftOutput_im)}; + } + + /** + * Function to perform STFT on a given matrices with windowSize and overlap. The matrix represents the real values. + * The output contains one matrix for the real and one for the imaginary values. The results of the fourier + * transformations are appended to each other in the output. + * + * @param re matrix object representing the real values + * @param windowSize size of window + * @param overlap size of overlap + * @param threads The number of threads to use + * @return array of two matrix blocks + */ + public static MatrixBlock[] stft(MatrixBlock re, int windowSize, int overlap, int threads) { + return stft(re, + new MatrixBlock(re.getNumRows(), re.getNumColumns(), new double[re.getNumRows() * re.getNumColumns()]), + windowSize, overlap, threads); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java b/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java index f79b44f5a1..6292a14138 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/EigenDecompTest.java @@ -21,6 +21,7 @@ package org.apache.sysds.test.component.matrix; import static org.junit.Assert.fail; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.matrix.data.LibCommonsMath; @@ -120,6 +121,8 @@ public class EigenDecompTest { case QR: m = LibCommonsMath.multiReturnOperations(in, "eigen_qr", threads, 1); break; + default: + throw new NotImplementedException(); } isValidDecomposition(in, m[1], m[0], tol, t.toString()); diff --git a/src/test/java/org/apache/sysds/test/component/matrix/STFTTest.java b/src/test/java/org/apache/sysds/test/component/matrix/STFTTest.java new file mode 100644 index 0000000000..7d30528e01 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/matrix/STFTTest.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + package org.apache.sysds.test.component.matrix; + + import org.apache.sysds.runtime.matrix.data.MatrixBlock; + import org.junit.Test; + + import static org.junit.Assert.assertArrayEquals; + import static org.apache.sysds.runtime.matrix.data.LibMatrixSTFT.stft; + + public class STFTTest { + + int threads = Runtime.getRuntime().availableProcessors(); + + @Test + public void simple_test() { + + MatrixBlock re = new MatrixBlock(1, 16, new double[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + + MatrixBlock[] res = stft(re, 4, 2, threads); + + double[] res_re = res[0].getDenseBlockValues(); + double[] res_im = res[1].getDenseBlockValues(); + + double[] expected_re = {6, -2, -2, -2, 14, -2, -2, -2, 22, -2, -2, -2, 30, -2, -2, -2, 38, -2, -2, -2, 46, -2, -2, -2, 54, -2, -2, -2}; + double[] expected_im = {0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2}; + + assertArrayEquals(expected_re, res_re, 0.0001); + assertArrayEquals(expected_im, res_im, 0.0001); + + } + + @Test + public void simple_test_two() { + + MatrixBlock re = new MatrixBlock(1, 15, new double[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + + MatrixBlock[] res = stft(re, 4, 2, threads); + + double[] res_re = res[0].getDenseBlockValues(); + double[] res_im = res[1].getDenseBlockValues(); + + double[] expected_re = {6, -2, -2, -2, 14, -2, -2, -2, 22, -2, -2, -2, 30, -2, -2, -2, 38, -2, -2, -2, 46, -2, -2, -2, 39, -2, 13, -2}; + double[] expected_im = {0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, -13, 0, 13}; + + /* + for (int i = 0; i < res_re.length; i++) { + System.out.println(res_re[i] + " " + res_im[i]); + } + */ + assertArrayEquals(expected_re, res_re, 0.0001); + assertArrayEquals(expected_im, res_im, 0.0001); + + } + + @Test + public void matrix_block_one_dim_test(){ + + MatrixBlock re = new MatrixBlock(1, 4, new double[]{0, 18, -15, 3}); + + MatrixBlock[] res = stft(re, 4, 0, threads); + + double[] res_re = res[0].getDenseBlockValues(); + double[] res_im = res[1].getDenseBlockValues(); + + double[] expected_re = {6, 15, -36, 15}; + double[] expected_im = {0, -15, 0, 15}; + + assertArrayEquals(expected_re, res_re, 0.0001); + assertArrayEquals(expected_im, res_im, 0.0001); + + } + + @Test + public void matrix_block_one_dim_test2(){ + + MatrixBlock re = new MatrixBlock(1, 8, new double[]{10, 5, -3, 8, 15, -6, 2, 0}); + + MatrixBlock[] res = stft(re, 4, 2, threads); + + double[] res_re = res[0].getDenseBlockValues(); + double[] res_im = res[1].getDenseBlockValues(); + + double[] expected_re = {20.0, 13.0, -6.0, 13.0, 14.0, -18.0, 10.0, -18.0, 11.0, 13.0, 23.0, 13.0}; + double[] expected_im = {0.0, 3.0, 0.0, -3.0, 0.0, -14.0, 0.0, 14.0, 0.0, 6.0, 0.0, -6.0}; + + assertArrayEquals(expected_re, res_re, 0.0001); + assertArrayEquals(expected_im, res_im, 0.0001); + + } + + @Test + public void matrix_block_one_dim_test3(){ + + MatrixBlock re = new MatrixBlock(1, 8, new double[]{10, 5, -3, 8, 15, -6, 2, 0}); + MatrixBlock im = new MatrixBlock(1, 8, new double[]{0, 0, 0, 0, 0, 0, 0, 0}); + + MatrixBlock[] res = stft(re, im, 4, 2, threads); + + double[] res_re = res[0].getDenseBlockValues(); + double[] res_im = res[1].getDenseBlockValues(); + + double[] expected_re = {20.0, 13.0, -6.0, 13.0, 14.0, -18.0, 10.0, -18.0, 11.0, 13.0, 23.0, 13.0}; + double[] expected_im = {0.0, 3.0, 0.0, -3.0, 0.0, -14.0, 0.0, 14.0, 0.0, 6.0, 0.0, -6.0}; + + assertArrayEquals(expected_re, res_re, 0.0001); + assertArrayEquals(expected_im, res_im, 0.0001); + + } + + @Test + public void test_two_x_eight() { + + MatrixBlock re = new MatrixBlock(2, 8, new double[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + + MatrixBlock[] res = stft(re, 4, 1, threads); + + double[] res_re = res[0].getDenseBlockValues(); + double[] res_im = res[1].getDenseBlockValues(); + + double[] expected_re = {6, -2, -2, -2, 18, -2, -2, -2, 13, 6, -1, 6, 38, -2, -2, -2, 50, -2, -2, -2, 29, 14, -1, 14}; + double[] expected_im = {0, 2, 0, -2, 0, 2, 0, -2, 0, -7, 0, 7, 0, 2, 0, -2, 0, 2, 0, -2, 0, -15, 0, 15}; + + assertArrayEquals(expected_re, res_re, 0.0001); + assertArrayEquals(expected_im, res_im, 0.0001); + + } + + @Test + public void test_four_x_four() { + + MatrixBlock re = new MatrixBlock(4, 4, new double[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + + MatrixBlock[] res = stft(re, 4, 0, threads); + + double[] res_re = res[0].getDenseBlockValues(); + double[] res_im = res[1].getDenseBlockValues(); + + double[] expected_re = {6, -2, -2, -2, 22, -2, -2, -2, 38, -2, -2, -2, 54, -2, -2, -2}; + double[] expected_im = {0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2, 0, 2, 0, -2}; + + assertArrayEquals(expected_re, res_re, 0.0001); + assertArrayEquals(expected_im, res_im, 0.0001); + + } + + @Test + public void test_four_x_four_two() { + + MatrixBlock re = new MatrixBlock(4, 5, new double[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}); + + MatrixBlock[] res = stft(re, 4, 1, threads); + + double[] res_re = res[0].getDenseBlockValues(); + double[] res_im = res[1].getDenseBlockValues(); + + double[] expected_re = {6.0, -2.0, -2.0, -2.0, 7.0, 3.0, -1.0, 3.0, 26.0, -2.0, -2.0, -2.0, 17.0, 8.0, -1.0, 8.0, 46.0, -2.0, -2.0, -2.0, 27.0, 13.0, -1.0, 13.0, 66.0, -2.0, -2.0, -2.0, 37.0, 18.0, -1.0, 18.0}; + double[] expected_im = {0.0, 2.0, 0.0, -2.0, 0.0, -4.0, 0.0, 4.0, 0.0, 2.0, 0.0, -2.0, 0.0, -9.0, 0.0, 9.0, 0.0, 2.0, 0.0, -2.0, 0.0, -14.0, 0.0, 14.0, 0.0, 2.0, 0.0, -2.0, 0.0, -19.0, 0.0, 19.0}; + + assertArrayEquals(expected_re, res_re, 0.0001); + assertArrayEquals(expected_im, res_im, 0.0001); + + } + + } \ No newline at end of file
