This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit d22ffeccd7b10af366574f7fe03d637be9db49d5 Author: Frederic Caspar Zoepffel <[email protected]> AuthorDate: Thu Apr 4 17:17:44 2024 +0200 [SYSTEMDS-3685] Python FFT This commit adds support in the Python API for fft and ifft. Future work is to add the linearized versions of the commands. LDE 23/24 project Co-authored-by: Mufan Wang <[email protected]> Co-authored-by: Frederic Caspar Zoepffel <[email protected]> Co-authored-by: Jessica Eva Sophie Priebe <[email protected]> Closes #1983 --- .../sysds/parser/BuiltinFunctionExpression.java | 148 ++++++--- .../python/systemds/context/systemds_context.py | 37 ++- src/main/python/tests/matrix/test_fft.py | 333 +++++++++++++++++++++ 3 files changed, 479 insertions(+), 39 deletions(-) diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 5e86a2fd8e..4b3c8e82f7 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -381,20 +381,41 @@ public class BuiltinFunctionExpression extends DataIdentifier { break; } case FFT: { + + Expression expressionOne = getFirstExpr(); + Expression expressionTwo = getSecondExpr(); + + if(expressionOne == null) { + raiseValidateError("The first argument to " + _opcode + " cannot be null.", false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + else if(expressionOne.getOutput() == null || expressionOne.getOutput().getDim1() == 0 || + expressionOne.getOutput().getDim2() == 0) { + raiseValidateError("The first argument to " + _opcode + " cannot be an empty matrix.", false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + else if(expressionTwo != null) { + raiseValidateError("Too many arguments. This FFT implementation is only defined for real inputs.", false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + else if(!isPowerOfTwo(expressionOne.getOutput().getDim1()) || + !isPowerOfTwo(expressionOne.getOutput().getDim2())) { + raiseValidateError( + "This FFT implementation is only defined for matrices with dimensions that are powers of 2.", false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + checkNumParameters(1); - checkMatrixParam(getFirstExpr()); + checkMatrixParam(expressionOne); - // setup output properties DataIdentifier fftOut1 = (DataIdentifier) getOutputs()[0]; DataIdentifier fftOut2 = (DataIdentifier) getOutputs()[1]; - // Output1 - FFT Values fftOut1.setDataType(DataType.MATRIX); fftOut1.setValueType(ValueType.FP64); fftOut1.setDimensions(getFirstExpr().getOutput().getDim1(), getFirstExpr().getOutput().getDim2()); fftOut1.setBlocksize(getFirstExpr().getOutput().getBlocksize()); - // Output2 - FFT Vectors fftOut2.setDataType(DataType.MATRIX); fftOut2.setValueType(ValueType.FP64); fftOut2.setDimensions(getFirstExpr().getOutput().getDim1(), getFirstExpr().getOutput().getDim2()); @@ -405,16 +426,53 @@ public class BuiltinFunctionExpression extends DataIdentifier { } case IFFT: { Expression expressionTwo = getSecondExpr(); - checkNumParameters(getSecondExpr() != null ? 2 : 1); - checkMatrixParam(getFirstExpr()); - if (expressionTwo != null) - checkMatrixParam(getSecondExpr()); + Expression expressionOne = getFirstExpr(); + + if(expressionOne == null) { + raiseValidateError("The first argument to " + _opcode + " cannot be null.", false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + else if(expressionOne.getOutput() == null || expressionOne.getOutput().getDim1() == 0 || + expressionOne.getOutput().getDim2() == 0) { + raiseValidateError("The first argument to " + _opcode + " cannot be an empty matrix.", false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + else if(expressionTwo != null) { + if(expressionTwo.getOutput() == null || expressionTwo.getOutput().getDim1() == 0 || + expressionTwo.getOutput().getDim2() == 0) { + raiseValidateError("The second argument to " + _opcode + + " cannot be an empty matrix. Provide either only a real matrix or a filled real and imaginary one.", + false, LanguageErrorCodes.INVALID_PARAMETERS); + } + } + + checkNumParameters(expressionTwo != null ? 2 : 1); + checkMatrixParam(expressionOne); + if(expressionTwo != null && expressionOne != null) { + checkMatrixParam(expressionTwo); + if(expressionOne.getOutput().getDim1() != expressionTwo.getOutput().getDim1() || + expressionOne.getOutput().getDim2() != expressionTwo.getOutput().getDim2()) + raiseValidateError("The real and imaginary part of the provided matrix are of different dimensions.", + false); + else if(!isPowerOfTwo(expressionTwo.getOutput().getDim1()) || + !isPowerOfTwo(expressionTwo.getOutput().getDim2())) { + raiseValidateError( + "This IFFT implementation is only defined for matrices with dimensions that are powers of 2.", false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + } + else if(expressionOne != null) { + if(!isPowerOfTwo(expressionOne.getOutput().getDim1()) || + !isPowerOfTwo(expressionOne.getOutput().getDim2())) { + raiseValidateError( + "This IFFT implementation is only defined for matrices with dimensions that are powers of 2.", false, + LanguageErrorCodes.INVALID_PARAMETERS); + } + } - // setup output properties DataIdentifier ifftOut1 = (DataIdentifier) getOutputs()[0]; DataIdentifier ifftOut2 = (DataIdentifier) getOutputs()[1]; - // Output1 - ifft Values ifftOut1.setDataType(DataType.MATRIX); ifftOut1.setValueType(ValueType.FP64); ifftOut1.setDimensions(getFirstExpr().getOutput().getDim1(), getFirstExpr().getOutput().getDim2()); @@ -433,20 +491,24 @@ public class BuiltinFunctionExpression extends DataIdentifier { Expression expressionOne = getFirstExpr(); Expression expressionTwo = getSecondExpr(); - if (expressionOne == null) { - raiseValidateError("The first argument to " + _opcode + " cannot be null.", false, LanguageErrorCodes.INVALID_PARAMETERS); + if(expressionOne == null) { + raiseValidateError("The first argument to " + _opcode + " cannot be null.", false, + LanguageErrorCodes.INVALID_PARAMETERS); } - - else if (expressionOne.getOutput() == null || expressionOne.getOutput().getDim1() == 0 || expressionOne.getOutput().getDim2() == 0) { - raiseValidateError("The first argument to " + _opcode + " cannot be an empty matrix.", false, LanguageErrorCodes.INVALID_PARAMETERS); + else if(expressionOne.getOutput() == null || expressionOne.getOutput().getDim1() == 0 || + expressionOne.getOutput().getDim2() == 0) { + raiseValidateError("The first argument to " + _opcode + " cannot be an empty matrix.", false, + LanguageErrorCodes.INVALID_PARAMETERS); } - - else if (expressionTwo != null) { - raiseValidateError("Too many arguments. This FFT_LINEARIZED implementation is only defined for real inputs.", false, LanguageErrorCodes.INVALID_PARAMETERS); + else if(expressionTwo != null) { + raiseValidateError( + "Too many arguments. This FFT_LINEARIZED implementation is only defined for real inputs.", false, + LanguageErrorCodes.INVALID_PARAMETERS); } - - else if (!isPowerOfTwo(expressionOne.getOutput().getDim2())) { - raiseValidateError("This FFT_LINEARIZED implementation is only defined for matrices with columns that are powers of 2.", false, LanguageErrorCodes.INVALID_PARAMETERS); + else if(!isPowerOfTwo(expressionOne.getOutput().getDim2())) { + raiseValidateError( + "This FFT_LINEARIZED implementation is only defined for matrices with columns that are powers of 2.", + false, LanguageErrorCodes.INVALID_PARAMETERS); } checkNumParameters(1); @@ -472,33 +534,43 @@ public class BuiltinFunctionExpression extends DataIdentifier { Expression expressionTwo = getSecondExpr(); Expression expressionOne = getFirstExpr(); - if (expressionOne == null) { - raiseValidateError("The first argument to " + _opcode + " cannot be null.", false, LanguageErrorCodes.INVALID_PARAMETERS); + if(expressionOne == null) { + raiseValidateError("The first argument to " + _opcode + " cannot be null.", false, + LanguageErrorCodes.INVALID_PARAMETERS); } - - else if (expressionOne.getOutput() == null || expressionOne.getOutput().getDim1() == 0 || expressionOne.getOutput().getDim2() == 0) { - raiseValidateError("The first argument to " + _opcode + " cannot be an empty matrix.", false, LanguageErrorCodes.INVALID_PARAMETERS); + else if(expressionOne.getOutput() == null || expressionOne.getOutput().getDim1() == 0 || + expressionOne.getOutput().getDim2() == 0) { + raiseValidateError("The first argument to " + _opcode + " cannot be an empty matrix.", false, + LanguageErrorCodes.INVALID_PARAMETERS); } - - else if (expressionTwo != null){ - if(expressionTwo.getOutput() == null || expressionTwo.getOutput().getDim1() == 0 || expressionTwo.getOutput().getDim2() == 0) { - raiseValidateError("The second argument to " + _opcode + " cannot be an empty matrix. Provide either only a real matrix or a filled real and imaginary one.", false, LanguageErrorCodes.INVALID_PARAMETERS); + else if(expressionTwo != null) { + if(expressionTwo.getOutput() == null || expressionTwo.getOutput().getDim1() == 0 || + expressionTwo.getOutput().getDim2() == 0) { + raiseValidateError("The second argument to " + _opcode + + " cannot be an empty matrix. Provide either only a real matrix or a filled real and imaginary one.", + false, LanguageErrorCodes.INVALID_PARAMETERS); } } checkNumParameters(expressionTwo != null ? 2 : 1); checkMatrixParam(expressionOne); - if(expressionTwo != null && expressionOne != null){ + if(expressionTwo != null && expressionOne != null) { checkMatrixParam(expressionTwo); - if(expressionOne.getOutput().getDim1() != expressionTwo.getOutput().getDim1() || expressionOne.getOutput().getDim2() != expressionTwo.getOutput().getDim2()) - raiseValidateError("The real and imaginary part of the provided matrix are of different dimensions.", false); - else if (!isPowerOfTwo(expressionTwo.getOutput().getDim2())) { - raiseValidateError("This IFFT_LINEARIZED implementation is only defined for matrices with columns that are powers of 2.", false, LanguageErrorCodes.INVALID_PARAMETERS); + if(expressionOne.getOutput().getDim1() != expressionTwo.getOutput().getDim1() || + expressionOne.getOutput().getDim2() != expressionTwo.getOutput().getDim2()) + raiseValidateError("The real and imaginary part of the provided matrix are of different dimensions.", + false); + else if(!isPowerOfTwo(expressionTwo.getOutput().getDim2())) { + raiseValidateError( + "This IFFT_LINEARIZED implementation is only defined for matrices with columns that are powers of 2.", + false, LanguageErrorCodes.INVALID_PARAMETERS); } } - else if(expressionOne != null){ - if (!isPowerOfTwo(expressionOne.getOutput().getDim2())) { - raiseValidateError("This IFFT_LINEARIZED implementation is only defined for matrices with columns that are powers of 2.", false, LanguageErrorCodes.INVALID_PARAMETERS); + else if(expressionOne != null) { + if(!isPowerOfTwo(expressionOne.getOutput().getDim2())) { + raiseValidateError( + "This IFFT_LINEARIZED implementation is only defined for matrices with columns that are powers of 2.", + false, LanguageErrorCodes.INVALID_PARAMETERS); } } diff --git a/src/main/python/systemds/context/systemds_context.py b/src/main/python/systemds/context/systemds_context.py index 5f34086807..fa0073ffca 100644 --- a/src/main/python/systemds/context/systemds_context.py +++ b/src/main/python/systemds/context/systemds_context.py @@ -38,7 +38,7 @@ import numpy as np import pandas as pd from py4j.java_gateway import GatewayParameters, JavaGateway, Py4JNetworkError from systemds.operator import (Frame, List, Matrix, OperationNode, Scalar, - Source, Combine) + Source, Combine, MultiReturn) from systemds.script_building import DMLScript, OutputType from systemds.utils.consts import VALID_INPUT_TYPES from systemds.utils.helpers import get_module_dir @@ -402,6 +402,41 @@ class SystemDSContext(object): named_input_nodes = {'rows': shape[0], 'cols': shape[1]} return Matrix(self, 'matrix', unnamed_input_nodes, named_input_nodes) + + def fft(self, real_input: 'Matrix') -> 'MultiReturn': + """ + Performs the Fast Fourier Transform (FFT) on the matrix. + :param real_input: The real part of the input matrix. + :return: A MultiReturn object representing the real and imaginary parts of the FFT output. + """ + + real_output = OperationNode(self, '', output_type=OutputType.MATRIX, is_python_local_data=False) + imag_output = OperationNode(self, '', output_type=OutputType.MATRIX, is_python_local_data=False) + + fft_node = MultiReturn(self, 'fft', [real_output, imag_output], [real_input]) + + return fft_node + + + def ifft(self, real_input: 'Matrix', imag_input: 'Matrix' = None) -> 'MultiReturn': + """ + Performs the Inverse Fast Fourier Transform (IFFT) on a complex matrix. + + :param real_input: The real part of the input matrix. + :param imag_input: The imaginary part of the input matrix (optional). + :return: A MultiReturn object representing the real and imaginary parts of the IFFT output. + """ + + real_output = OperationNode(self, '', output_type=OutputType.MATRIX, is_python_local_data=False) + imag_output = OperationNode(self, '', output_type=OutputType.MATRIX, is_python_local_data=False) + + if imag_input is not None: + ifft_node = MultiReturn(self, 'ifft', [real_output, imag_output], [real_input, imag_input]) + else: + ifft_node = MultiReturn(self, 'ifft', [real_output, imag_output], [real_input]) + + return ifft_node + def seq(self, start: Union[float, int], stop: Union[float, int] = None, step: Union[float, int] = 1) -> 'Matrix': """Create a single column vector with values from `start` to `stop` and an increment of `step`. diff --git a/src/main/python/tests/matrix/test_fft.py b/src/main/python/tests/matrix/test_fft.py new file mode 100644 index 0000000000..e95a8f2c8d --- /dev/null +++ b/src/main/python/tests/matrix/test_fft.py @@ -0,0 +1,333 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +import unittest +import numpy as np +from systemds.context import SystemDSContext + +class TestFFT(unittest.TestCase): + def setUp(self): + self.sds = SystemDSContext() + + def tearDown(self): + self.sds.close() + + def test_fft_basic(self): + + input_matrix = np.array([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]) + + sds_input = self.sds.from_numpy(input_matrix) + fft_result = self.sds.fft(sds_input).compute() + + real_part, imag_part = fft_result + + np_fft_result = np.fft.fft2(input_matrix) + expected_real = np.real(np_fft_result) + expected_imag = np.imag(np_fft_result) + + np.testing.assert_array_almost_equal(real_part, expected_real, decimal=5) + np.testing.assert_array_almost_equal(imag_part, expected_imag, decimal=5) + + def test_fft_random_1d(self): + np.random.seed(123) + for _ in range(10): + input_matrix = np.random.rand(1, 16) + + sds_input = self.sds.from_numpy(input_matrix) + + fft_result = self.sds.fft(sds_input).compute() + + real_part, imag_part = fft_result + + np_fft_result = np.fft.fft(input_matrix[0]) + expected_real = np.real(np_fft_result) + expected_imag = np.imag(np_fft_result) + + np.testing.assert_array_almost_equal(real_part.flatten(), expected_real, decimal=5) + np.testing.assert_array_almost_equal(imag_part.flatten(), expected_imag, decimal=5) + + def test_fft_2d(self): + np.random.seed(123) + for _ in range(10): + input_matrix = np.random.rand(8, 8) + + sds_input = self.sds.from_numpy(input_matrix) + + fft_result = self.sds.fft(sds_input).compute() + + real_part, imag_part = fft_result + + np_fft_result = np.fft.fft2(input_matrix) + expected_real = np.real(np_fft_result) + expected_imag = np.imag(np_fft_result) + + np.testing.assert_array_almost_equal(real_part, expected_real, decimal=5) + np.testing.assert_array_almost_equal(imag_part, expected_imag, decimal=5) + + def test_fft_non_power_of_two_matrix(self): + + input_matrix = np.random.rand(3, 5) + sds_input = self.sds.from_numpy(input_matrix) + + with self.assertRaisesRegex(RuntimeError, "This FFT implementation is only defined for matrices with dimensions that are powers of 2."): + _ = self.sds.fft(sds_input).compute() + + def test_ifft_basic(self): + real_input_matrix = np.array([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]) + + imag_input_matrix = np.array([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]) + + sds_real_input = self.sds.from_numpy(real_input_matrix) + sds_imag_input = self.sds.from_numpy(imag_input_matrix) + + ifft_result = self.sds.ifft(sds_real_input, sds_imag_input).compute() + + real_part, imag_part = ifft_result + + np_ifft_result = np.fft.ifft2(real_input_matrix + 1j * imag_input_matrix) + expected_real = np.real(np_ifft_result) + expected_imag = np.imag(np_ifft_result) + + + np.testing.assert_array_almost_equal(real_part, expected_real, decimal=5) + np.testing.assert_array_almost_equal(imag_part, expected_imag, decimal=5) + + def test_ifft_only_zeros_imag(self): + real_input_matrix = np.array([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]) + + imag_input_matrix = np.array([[0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + + sds_real_input = self.sds.from_numpy(real_input_matrix) + sds_imag_input = self.sds.from_numpy(imag_input_matrix) + + ifft_result = self.sds.ifft(sds_real_input, sds_imag_input).compute() + + real_part, imag_part = ifft_result + + np_ifft_result = np.fft.ifft2(real_input_matrix + 1j * imag_input_matrix) + expected_real = np.real(np_ifft_result) + expected_imag = np.imag(np_ifft_result) + + + np.testing.assert_array_almost_equal(real_part, expected_real, decimal=5) + np.testing.assert_array_almost_equal(imag_part, expected_imag, decimal=5) + + def test_ifft_empty_matrix_imag(self): + real_input_matrix = np.array([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]) + + imag_input_matrix = np.array([]) + + sds_real_input = self.sds.from_numpy(real_input_matrix) + sds_imag_input = self.sds.from_numpy(imag_input_matrix) + + with self.assertRaisesRegex(RuntimeError, "The second argument to IFFT cannot be an empty matrix. Provide either only a real matrix or a filled real and imaginary one."): + self.sds.ifft(sds_real_input, sds_imag_input).compute() + + def test_ifft_empty_2dmatrix_imag(self): + real_input_matrix = np.array([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]) + + imag_input_matrix = np.array([[]]) + + sds_real_input = self.sds.from_numpy(real_input_matrix) + sds_imag_input = self.sds.from_numpy(imag_input_matrix) + + with self.assertRaisesRegex(RuntimeError, "The second argument to IFFT cannot be an empty matrix. Provide either only a real matrix or a filled real and imaginary one."): + self.sds.ifft(sds_real_input, sds_imag_input).compute() + + def test_ifft_random_1d(self): + np.random.seed(123) + for _ in range(10): + real_part = np.random.rand(1, 16) + imag_part = np.random.rand(1, 16) + complex_input = real_part + 1j * imag_part + + np_fft_result = np.fft.fft(complex_input[0]) + + sds_real_input = self.sds.from_numpy(np.real(np_fft_result).reshape(1, -1)) + sds_imag_input = self.sds.from_numpy(np.imag(np_fft_result).reshape(1, -1)) + + ifft_result = self.sds.ifft(sds_real_input, sds_imag_input).compute() + + real_part_result, imag_part_result = ifft_result + + real_part_result = real_part_result.flatten() + imag_part_result = imag_part_result.flatten() + + expected_ifft = np.fft.ifft(np_fft_result) + expected_real = np.real(expected_ifft) + expected_imag = np.imag(expected_ifft) + + np.testing.assert_array_almost_equal(real_part_result, expected_real, decimal=5) + np.testing.assert_array_almost_equal(imag_part_result, expected_imag, decimal=5) + + def test_ifft_real_only_basic(self): + np.random.seed(123) + real = np.array([1, 2, 3, 4, 4, 3, 2, 1]) + + sds_real_input = self.sds.from_numpy(real) + + ifft_result = self.sds.ifft(sds_real_input).compute() + + real_part_result, imag_part_result = ifft_result + + real_part_result = real_part_result.flatten() + imag_part_result = imag_part_result.flatten() + + expected_ifft = np.fft.ifft(real) + expected_real = np.real(expected_ifft) + expected_imag = np.imag(expected_ifft) + + np.testing.assert_array_almost_equal(real_part_result, expected_real, decimal=5) + np.testing.assert_array_almost_equal(imag_part_result, expected_imag, decimal=5) + + def test_ifft_real_only_random(self): + np.random.seed(123) + for _ in range(10): + input_matrix = np.random.rand(1, 16) + + sds_input = self.sds.from_numpy(input_matrix) + + ifft_result = self.sds.ifft(sds_input).compute() + + real_part, imag_part = ifft_result + + np_ifft_result = np.fft.ifft(input_matrix[0]) + expected_real = np.real(np_ifft_result) + expected_imag = np.imag(np_ifft_result) + + np.testing.assert_array_almost_equal(real_part.flatten(), expected_real, decimal=5) + np.testing.assert_array_almost_equal(imag_part.flatten(), expected_imag, decimal=5) + + + def test_ifft_2d(self): + np.random.seed(123) + for _ in range(10): + input_matrix = np.random.rand(8, 8) + 1j * np.random.rand(8, 8) + + fft_result = np.fft.fft2(input_matrix) + + sds_real_input = self.sds.from_numpy(np.real(fft_result)) + sds_imag_input = self.sds.from_numpy(np.imag(fft_result)) + + ifft_result = self.sds.ifft(sds_real_input, sds_imag_input).compute() + + real_part, imag_part = ifft_result + + expected_ifft_result = np.fft.ifft2(fft_result) + expected_real = np.real(expected_ifft_result) + expected_imag = np.imag(expected_ifft_result) + + np.testing.assert_array_almost_equal(real_part, expected_real, decimal=5) + np.testing.assert_array_almost_equal(imag_part, expected_imag, decimal=5) + + def test_fft_empty_matrix(self): + input_matrix = np.array([]) + sds_input = self.sds.from_numpy(input_matrix) + + with self.assertRaisesRegex(RuntimeError, "The first argument to FFT cannot be an empty matrix."): + _ = self.sds.fft(sds_input).compute() + + def test_ifft_empty_matrix(self): + input_matrix = np.array([]) + sds_input = self.sds.from_numpy(input_matrix) + + with self.assertRaisesRegex(RuntimeError, "The first argument to IFFT cannot be an empty matrix."): + _ = self.sds.ifft(sds_input).compute() + + def test_fft_single_element(self): + input_matrix = np.array([[5]]) + sds_input = self.sds.from_numpy(input_matrix) + fft_result = self.sds.fft(sds_input).compute() + + real_part, imag_part = fft_result + np.testing.assert_array_almost_equal(real_part, [[5]], decimal=5) + np.testing.assert_array_almost_equal(imag_part, [[0]], decimal=5) + + def test_ifft_single_element(self): + input_matrix = np.array([[5]]) + sds_input = self.sds.from_numpy(input_matrix) + ifft_result = self.sds.ifft(sds_input).compute() + + real_part, imag_part = ifft_result + np.testing.assert_array_almost_equal(real_part, [[5]], decimal=5) + np.testing.assert_array_almost_equal(imag_part, [[0]], decimal=5) + + def test_fft_zeros_matrix(self): + input_matrix = np.zeros((4, 4)) + sds_input = self.sds.from_numpy(input_matrix) + fft_result = self.sds.fft(sds_input).compute() + + real_part, imag_part = fft_result + np.testing.assert_array_almost_equal(real_part, np.zeros((4, 4)), decimal=5) + np.testing.assert_array_almost_equal(imag_part, np.zeros((4, 4)), decimal=5) + + def test_ifft_zeros_matrix(self): + input_matrix = np.zeros((4, 4)) + sds_input = self.sds.from_numpy(input_matrix) + ifft_result = self.sds.ifft(sds_input).compute() + + real_part, imag_part = ifft_result + np.testing.assert_array_almost_equal(real_part, np.zeros((4, 4)), decimal=5) + np.testing.assert_array_almost_equal(imag_part, np.zeros((4, 4)), decimal=5) + + def test_ifft_real_and_imaginary_dimensions_check(self): + real_part = np.random.rand(1, 16) + imag_part = np.random.rand(1, 14) + + sds_real_input = self.sds.from_numpy(real_part) + sds_imag_input = self.sds.from_numpy(imag_part) + + with self.assertRaisesRegex(RuntimeError, "The real and imaginary part of the provided matrix are of different dimensions."): + self.sds.ifft(sds_real_input, sds_imag_input).compute() + + def test_ifft_non_power_of_two_matrix(self): + real_part = np.random.rand(3, 5) + imag_part = np.random.rand(3, 5) + + sds_real_input = self.sds.from_numpy(real_part) + sds_imag_input = self.sds.from_numpy(imag_part) + + with self.assertRaisesRegex(RuntimeError, "This IFFT implementation is only defined for matrices with dimensions that are powers of 2."): + _ = self.sds.ifft(sds_real_input, sds_imag_input).compute() + +if __name__ == '__main__': + unittest.main() \ No newline at end of file
