This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit d22ffeccd7b10af366574f7fe03d637be9db49d5
Author: Frederic Caspar Zoepffel <[email protected]>
AuthorDate: Thu Apr 4 17:17:44 2024 +0200

    [SYSTEMDS-3685] Python FFT
    
    This commit adds support in the Python API for fft and ifft.
    
    Future work is to add the linearized versions of the commands.
    
    LDE 23/24 project
    
    Co-authored-by: Mufan Wang <[email protected]>
    Co-authored-by: Frederic Caspar Zoepffel <[email protected]>
    Co-authored-by: Jessica Eva Sophie Priebe <[email protected]>
    
    Closes #1983
---
 .../sysds/parser/BuiltinFunctionExpression.java    | 148 ++++++---
 .../python/systemds/context/systemds_context.py    |  37 ++-
 src/main/python/tests/matrix/test_fft.py           | 333 +++++++++++++++++++++
 3 files changed, 479 insertions(+), 39 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 5e86a2fd8e..4b3c8e82f7 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -381,20 +381,41 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
                        break;
                }
                case FFT: {
+
+                       Expression expressionOne = getFirstExpr();
+                       Expression expressionTwo = getSecondExpr();
+
+                       if(expressionOne == null) {
+                               raiseValidateError("The first argument to " + 
_opcode + " cannot be null.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
+                       }
+                       else if(expressionOne.getOutput() == null || 
expressionOne.getOutput().getDim1() == 0 ||
+                               expressionOne.getOutput().getDim2() == 0) {
+                               raiseValidateError("The first argument to " + 
_opcode + " cannot be an empty matrix.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
+                       }
+                       else if(expressionTwo != null) {
+                               raiseValidateError("Too many arguments. This 
FFT implementation is only defined for real inputs.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
+                       }
+                       else 
if(!isPowerOfTwo(expressionOne.getOutput().getDim1()) ||
+                               
!isPowerOfTwo(expressionOne.getOutput().getDim2())) {
+                               raiseValidateError(
+                                       "This FFT implementation is only 
defined for matrices with dimensions that are powers of 2.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
+                       }
+
                        checkNumParameters(1);
-                       checkMatrixParam(getFirstExpr());
+                       checkMatrixParam(expressionOne);
 
-                       // setup output properties
                        DataIdentifier fftOut1 = (DataIdentifier) 
getOutputs()[0];
                        DataIdentifier fftOut2 = (DataIdentifier) 
getOutputs()[1];
 
-                       // Output1 - FFT Values
                        fftOut1.setDataType(DataType.MATRIX);
                        fftOut1.setValueType(ValueType.FP64);
                        
fftOut1.setDimensions(getFirstExpr().getOutput().getDim1(), 
getFirstExpr().getOutput().getDim2());
                        
fftOut1.setBlocksize(getFirstExpr().getOutput().getBlocksize());
 
-                       // Output2 - FFT Vectors
                        fftOut2.setDataType(DataType.MATRIX);
                        fftOut2.setValueType(ValueType.FP64);
                        
fftOut2.setDimensions(getFirstExpr().getOutput().getDim1(), 
getFirstExpr().getOutput().getDim2());
@@ -405,16 +426,53 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
                }
                case IFFT: {
                        Expression expressionTwo = getSecondExpr();
-                       checkNumParameters(getSecondExpr() != null ? 2 : 1);
-                       checkMatrixParam(getFirstExpr());
-                       if (expressionTwo != null)
-                               checkMatrixParam(getSecondExpr());
+                       Expression expressionOne = getFirstExpr();
+
+                       if(expressionOne == null) {
+                               raiseValidateError("The first argument to " + 
_opcode + " cannot be null.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
+                       }
+                       else if(expressionOne.getOutput() == null || 
expressionOne.getOutput().getDim1() == 0 ||
+                               expressionOne.getOutput().getDim2() == 0) {
+                               raiseValidateError("The first argument to " + 
_opcode + " cannot be an empty matrix.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
+                       }
+                       else if(expressionTwo != null) {
+                               if(expressionTwo.getOutput() == null || 
expressionTwo.getOutput().getDim1() == 0 ||
+                                       expressionTwo.getOutput().getDim2() == 
0) {
+                                       raiseValidateError("The second argument 
to " + _opcode
+                                               + " cannot be an empty matrix. 
Provide either only a real matrix or a filled real and imaginary one.",
+                                               false, 
LanguageErrorCodes.INVALID_PARAMETERS);
+                               }
+                       }
+
+                       checkNumParameters(expressionTwo != null ? 2 : 1);
+                       checkMatrixParam(expressionOne);
+                       if(expressionTwo != null && expressionOne != null) {
+                               checkMatrixParam(expressionTwo);
+                               if(expressionOne.getOutput().getDim1() != 
expressionTwo.getOutput().getDim1() ||
+                                       expressionOne.getOutput().getDim2() != 
expressionTwo.getOutput().getDim2())
+                                       raiseValidateError("The real and 
imaginary part of the provided matrix are of different dimensions.",
+                                               false);
+                               else 
if(!isPowerOfTwo(expressionTwo.getOutput().getDim1()) ||
+                                       
!isPowerOfTwo(expressionTwo.getOutput().getDim2())) {
+                                       raiseValidateError(
+                                               "This IFFT implementation is 
only defined for matrices with dimensions that are powers of 2.", false,
+                                               
LanguageErrorCodes.INVALID_PARAMETERS);
+                               }
+                       }
+                       else if(expressionOne != null) {
+                               
if(!isPowerOfTwo(expressionOne.getOutput().getDim1()) ||
+                                       
!isPowerOfTwo(expressionOne.getOutput().getDim2())) {
+                                       raiseValidateError(
+                                               "This IFFT implementation is 
only defined for matrices with dimensions that are powers of 2.", false,
+                                               
LanguageErrorCodes.INVALID_PARAMETERS);
+                               }
+                       }
 
-                       // setup output properties
                        DataIdentifier ifftOut1 = (DataIdentifier) 
getOutputs()[0];
                        DataIdentifier ifftOut2 = (DataIdentifier) 
getOutputs()[1];
 
-                       // Output1 - ifft Values
                        ifftOut1.setDataType(DataType.MATRIX);
                        ifftOut1.setValueType(ValueType.FP64);
                        
ifftOut1.setDimensions(getFirstExpr().getOutput().getDim1(), 
getFirstExpr().getOutput().getDim2());
@@ -433,20 +491,24 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
                        Expression expressionOne = getFirstExpr();
                        Expression expressionTwo = getSecondExpr();
 
-                       if (expressionOne == null) {
-                               raiseValidateError("The first argument to " + 
_opcode + " cannot be null.", false, LanguageErrorCodes.INVALID_PARAMETERS);
+                       if(expressionOne == null) {
+                               raiseValidateError("The first argument to " + 
_opcode + " cannot be null.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
                        }
-
-                       else if (expressionOne.getOutput() == null || 
expressionOne.getOutput().getDim1() == 0 || expressionOne.getOutput().getDim2() 
== 0) {
-                               raiseValidateError("The first argument to " + 
_opcode + " cannot be an empty matrix.", false, 
LanguageErrorCodes.INVALID_PARAMETERS);
+                       else if(expressionOne.getOutput() == null || 
expressionOne.getOutput().getDim1() == 0 ||
+                               expressionOne.getOutput().getDim2() == 0) {
+                               raiseValidateError("The first argument to " + 
_opcode + " cannot be an empty matrix.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
                        }
-
-                       else if (expressionTwo != null) {
-                               raiseValidateError("Too many arguments. This 
FFT_LINEARIZED implementation is only defined for real inputs.", false, 
LanguageErrorCodes.INVALID_PARAMETERS);
+                       else if(expressionTwo != null) {
+                               raiseValidateError(
+                                       "Too many arguments. This 
FFT_LINEARIZED implementation is only defined for real inputs.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
                        }
-
-                       else if 
(!isPowerOfTwo(expressionOne.getOutput().getDim2())) {
-                               raiseValidateError("This FFT_LINEARIZED 
implementation is only defined for matrices with columns that are powers of 
2.", false, LanguageErrorCodes.INVALID_PARAMETERS);
+                       else 
if(!isPowerOfTwo(expressionOne.getOutput().getDim2())) {
+                               raiseValidateError(
+                                       "This FFT_LINEARIZED implementation is 
only defined for matrices with columns that are powers of 2.",
+                                       false, 
LanguageErrorCodes.INVALID_PARAMETERS);
                        }
 
                        checkNumParameters(1);
@@ -472,33 +534,43 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
                        Expression expressionTwo = getSecondExpr();
                        Expression expressionOne = getFirstExpr();
 
-                       if (expressionOne == null) {
-                               raiseValidateError("The first argument to " + 
_opcode + " cannot be null.", false, LanguageErrorCodes.INVALID_PARAMETERS);
+                       if(expressionOne == null) {
+                               raiseValidateError("The first argument to " + 
_opcode + " cannot be null.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
                        }
-
-                       else if (expressionOne.getOutput() == null || 
expressionOne.getOutput().getDim1() == 0 || expressionOne.getOutput().getDim2() 
== 0) {
-                               raiseValidateError("The first argument to " + 
_opcode + " cannot be an empty matrix.", false, 
LanguageErrorCodes.INVALID_PARAMETERS);
+                       else if(expressionOne.getOutput() == null || 
expressionOne.getOutput().getDim1() == 0 ||
+                               expressionOne.getOutput().getDim2() == 0) {
+                               raiseValidateError("The first argument to " + 
_opcode + " cannot be an empty matrix.", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
                        }
-
-                       else if (expressionTwo != null){
-                               if(expressionTwo.getOutput() == null || 
expressionTwo.getOutput().getDim1() == 0 || expressionTwo.getOutput().getDim2() 
== 0) {
-                                       raiseValidateError("The second argument 
to " + _opcode + " cannot be an empty matrix. Provide either only a real matrix 
or a filled real and imaginary one.", false, 
LanguageErrorCodes.INVALID_PARAMETERS);
+                       else if(expressionTwo != null) {
+                               if(expressionTwo.getOutput() == null || 
expressionTwo.getOutput().getDim1() == 0 ||
+                                       expressionTwo.getOutput().getDim2() == 
0) {
+                                       raiseValidateError("The second argument 
to " + _opcode
+                                               + " cannot be an empty matrix. 
Provide either only a real matrix or a filled real and imaginary one.",
+                                               false, 
LanguageErrorCodes.INVALID_PARAMETERS);
                                }
                        }
 
                        checkNumParameters(expressionTwo != null ? 2 : 1);
                        checkMatrixParam(expressionOne);
-                       if(expressionTwo != null && expressionOne != null){
+                       if(expressionTwo != null && expressionOne != null) {
                                checkMatrixParam(expressionTwo);
-                               if(expressionOne.getOutput().getDim1() != 
expressionTwo.getOutput().getDim1() || expressionOne.getOutput().getDim2() != 
expressionTwo.getOutput().getDim2())
-                                       raiseValidateError("The real and 
imaginary part of the provided matrix are of different dimensions.", false);
-                               else if 
(!isPowerOfTwo(expressionTwo.getOutput().getDim2())) {
-                                       raiseValidateError("This 
IFFT_LINEARIZED implementation is only defined for matrices with columns that 
are powers of 2.", false, LanguageErrorCodes.INVALID_PARAMETERS);
+                               if(expressionOne.getOutput().getDim1() != 
expressionTwo.getOutput().getDim1() ||
+                                       expressionOne.getOutput().getDim2() != 
expressionTwo.getOutput().getDim2())
+                                       raiseValidateError("The real and 
imaginary part of the provided matrix are of different dimensions.",
+                                               false);
+                               else 
if(!isPowerOfTwo(expressionTwo.getOutput().getDim2())) {
+                                       raiseValidateError(
+                                               "This IFFT_LINEARIZED 
implementation is only defined for matrices with columns that are powers of 2.",
+                                               false, 
LanguageErrorCodes.INVALID_PARAMETERS);
                                }
                        }
-                       else if(expressionOne != null){
-                               if 
(!isPowerOfTwo(expressionOne.getOutput().getDim2())) {
-                                       raiseValidateError("This 
IFFT_LINEARIZED implementation is only defined for matrices with columns that 
are powers of 2.", false, LanguageErrorCodes.INVALID_PARAMETERS);
+                       else if(expressionOne != null) {
+                               
if(!isPowerOfTwo(expressionOne.getOutput().getDim2())) {
+                                       raiseValidateError(
+                                               "This IFFT_LINEARIZED 
implementation is only defined for matrices with columns that are powers of 2.",
+                                               false, 
LanguageErrorCodes.INVALID_PARAMETERS);
                                }
                        }
 
diff --git a/src/main/python/systemds/context/systemds_context.py 
b/src/main/python/systemds/context/systemds_context.py
index 5f34086807..fa0073ffca 100644
--- a/src/main/python/systemds/context/systemds_context.py
+++ b/src/main/python/systemds/context/systemds_context.py
@@ -38,7 +38,7 @@ import numpy as np
 import pandas as pd
 from py4j.java_gateway import GatewayParameters, JavaGateway, Py4JNetworkError
 from systemds.operator import (Frame, List, Matrix, OperationNode, Scalar,
-                               Source, Combine)
+                               Source, Combine,  MultiReturn)
 from systemds.script_building import DMLScript, OutputType
 from systemds.utils.consts import VALID_INPUT_TYPES
 from systemds.utils.helpers import get_module_dir
@@ -402,6 +402,41 @@ class SystemDSContext(object):
         named_input_nodes = {'rows': shape[0], 'cols': shape[1]}
         return Matrix(self, 'matrix', unnamed_input_nodes, named_input_nodes)
 
+
+    def fft(self, real_input: 'Matrix') -> 'MultiReturn':
+        """
+        Performs the Fast Fourier Transform (FFT) on the matrix.
+        :param real_input: The real part of the input matrix.
+        :return: A MultiReturn object representing the real and imaginary 
parts of the FFT output.
+        """
+
+        real_output = OperationNode(self, '', output_type=OutputType.MATRIX, 
is_python_local_data=False)
+        imag_output = OperationNode(self, '', output_type=OutputType.MATRIX, 
is_python_local_data=False)
+
+        fft_node = MultiReturn(self, 'fft', [real_output, imag_output], 
[real_input])
+
+        return fft_node
+
+
+    def ifft(self, real_input: 'Matrix', imag_input: 'Matrix' = None) -> 
'MultiReturn':
+        """
+        Performs the Inverse Fast Fourier Transform (IFFT) on a complex matrix.
+        
+        :param real_input: The real part of the input matrix.
+        :param imag_input: The imaginary part of the input matrix (optional).
+        :return: A MultiReturn object representing the real and imaginary 
parts of the IFFT output.
+        """
+
+        real_output = OperationNode(self, '', output_type=OutputType.MATRIX, 
is_python_local_data=False)
+        imag_output = OperationNode(self, '', output_type=OutputType.MATRIX, 
is_python_local_data=False)
+
+        if imag_input is not None:
+            ifft_node = MultiReturn(self, 'ifft', [real_output, imag_output], 
[real_input, imag_input])
+        else:
+            ifft_node = MultiReturn(self, 'ifft', [real_output, imag_output], 
[real_input])
+
+        return ifft_node
+
     def seq(self, start: Union[float, int], stop: Union[float, int] = None,
             step: Union[float, int] = 1) -> 'Matrix':
         """Create a single column vector with values from `start` to `stop` 
and an increment of `step`.
diff --git a/src/main/python/tests/matrix/test_fft.py 
b/src/main/python/tests/matrix/test_fft.py
new file mode 100644
index 0000000000..e95a8f2c8d
--- /dev/null
+++ b/src/main/python/tests/matrix/test_fft.py
@@ -0,0 +1,333 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+import numpy as np
+from systemds.context import SystemDSContext
+
+class TestFFT(unittest.TestCase):
+    def setUp(self):
+        self.sds = SystemDSContext()
+
+    def tearDown(self):
+        self.sds.close()
+
+    def test_fft_basic(self):
+
+        input_matrix = np.array([[1, 2, 3, 4],
+                                 [5, 6, 7, 8],
+                                 [9, 10, 11, 12],
+                                 [13, 14, 15, 16]])
+
+        sds_input = self.sds.from_numpy(input_matrix)
+        fft_result = self.sds.fft(sds_input).compute()
+
+        real_part, imag_part = fft_result
+
+        np_fft_result = np.fft.fft2(input_matrix)
+        expected_real = np.real(np_fft_result)
+        expected_imag = np.imag(np_fft_result) 
+
+        np.testing.assert_array_almost_equal(real_part, expected_real, 
decimal=5)
+        np.testing.assert_array_almost_equal(imag_part, expected_imag, 
decimal=5)
+
+    def test_fft_random_1d(self):
+        np.random.seed(123) 
+        for _ in range(10):
+            input_matrix = np.random.rand(1, 16)  
+
+            sds_input = self.sds.from_numpy(input_matrix)
+
+            fft_result = self.sds.fft(sds_input).compute()
+
+            real_part, imag_part = fft_result
+
+            np_fft_result = np.fft.fft(input_matrix[0]) 
+            expected_real = np.real(np_fft_result)
+            expected_imag = np.imag(np_fft_result) 
+
+            np.testing.assert_array_almost_equal(real_part.flatten(), 
expected_real, decimal=5)
+            np.testing.assert_array_almost_equal(imag_part.flatten(), 
expected_imag, decimal=5)
+
+    def test_fft_2d(self):
+        np.random.seed(123) 
+        for _ in range(10):
+            input_matrix = np.random.rand(8, 8) 
+
+            sds_input = self.sds.from_numpy(input_matrix)
+
+            fft_result = self.sds.fft(sds_input).compute()
+
+            real_part, imag_part = fft_result
+
+            np_fft_result = np.fft.fft2(input_matrix)
+            expected_real = np.real(np_fft_result)
+            expected_imag = np.imag(np_fft_result)
+
+            np.testing.assert_array_almost_equal(real_part, expected_real, 
decimal=5)
+            np.testing.assert_array_almost_equal(imag_part, expected_imag, 
decimal=5)
+
+    def test_fft_non_power_of_two_matrix(self):
+
+        input_matrix = np.random.rand(3, 5) 
+        sds_input = self.sds.from_numpy(input_matrix)
+
+        with self.assertRaisesRegex(RuntimeError, "This FFT implementation is 
only defined for matrices with dimensions that are powers of 2."):
+            _ = self.sds.fft(sds_input).compute()
+
+    def test_ifft_basic(self):
+        real_input_matrix = np.array([[1, 2, 3, 4],
+                                       [5, 6, 7, 8],
+                                       [9, 10, 11, 12],
+                                       [13, 14, 15, 16]])
+
+        imag_input_matrix = np.array([[1, 2, 3, 4],
+                                       [5, 6, 7, 8],
+                                       [9, 10, 11, 12],
+                                       [13, 14, 15, 16]]) 
+
+        sds_real_input = self.sds.from_numpy(real_input_matrix)
+        sds_imag_input = self.sds.from_numpy(imag_input_matrix)
+
+        ifft_result = self.sds.ifft(sds_real_input, sds_imag_input).compute()
+
+        real_part, imag_part = ifft_result
+
+        np_ifft_result = np.fft.ifft2(real_input_matrix + 1j * 
imag_input_matrix)
+        expected_real = np.real(np_ifft_result)
+        expected_imag = np.imag(np_ifft_result)
+
+
+        np.testing.assert_array_almost_equal(real_part, expected_real, 
decimal=5)
+        np.testing.assert_array_almost_equal(imag_part, expected_imag, 
decimal=5)
+
+    def test_ifft_only_zeros_imag(self):
+        real_input_matrix = np.array([[1, 2, 3, 4],
+                                       [5, 6, 7, 8],
+                                       [9, 10, 11, 12],
+                                       [13, 14, 15, 16]])
+
+        imag_input_matrix = np.array([[0, 0, 0, 0],
+                                       [0, 0, 0, 0],
+                                       [0, 0, 0, 0],
+                                       [0, 0, 0, 0]]) 
+
+        sds_real_input = self.sds.from_numpy(real_input_matrix)
+        sds_imag_input = self.sds.from_numpy(imag_input_matrix)
+
+        ifft_result = self.sds.ifft(sds_real_input, sds_imag_input).compute()
+
+        real_part, imag_part = ifft_result
+
+        np_ifft_result = np.fft.ifft2(real_input_matrix + 1j * 
imag_input_matrix)
+        expected_real = np.real(np_ifft_result)
+        expected_imag = np.imag(np_ifft_result)
+
+
+        np.testing.assert_array_almost_equal(real_part, expected_real, 
decimal=5)
+        np.testing.assert_array_almost_equal(imag_part, expected_imag, 
decimal=5)
+
+    def test_ifft_empty_matrix_imag(self):
+        real_input_matrix = np.array([[1, 2, 3, 4],
+                                       [5, 6, 7, 8],
+                                       [9, 10, 11, 12],
+                                       [13, 14, 15, 16]])
+
+        imag_input_matrix = np.array([]) 
+
+        sds_real_input = self.sds.from_numpy(real_input_matrix)
+        sds_imag_input = self.sds.from_numpy(imag_input_matrix)
+
+        with self.assertRaisesRegex(RuntimeError, "The second argument to IFFT 
cannot be an empty matrix. Provide either only a real matrix or a filled real 
and imaginary one."):
+            self.sds.ifft(sds_real_input, sds_imag_input).compute()
+
+    def test_ifft_empty_2dmatrix_imag(self):
+        real_input_matrix = np.array([[1, 2, 3, 4],
+                                       [5, 6, 7, 8],
+                                       [9, 10, 11, 12],
+                                       [13, 14, 15, 16]])
+
+        imag_input_matrix = np.array([[]]) 
+
+        sds_real_input = self.sds.from_numpy(real_input_matrix)
+        sds_imag_input = self.sds.from_numpy(imag_input_matrix)
+
+        with self.assertRaisesRegex(RuntimeError, "The second argument to IFFT 
cannot be an empty matrix. Provide either only a real matrix or a filled real 
and imaginary one."):
+            self.sds.ifft(sds_real_input, sds_imag_input).compute()
+
+    def test_ifft_random_1d(self):
+        np.random.seed(123) 
+        for _ in range(10):
+            real_part = np.random.rand(1, 16) 
+            imag_part = np.random.rand(1, 16) 
+            complex_input = real_part + 1j * imag_part  
+
+            np_fft_result = np.fft.fft(complex_input[0])
+
+            sds_real_input = 
self.sds.from_numpy(np.real(np_fft_result).reshape(1, -1))
+            sds_imag_input = 
self.sds.from_numpy(np.imag(np_fft_result).reshape(1, -1))
+
+            ifft_result = self.sds.ifft(sds_real_input, 
sds_imag_input).compute()
+
+            real_part_result, imag_part_result = ifft_result
+
+            real_part_result = real_part_result.flatten()
+            imag_part_result = imag_part_result.flatten()
+
+            expected_ifft = np.fft.ifft(np_fft_result)
+            expected_real = np.real(expected_ifft)
+            expected_imag = np.imag(expected_ifft)
+
+            np.testing.assert_array_almost_equal(real_part_result, 
expected_real, decimal=5)
+            np.testing.assert_array_almost_equal(imag_part_result, 
expected_imag, decimal=5)
+
+    def test_ifft_real_only_basic(self):
+        np.random.seed(123)  
+        real = np.array([1, 2, 3, 4, 4,  3, 2, 1])
+
+        sds_real_input = self.sds.from_numpy(real)
+
+        ifft_result = self.sds.ifft(sds_real_input).compute()
+
+        real_part_result, imag_part_result = ifft_result
+
+        real_part_result = real_part_result.flatten()
+        imag_part_result = imag_part_result.flatten()
+
+        expected_ifft = np.fft.ifft(real)
+        expected_real = np.real(expected_ifft)
+        expected_imag = np.imag(expected_ifft)
+
+        np.testing.assert_array_almost_equal(real_part_result, expected_real, 
decimal=5)
+        np.testing.assert_array_almost_equal(imag_part_result, expected_imag, 
decimal=5)
+
+    def test_ifft_real_only_random(self):
+        np.random.seed(123) 
+        for _ in range(10):
+            input_matrix = np.random.rand(1, 16)  
+
+            sds_input = self.sds.from_numpy(input_matrix)
+
+            ifft_result = self.sds.ifft(sds_input).compute()
+
+            real_part, imag_part = ifft_result
+
+            np_ifft_result = np.fft.ifft(input_matrix[0]) 
+            expected_real = np.real(np_ifft_result)
+            expected_imag = np.imag(np_ifft_result) 
+
+            np.testing.assert_array_almost_equal(real_part.flatten(), 
expected_real, decimal=5)
+            np.testing.assert_array_almost_equal(imag_part.flatten(), 
expected_imag, decimal=5)
+
+
+    def test_ifft_2d(self):
+        np.random.seed(123) 
+        for _ in range(10):
+            input_matrix = np.random.rand(8, 8) + 1j * np.random.rand(8, 8)
+
+            fft_result = np.fft.fft2(input_matrix)
+
+            sds_real_input = self.sds.from_numpy(np.real(fft_result))
+            sds_imag_input = self.sds.from_numpy(np.imag(fft_result))
+
+            ifft_result = self.sds.ifft(sds_real_input, 
sds_imag_input).compute()
+
+            real_part, imag_part = ifft_result
+
+            expected_ifft_result = np.fft.ifft2(fft_result)
+            expected_real = np.real(expected_ifft_result)
+            expected_imag = np.imag(expected_ifft_result)
+
+            np.testing.assert_array_almost_equal(real_part, expected_real, 
decimal=5)
+            np.testing.assert_array_almost_equal(imag_part, expected_imag, 
decimal=5)
+
+    def test_fft_empty_matrix(self):
+        input_matrix = np.array([])
+        sds_input = self.sds.from_numpy(input_matrix)
+
+        with self.assertRaisesRegex(RuntimeError, "The first argument to FFT 
cannot be an empty matrix."):
+            _ = self.sds.fft(sds_input).compute()
+
+    def test_ifft_empty_matrix(self):
+        input_matrix = np.array([])
+        sds_input = self.sds.from_numpy(input_matrix)
+
+        with self.assertRaisesRegex(RuntimeError, "The first argument to IFFT 
cannot be an empty matrix."):
+            _ = self.sds.ifft(sds_input).compute()
+
+    def test_fft_single_element(self):
+        input_matrix = np.array([[5]])
+        sds_input = self.sds.from_numpy(input_matrix)
+        fft_result = self.sds.fft(sds_input).compute()
+
+        real_part, imag_part = fft_result
+        np.testing.assert_array_almost_equal(real_part, [[5]], decimal=5)
+        np.testing.assert_array_almost_equal(imag_part, [[0]], decimal=5)
+
+    def test_ifft_single_element(self):
+        input_matrix = np.array([[5]])
+        sds_input = self.sds.from_numpy(input_matrix)
+        ifft_result = self.sds.ifft(sds_input).compute()
+
+        real_part, imag_part = ifft_result
+        np.testing.assert_array_almost_equal(real_part, [[5]], decimal=5)
+        np.testing.assert_array_almost_equal(imag_part, [[0]], decimal=5)
+
+    def test_fft_zeros_matrix(self):
+        input_matrix = np.zeros((4, 4))
+        sds_input = self.sds.from_numpy(input_matrix)
+        fft_result = self.sds.fft(sds_input).compute()
+
+        real_part, imag_part = fft_result
+        np.testing.assert_array_almost_equal(real_part, np.zeros((4, 4)), 
decimal=5)
+        np.testing.assert_array_almost_equal(imag_part, np.zeros((4, 4)), 
decimal=5)
+
+    def test_ifft_zeros_matrix(self):
+        input_matrix = np.zeros((4, 4))
+        sds_input = self.sds.from_numpy(input_matrix)
+        ifft_result = self.sds.ifft(sds_input).compute()
+
+        real_part, imag_part = ifft_result
+        np.testing.assert_array_almost_equal(real_part, np.zeros((4, 4)), 
decimal=5)
+        np.testing.assert_array_almost_equal(imag_part, np.zeros((4, 4)), 
decimal=5)
+
+    def test_ifft_real_and_imaginary_dimensions_check(self):
+        real_part = np.random.rand(1, 16) 
+        imag_part = np.random.rand(1, 14) 
+
+        sds_real_input = self.sds.from_numpy(real_part)
+        sds_imag_input = self.sds.from_numpy(imag_part)
+
+        with self.assertRaisesRegex(RuntimeError, "The real and imaginary part 
of the provided matrix are of different dimensions."):
+            self.sds.ifft(sds_real_input, sds_imag_input).compute()
+
+    def test_ifft_non_power_of_two_matrix(self):
+        real_part = np.random.rand(3, 5) 
+        imag_part = np.random.rand(3, 5) 
+
+        sds_real_input = self.sds.from_numpy(real_part)
+        sds_imag_input = self.sds.from_numpy(imag_part)
+
+        with self.assertRaisesRegex(RuntimeError, "This IFFT implementation is 
only defined for matrices with dimensions that are powers of 2."):
+            _ = self.sds.ifft(sds_real_input, sds_imag_input).compute()
+
+if __name__ == '__main__':
+    unittest.main()
\ No newline at end of file

Reply via email to