This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c00c89d670 [SYSTEMDS-3547] Tensor in-place permutation operations
c00c89d670 is described below

commit c00c89d670177f6173f38c0a65c9abe12eb05be2
Author: dogakarakas <[email protected]>
AuthorDate: Tue Mar 31 12:38:29 2026 +0200

    [SYSTEMDS-3547] Tensor in-place permutation operations
    
    Closes #2412.
    
    Co-authored-by: bakiberkay <[email protected]>
---
 .../sysds/runtime/matrix/data/LibMatrixReorg.java  | 443 +++++++++++++--
 src/test/java/org/apache/sysds/test/TestUtils.java |   6 +
 .../TransposeInPlaceBrennerTest.java               | 591 ++++++++++++++++++++-
 3 files changed, 986 insertions(+), 54 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index ffd7b17a20..f00e0015a8 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -111,7 +111,7 @@ public class LibMatrixReorg {
        }
        
        /////////////////////////
-       // public interface    //
+       // public interface     //
        /////////////////////////
 
        public static boolean isSupportedReorgOperator( ReorgOperator op ) {
@@ -835,9 +835,9 @@ public class LibMatrixReorg {
         * default, while R uses always a column-wise read, rowwise specifying 
the write order and column-wise being the
         * default.
         * 
-        * @param in      input matrix
-        * @param rows    number of rows
-        * @param cols    number of columns
+        * @param in      input matrix
+        * @param rows  number of rows
+        * @param cols  number of columns
         * @param rowwise if true, reshape by row
         * @return output matrix
         */
@@ -852,10 +852,10 @@ public class LibMatrixReorg {
         * default, while R uses always a column-wise read, rowwise specifying 
the write order and column-wise being the
         * default.
         * 
-        * @param in      input matrix
-        * @param out     output matrix
-        * @param rows    number of rows
-        * @param cols    number of columns
+        * @param in      input matrix
+        * @param out    output matrix
+        * @param rows  number of rows
+        * @param cols  number of columns
         * @param rowwise if true, reshape by row
         * @return output matrix
         */
@@ -870,12 +870,12 @@ public class LibMatrixReorg {
         * default, while R uses always a column-wise read, rowwise specifying 
the write order and column-wise being the
         * default.
         * 
-        * @param in      input matrix
-        * @param out     output matrix
-        * @param rows    number of rows
-        * @param cols    number of columns
+        * @param in      input matrix
+        * @param out    output matrix
+        * @param rows  number of rows
+        * @param cols  number of columns
         * @param rowwise if true, reshape by row
-        * @param k       The parallelization degree
+        * @param k        The parallelization degree
         * @return output matrix
         */
        public static MatrixBlock reshape(MatrixBlock in, MatrixBlock out, int 
rows, int cols, boolean rowwise, int k) {
@@ -939,7 +939,7 @@ public class LibMatrixReorg {
         * @return list of indexed matrix values
         */
        public static List<IndexedMatrixValue> reshape(IndexedMatrixValue in, 
DataCharacteristics mcIn,
-                                                      DataCharacteristics 
mcOut, boolean rowwise, boolean outputEmptyBlocks ) {
+                                                                               
                   DataCharacteristics mcOut, boolean rowwise, boolean 
outputEmptyBlocks ) {
                //prepare inputs
                MatrixIndexes ixIn = in.getIndexes();
                MatrixBlock mbIn = (MatrixBlock) in.getValue();
@@ -1013,7 +1013,7 @@ public class LibMatrixReorg {
                //sanity check inputs
                if( !(data.getValue() instanceof MatrixBlock && 
offset.getValue() instanceof MatrixBlock) )
                        throw new DMLRuntimeException("Unsupported input data: 
expected "+MatrixBlock.class.getName()+" but got 
"+data.getValue().getClass().getName()+" and 
"+offset.getValue().getClass().getName());
-               if(     rmRows && 
data.getValue().getNumRows()!=offset.getValue().getNumRows() 
+               if(      rmRows && 
data.getValue().getNumRows()!=offset.getValue().getNumRows() 
                        || !rmRows && 
data.getValue().getNumColumns()!=offset.getValue().getNumColumns()  ){
                        throw new DMLRuntimeException("Dimension mismatch 
between input data and offsets: ["
                                        
+data.getValue().getNumRows()+"x"+data.getValue().getNumColumns()+" vs 
"+offset.getValue().getNumRows()+"x"+offset.getValue().getNumColumns());
@@ -1090,13 +1090,13 @@ public class LibMatrixReorg {
         * CP rexpand operation (single input, single output), the classic 
example of this operation is one hot encoding of a
         * column to multiple columns.
         * 
-        * @param in     Input matrix
-        * @param ret    Output matrix
-        * @param max    Number of rows/cols of the output
+        * @param in     Input matrix
+        * @param ret   Output matrix
+        * @param max   Number of rows/cols of the output
         * @param rows   If the expansion is in rows direction
         * @param cast   If the values contained should be cast to double 
(rounded up and down)
         * @param ignore Ignore if the input contain values below zero that 
technically is incorrect input.
-        * @param k      Degree of parallelism
+        * @param k       Degree of parallelism
         * @return Output matrix rexpanded
         */
        public static MatrixBlock rexpand(MatrixBlock in, MatrixBlock ret, 
double max, boolean rows, boolean cast, boolean ignore, int k) {
@@ -1107,13 +1107,13 @@ public class LibMatrixReorg {
         * CP rexpand operation (single input, single output), the classic 
example of this operation is one hot encoding of a
         * column to multiple columns.
         * 
-        * @param in     Input matrix
-        * @param ret    Output matrix
-        * @param max    Number of rows/cols of the output
+        * @param in     Input matrix
+        * @param ret   Output matrix
+        * @param max   Number of rows/cols of the output
         * @param rows   If the expansion is in rows direction
         * @param cast   If the values contained should be cast to double 
(rounded up and down)
         * @param ignore Ignore if the input contain values below zero that 
technically is incorrect input.
-        * @param k      Degree of parallelism
+        * @param k       Degree of parallelism
         * @return Output matrix rexpanded
         */
        public static MatrixBlock rexpand(MatrixBlock in, MatrixBlock ret, int 
max, boolean rows, boolean cast, boolean ignore, int k){
@@ -1144,8 +1144,8 @@ public class LibMatrixReorg {
         * ret = table(seq(1, nrow(A)), A, w)
         * 
         * @param seqHeight A sequence vector height.
-        * @param A         The MatrixBlock vector to encode.
-        * @param w         The weight matrix to multiply on output cells.
+        * @param A              The MatrixBlock vector to encode.
+        * @param w              The weight matrix to multiply on output cells.
         * @return A new MatrixBlock with the table result.
         */
        public static MatrixBlock fusedSeqRexpand(int seqHeight, MatrixBlock A, 
double w) {
@@ -1159,12 +1159,12 @@ public class LibMatrixReorg {
         * ret = table(seq(1, nrow(A)), A, w)
         * 
         * @param seqHeight  A sequence vector height.
-        * @param A          The MatrixBlock vector to encode.
-        * @param w          The weight scalar to multiply on output cells.
-        * @param ret        The output MatrixBlock, does not have to be used, 
but depending on updateClen determine the
-        *                   output size.
+        * @param A               The MatrixBlock vector to encode.
+        * @param w               The weight scalar to multiply on output cells.
+        * @param ret           The output MatrixBlock, does not have to be 
used, but depending on updateClen determine the
+        *                                 output size.
         * @param updateClen Update clen, if set to true, ignore dimensions of 
ret, otherwise use the column dimension of
-        *                   ret.
+        *                                 ret.
         * @return A new MatrixBlock or ret.
         */
        public static MatrixBlock fusedSeqRexpand(int seqHeight, MatrixBlock A, 
double w, MatrixBlock ret,
@@ -1179,12 +1179,12 @@ public class LibMatrixReorg {
         * ret = table(seq(1, nrow(A)), A, w)
         * 
         * @param seqHeight  A sequence vector height.
-        * @param A          The MatrixBlock vector to encode.
-        * @param w          The weight matrix to multiply on output cells.
-        * @param ret        The output MatrixBlock, does not have to be used, 
but depending on updateClen determine the
-        *                   output size.
+        * @param A               The MatrixBlock vector to encode.
+        * @param w               The weight matrix to multiply on output cells.
+        * @param ret           The output MatrixBlock, does not have to be 
used, but depending on updateClen determine the
+        *                                 output size.
         * @param updateClen Update clen, if set to true, ignore dimensions of 
ret, otherwise use the column dimension of
-        *                   ret.
+        *                                 ret.
         * @param k                        Parallelization degree
         * @return A new MatrixBlock or ret.
         */
@@ -1318,7 +1318,7 @@ public class LibMatrixReorg {
        /**
         * Quick check if the input is valid for rexpand, this check does not 
guarantee that the input is valid for rexpand
         * 
-        * @param in     Input matrix block
+        * @param in     Input matrix block
         * @param ignore If zero valued cells should be ignored
         */
        public static void checkRexpand(MatrixBlock in, boolean ignore){
@@ -1330,12 +1330,12 @@ public class LibMatrixReorg {
        /**
         * MR/Spark rexpand operation (single input, multiple outputs incl 
empty blocks)
         * 
-        * @param data    Input indexed matrix block
-        * @param max     Total nrows/cols of the output
-        * @param rows    If the expansion is in rows direction
-        * @param cast    If the values contained should be cast to double 
(rounded up and down)
+        * @param data  Input indexed matrix block
+        * @param max    Total nrows/cols of the output
+        * @param rows  If the expansion is in rows direction
+        * @param cast  If the values contained should be cast to double 
(rounded up and down)
         * @param ignore  Ignore if the input contain values below zero that 
technically is incorrect input.
-        * @param blen    The block size to slice the output up into
+        * @param blen  The block size to slice the output up into
         * @param outList The output indexedMatrixValues (a list to add all the 
output blocks to / modify)
         */
        public static void rexpand(IndexedMatrixValue data, double max, boolean 
rows, boolean cast, boolean ignore, long blen, ArrayList<IndexedMatrixValue> 
outList) {
@@ -3416,7 +3416,7 @@ public class LibMatrixReorg {
        }
        
        private static void createNonZeroIndexes(DataCharacteristics mcIn, 
DataCharacteristics mcOut,
-                                                MatrixBlock in, long 
row_offset, long col_offset, boolean rowwise, HashSet<MatrixIndexes> ret) {
+                                                                               
         MatrixBlock in, long row_offset, long col_offset, boolean rowwise, 
HashSet<MatrixIndexes> ret) {
                Iterator<IJV> iter = in.getSparseBlockIterator();
                while( iter.hasNext() ) {
                        IJV cell = iter.next();
@@ -3444,7 +3444,7 @@ public class LibMatrixReorg {
        }
 
        private static void reshapeDense(MatrixBlock in, long row_offset, long 
col_offset, Map<MatrixIndexes,MatrixBlock> rix,
-                                        DataCharacteristics mcIn, 
DataCharacteristics mcOut, boolean rowwise ) {
+                                                                        
DataCharacteristics mcIn, DataCharacteristics mcOut, boolean rowwise ) {
                if( in.isEmptyBlock(false) )
                        return;
                
@@ -3480,7 +3480,7 @@ public class LibMatrixReorg {
        }
 
        private static void reshapeSparse(MatrixBlock in, long row_offset, long 
col_offset, Map<MatrixIndexes,MatrixBlock> rix,
-                                         DataCharacteristics mcIn, 
DataCharacteristics mcOut, boolean rowwise ) {
+                                                                         
DataCharacteristics mcIn, DataCharacteristics mcOut, boolean rowwise ) {
                if( in.isEmptyBlock(false) )
                        return;
                
@@ -3542,7 +3542,7 @@ public class LibMatrixReorg {
        }
        
        private static MatrixIndexes computeInBlockIndex(MatrixIndexes ixout, 
long ai, long aj,
-                                                        DataCharacteristics 
mcIn, DataCharacteristics mcOut, boolean rowwise )
+                                                                               
                         DataCharacteristics mcIn, DataCharacteristics mcOut, 
boolean rowwise )
        {
                long tempc = computeGlobalCellIndex(mcIn, ai, aj, rowwise);
                long ci = rowwise ? 
(tempc/mcOut.getCols())%mcOut.getBlocksize() : 
@@ -4517,13 +4517,13 @@ public class LibMatrixReorg {
         * https://dl.acm.org/doi/pdf/10.1145/355611.362542.
         *
         * @param matrix   The matrix whose elements are being shifted.
-        * @param moved    Boolean array tracking whether an element has 
already been moved.
-        * @param rows     The number of rows in the matrix.
+        * @param moved Boolean array tracking whether an element has already 
been moved.
+        * @param rows   The number of rows in the matrix.
         * @param maxIndex The maximum valid index in the matrix.
-        * @param count    The number of elements left to process.
+        * @param count The number of elements left to process.
         * @param workSize The length of moved.
-        * @param start    The starting index for the cycle shift.
-        * @param comp     The corresponding companion index.
+        * @param start The starting index for the cycle shift.
+        * @param comp   The corresponding companion index.
         * @return The updated count of elements remaining to shift.
         */
        private static int simultaneousCycleShift(double[] matrix, boolean[] 
moved, int rows, int maxIndex, int count,
@@ -4585,10 +4585,10 @@ public class LibMatrixReorg {
         * Performs prime factorization of a given number n. The method 
calculates the prime factors of n, their exponents,
         * powers and stores the results in the provided arrays.
         *
-        * @param n         The number to be factorized.
-        * @param primes    Array to store the unique prime factors of n.
+        * @param n              The number to be factorized.
+        * @param primes        Array to store the unique prime factors of n.
         * @param exponents Array to store the exponents of the respective 
prime factors.
-        * @param powers    Array to store the powers of the respective prime 
factors.
+        * @param powers        Array to store the powers of the respective 
prime factors.
         * @return The number of unique prime factors.
         */
        private static int primeFactorization(int n, int[] primes, int[] 
exponents, int[] powers) {
@@ -4635,4 +4635,341 @@ public class LibMatrixReorg {
                }
                return count;
        }
+
+       // TENSOR
+       /**
+        * Performs prime in-place tensor transposition for arbitrary 
permutations.
+        *
+        * @param in
+        *                      Tensor stored as MatrixBlock
+        * @param shape
+        *                      Original shape informtion of tensor
+        * @param perm
+        *                      Permutation of tensor
+        */
+       // (A) If permutation is split-index reducible -> reduce to 2D and use 
transposeInPlaceDenseBrenner()
+       // (B) Else -> decompose perm into adjacent swaps and apply each via 
1324 primitive from EITHOT algorithm
+       // (https://dl.acm.org/doi/10.1145/3711871)
+       // with Brenner's method instead of Catanzaro's algorithm for 
generalizability to arbitrary large dimensions
+       // ------------------------------------------------------------
+       public static boolean transposeInPlaceTensor(MatrixBlock in, int[] 
shape, int[] perm) {
+               final int rank = shape.length;
+
+               // final shape
+               final int[] finalShape = new int[rank];
+               for (int i = 0; i < rank; i++)
+                       finalShape[i] = shape[perm[i]];
+
+               // Identity perm -> metadata only
+               boolean identity = true;
+               for (int i = 0; i < rank; i++) {
+                       if (perm[i] != i) {
+                               identity = false;
+                               break;
+                       }
+               }
+               if (identity) {
+                       restoreMetadata(in, finalShape);
+                       return true;
+               }
+
+               // (A) Split-index reducible
+               int splitIdx = findSplitIndex(perm);
+               if (splitIdx != -1) {
+                       int newRows = 1;
+                       for (int i = 0; i < splitIdx; i++)
+                               newRows *= shape[perm[i]];
+
+                       long newColsL = 1;
+                       for (int i = splitIdx; i < rank; i++)
+                               newColsL *= shape[perm[i]];
+                       int newCols = (int) newColsL;
+
+                       try {
+                               in.setNumRows(newCols);
+                               in.setNumColumns(newRows);
+                               transposeInPlaceDenseBrenner(in, 1);
+                       } finally {
+                               restoreMetadata(in, finalShape);
+                       }
+                       return true;
+               }
+
+               // (B) General path: usage of 1324 primitv
+
+               final double[] tensor = in.getDenseBlockValues();
+               int[] curShape = Arrays.copyOf(shape, rank);
+               in.getDenseBlock().setDims(curShape);
+
+               // plan adjacent swaps to realize perm
+               int[] swaps = permutationToAdjacentSwaps(rank, perm);
+               int swapCount = swaps[0];
+
+               for (int s = 1; s <= swapCount; s++) {
+                       int k = swaps[s];
+                       reshape1324(in, tensor, curShape, k);
+                       int tmp = curShape[k];
+                       curShape[k] = curShape[k + 1];
+                       curShape[k + 1] = tmp;
+                       in.getDenseBlock().setDims(curShape);
+               }
+
+               restoreMetadata(in, finalShape);
+               return true;
+       }
+
+       /**
+        * Applies a single adjacent-axis swap (k <-> k+1) to a dense tensor 
**in-place** by reducing it to a rank-4 view
+        * and calling primitive 1324.
+        *
+        * @param in
+        *                      MatrixBlock holding the dense tensor buffer 
(metadata is temporarily modified)
+        * @param a
+        *                      backing dense buffer (row-major, contiguous)
+        * @param curShape
+        *                      current logical tensor shape/order before 
applying this adjacent swap
+        * @param k
+        *                      adjacent axis index to swap (swaps axis k with 
axis k+1)
+        */
+       private static void reshape1324(MatrixBlock in, double[] a, int[] 
curShape, int k) {
+               int lastDim = curShape.length;
+
+               int left = prod(curShape, 0, k);
+               int A = curShape[k];
+               int B = curShape[k + 1];
+               int right = prod(curShape, k + 2, lastDim);
+
+               // metadata-only reshape to 4D
+               in.getDenseBlock().setDims(new int[] { left, A, B, right });
+
+               // in-place 1324 on that view
+               prim1324(a, 0, left, A, B, right);
+
+               // caller restores dims to curShape after it swaps 
curShape[k],curShape[k+1]
+       }
+
+       private static int prod(int[] shape, int start, int end) {
+               long p = 1;
+               for (int i = start; i < end; i++) {
+                       p *= shape[i];
+               }
+               return (int) p;
+       }
+
+       /**
+        * Decomposes an arbitrary permutation into a sequence of adjacent 
swaps.
+        *
+        * @param rank
+        *                      tensor rank
+        * @param perm
+        *                      target permutation (maps output axis i to input 
axis perm[i])
+        *
+        * @return swap plan array
+        */
+       private static int[] permutationToAdjacentSwaps(int rank, int[] perm) {
+               int[] order = new int[rank];
+               // original order of permutation
+               for (int i = 0; i < rank; i++)
+                       order[i] = i;
+
+               int maxSwaps = rank * (rank - 1) / 2;
+               int[] out = new int[maxSwaps + 1]; // stores swap order
+               int cnt = 0; // number of swaps needed
+
+               for (int targetPos = 0; targetPos < rank; targetPos++) {
+                       int wantedAxis = perm[targetPos];
+
+                       // index of dimension in current permutation
+                       int curPos = -1;
+                       for (int p = targetPos; p < rank; p++) {
+                               if (order[p] == wantedAxis) {
+                                       curPos = p;
+                                       break;
+                               }
+                       }
+                       if (curPos < 0)
+                               throw new IllegalArgumentException("Invalid 
perm");
+
+                       while (curPos > targetPos) {
+                               int t = order[curPos - 1];
+                               order[curPos - 1] = order[curPos];
+                               order[curPos] = t;
+
+                               out[++cnt] = curPos - 1;
+                               curPos--;
+                       }
+               }
+
+               out[0] = cnt;
+               return out;
+       }
+
+       /**
+        * Primitive {@code 1324}: swaps dimensions 2 and 3 while keeping 
dimensions 1 and 4 fixed:
+        *
+        * @param a
+        *                      dense buffer
+        * @param offset
+        *                      base offset into {a} (usually 0)
+        * @param d1
+        *                      first dimension (number of slices)
+        * @param d2
+        *                      second dimension (matrix rows)
+        * @param d3
+        *                      third dimension (matrix cols)
+        * @param d4
+        *                      fourth dimension (block length per matrix cell)
+        */
+       private static void prim1324(double[] a, int offset, int d1, int d2, 
int d3, int d4) {
+               for (int i1 = 0; i1 < d1; i1++) {
+                       int slice = d2 * d3 * d4;
+                       int base = offset + i1 * slice;
+                       transposeBlocksInPlace(a, base, d2, d3, d4);
+               }
+       }
+
+       /**
+        * In-place transpose of an matrix where each element is a contiguous 
block of length {blk}. Performs a cycle-walk
+        * permutation over block positions induced by transpose. For each 
unvisited start position, we rotate blocks along
+        * its cycle using one temporary block buffer.
+        *
+        * @param tensor
+        *                      backing dense buffer
+        * @param base
+        *                      offset of the (m*n*blk) region
+        * @param d2
+        *                      number of rows in the block-matrix
+        * @param d3
+        *                      number of columns in the block-matrix
+        * @param blk
+        *                      block length (number of doubles per cell), d4
+        */
+       private static void transposeBlocksInPlace(double[] tensor, int base, 
int d2, int d3, int blk) {
+               int numBlocks = d2 * d3;
+               boolean[] visited = new boolean[numBlocks];
+               double[] tmp = new double[blk]; // buffer for one block
+
+               for (int start = 0; start < numBlocks; start++) {
+                       if (visited[start])
+                               continue;
+
+                       int next = transposeBlockIndex(start, d2, d3);
+
+                       // no movement
+                       if (next == start) {
+                               visited[start] = true;
+                               continue;
+                       }
+
+                       // save start
+                       System.arraycopy(tensor, base + start * blk, tmp, 0, 
blk);
+
+                       // cycle-following
+                       int cur = start;
+                       while (true) {
+                               visited[cur] = true;
+                               int prev = inverseTransposeBlockIndex(cur, d2, 
d3);
+                               if (prev == start)
+                                       break;
+
+                               System.arraycopy(tensor, base + prev * blk, 
tensor, base + cur * blk, blk);
+                               cur = prev;
+                       }
+
+                       System.arraycopy(tmp, 0, tensor, base + cur * blk, blk);
+                       visited[cur] = true;
+               }
+       }
+
+       /**
+        * Finds the target index of a current block
+        *
+        * @param block_idx
+        *                      index of block
+        * @param m
+        *                      number of rows
+        * @param n
+        *                      number of columns
+        *
+        * @return new block idx
+        */
+       private static int transposeBlockIndex(int block_idx, int m, int n) {
+               int i = block_idx / n;
+               int j = block_idx % n;
+               return j * m + i;
+       }
+
+       /**
+        * Finds the idx of the element which moves to the current block index 
duing permutation
+        *
+        * @param curr_block_idx
+        *                      index of current block
+        * @param m
+        *                      number of rows
+        * @param n
+        *                      number of columns
+        *
+        * @return new block idx
+        */
+       private static int inverseTransposeBlockIndex(int curr_block_idx, int 
m, int n) {
+               int i = curr_block_idx % m;
+               int j = curr_block_idx / m;
+               return i * n + j;
+       }
+
+       /**
+        * Finds a split index for a tensor permutation that allows reduction 
of the permutation to a 2D matrix transpose.
+        * @param perm  permutation of tensor axes
+        * @return  split index {i} if reducible, otherwise {-1}
+        */
+       public static int findSplitIndex(int[] perm) {
+               if (perm == null || perm.length < 2)
+                       return -1;
+               int n = perm.length;
+
+               for (int i = 1; i < n; i++) {
+                       boolean contiguousFirst = isContiguousRange(perm, 0, i);
+                       boolean contiguousSecond = isContiguousRange(perm, i, 
n);
+
+                       if (contiguousFirst && contiguousSecond) {
+                               if (isSorted(perm, 0, i) && isSorted(perm, i, 
n)) {
+                                       return i;
+                               }
+                       }
+               }
+               return -1;
+       }
+
+       private static boolean isSorted(int[] perm, int start, int end) {
+               for (int i = start; i < end - 1; i++)
+                       if (perm[i] > perm[i + 1])
+                               return false;
+               return true;
+       }
+
+       private static boolean isContiguousRange(int[] perm, int start, int 
end) {
+               int min = perm[start], max = perm[start];
+               for (int i = start + 1; i < end; i++) {
+                       if (perm[i] < min)
+                               min = perm[i];
+                       if (perm[i] > max)
+                               max = perm[i];
+               }
+               return (max - min + 1) == (end - start);
+       }
+
+       /**
+        * Restores SystemDS matrix/tensor metadata after an in-place tensor 
permutation.
+        * @param in           matrix/tensor block whose metadata is restored
+        * @param finalShape   final tensor shape after permutation
+        */
+       private static void restoreMetadata(MatrixBlock in, int[] finalShape) {
+               in.setNumRows(finalShape[0]);
+               long totalRemaining = 1;
+               for (int i = 1; i < finalShape.length; i++)
+                       totalRemaining *= finalShape[i];
+               in.setNumColumns((int) totalRemaining);
+               if (in.getDenseBlock() != null)
+                       in.getDenseBlock().setDims(finalShape);
+       }
 }
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java 
b/src/test/java/org/apache/sysds/test/TestUtils.java
index e470dd8253..5ebc243dd4 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -3990,4 +3990,10 @@ public class TestUtils {
                        return 2;
                }
        }
+
+       public static void compareTensorValues(MatrixBlock actual, MatrixBlock 
expected, double epsilon) {
+               double[] a = actual.getDenseBlockValues();
+               double[] e = expected.getDenseBlockValues();
+               compareMatrices(e, a, epsilon);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
 
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
index 7b575cf37c..50e79279b6 100644
--- 
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
@@ -19,16 +19,19 @@
 
 package org.apache.sysds.test.component.matrix.libMatrixReorg;
 
+import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Test;
 
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 public class TransposeInPlaceBrennerTest {
 
+       
        @Test
        public void transposeInPlaceDenseBrennerOnePrime() {
                // 3*4-1 = 11
@@ -102,9 +105,595 @@ public class TransposeInPlaceBrennerTest {
        private void testTransposeInPlaceDense(int rows, int cols, double 
sparsity) {
                MatrixBlock X = MatrixBlock.randOperations(rows, cols, 
sparsity);
                MatrixBlock tX = LibMatrixReorg.transpose(X);
+               LibMatrixReorg.transposeInPlaceDenseBrenner(X, 1);
+               
+               
+               TestUtils.compareMatrices(X, tX, 0);
+       }
+       
+       // Tests for tensor permutations
+       @Test
+       public void testTensorPermuteSplit3D() {
+               int[] shape = {50,2,10};
+               int[] perm = {1,2,0}; 
+               testTransposeInPlaceTensor(shape, perm);
+       } 
+
+       @Test
+       public void testTensorPermuteSplit8D() {
+               int[] shape = {3,2,1,3,2,3,1,2};
+               int[] perm = {4,5,6,7,0,1,2,3};
+               testTransposeInPlaceTensor(shape, perm);
+       } 
+
+       @Test
+       public void testTensorPermuteSplit4D() {
+               int[] shape = {3,2,5,3};
+               int[] perm = {2,3,0,1}; 
+               testTransposeInPlaceTensor(shape, perm);
+       } 
+
+       @Test
+       public void testTensorPermuteSplit2D21() {
+               int[] shape = {4, 10};
+               int[] perm = {1,0}; 
+               testTransposeInPlaceTensor(shape, perm);
+       } 
+
+       //Test for primitives 
+       @Test
+       public void testTensorPermute3D213() {
+               int[] shape = {4, 2, 7};
+               int[] perm = {1,0,2}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
+ 
+       @Test
+       public void testTensorPermute3D132() {
+               int[] shape = {3, 4, 2};
+               int[] perm = {0, 2, 1}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
+       
+       @Test
+       public void testTensorPermute4D1324() {
+               int[] shape = {3, 2, 2, 3};
+               int[] perm = {0, 2, 1, 3}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       @Test
+       public void testTensorPermute4Db1324() {
+               int[] shape = {3, 4, 5, 6};
+               int[] perm = {0, 2, 1, 3}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+        @Test
+       public void testTensorPermuteSplit5D() {
+               int[] shape = {2, 3, 4, 5, 6};
+               int[] perm = {2, 3, 4, 0, 1};
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       @Test
+       public void testTensorPermuteSplit6D() {
+               int[] shape = {4, 3, 2, 5, 8, 2};
+               int[] perm = {3, 4, 5, 0, 1, 2}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       @Test
+       public void testTensorPermuteSplit5DMiddleSwap() {
+               int[] shape = {2, 6, 2, 4, 5};
+               int[] perm = {4, 3, 2, 1, 0}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       @Test
+       public void testTensorPermute5DMiddleSwapComplex() {
+               int[] shape = {2, 2, 3, 4, 2};
+               int[] perm = {0, 2, 1, 3, 4}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       @Test
+       public void testTensorPermute7Db() {
+               int[] shape = {20, 30, 15, 5, 2, 5, 2};
+               int[] perm = {0, 6, 1, 5, 4, 2, 3}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       @Test
+       public void testTensorPermute7D() {
+               int[] shape = {2, 3, 5, 5, 2, 3, 2};
+               int[] perm = {0, 6, 1, 5, 4, 2, 3}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       @Test
+       public void testTensorPermuteSplitMax2() {
+               int[] shape = {1000, 300, 100};
+               int[] perm = {2, 0, 1}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
 
+       @Test
+       public void testTensorPermuteSplitMax3() {
+               int[] shape = {8000, 4000, 2}; 
+               int[] perm = {2, 0, 1}; 
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       @Test
+       public void testTensorPermute3DAllCases() {
+               int[] shape = {2, 3, 2}; 
+               int[] perm1 = {0,1, 2}; 
+               int[] perm2 = {0,2, 1}; 
+               int[] perm3 = {1,0,2};
+               int[] perm4 = {1,2,0}; 
+               int[] perm5 = {2,0,1}; 
+               int[] perm6 = {2,1,0}; 
+               testTransposeInPlaceTensor(shape, perm1);
+               testTransposeInPlaceTensor(shape, perm2);
+               testTransposeInPlaceTensor(shape, perm3);
+               testTransposeInPlaceTensor(shape, perm4);
+               testTransposeInPlaceTensor(shape, perm5);
+               testTransposeInPlaceTensor(shape, perm6);
+       }
+
+       @Test
+       public void testTensorPermuteSplit4Db213() {
+               int[] shape = {2, 3, 4};
+               int[] perm  = {1, 0, 2}; 
+               testTransposeInPlaceTensor(shape, perm); 
+       }
+       
+       @Test
+       public void testTensorPermuteSplit4Db132() {
+               int[] shape = {2, 3, 4};
+               int[] perm  = {0, 2, 1}; 
+               testTransposeInPlaceTensor(shape, perm); 
+       }
+
+       // Edge case tests
+       
+       // 1. Square matrices
+       @Test
+       public void transposeInPlaceDenseSquare5x5() {
+               testTransposeInPlaceDense(5, 5, 0.8);
+       }
+
+       @Test
+       public void transposeInPlaceDenseSquare100x100() {
+               testTransposeInPlaceDense(100, 100, 0.7);
+       }
+
+       @Test
+       public void testTensorPermute3DSquareDims() {
+               int[] shape = {4, 4, 4};
+               int[] perm = {2, 0, 1};
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       // 2. Vectors (1×N and N×1)
+       @Test
+       public void transposeInPlaceDenseRowVector() {
+               testTransposeInPlaceDense(1, 50, 0.9);
+       }
+
+       @Test
+       public void transposeInPlaceDenseColVector() {
+               testTransposeInPlaceDense(50, 1, 0.9);
+       }
+
+       @Test
+       public void testTensorPermuteVectorLike() {
+               int[] shape = {1, 20};
+               int[] perm = {1, 0};
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       // 3. Single element
+       @Test
+       public void transposeInPlaceDenseSingleElement() {
+               testTransposeInPlaceDense(1, 1, 1.0);
+       }
+
+       @Test
+       public void testTensorPermute_SingleElement() {
+               int[] shape = {1, 1, 1};
+               int[] perm = {2, 1, 0};
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       // 4. Prime dimensions
+       @Test
+       public void transposeInPlaceDensePrime7x11() {
+               testTransposeInPlaceDense(7, 11, 0.75);
+       }
+
+       @Test
+       public void transposeInPlaceDensePrime13x17() {
+               testTransposeInPlaceDense(13, 17, 0.82);
+       }
+
+       @Test
+       public void testTensorPermuteAllPrimeDims() {
+               int[] shape = {3, 5, 7};
+               int[] perm = {1, 2, 0};
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       // 5. Power of 2 dimensions (common in computing, just to be sure)
+       @Test
+       public void transposeInPlaceDensePowerOf264x128() {
+               testTransposeInPlaceDense(64, 128, 0.6);
+       }
+
+       @Test
+       public void transposeInPlaceDensePowerOf232x64() {
+               testTransposeInPlaceDense(32, 64, 0.85);
+       }
+
+       @Test
+       public void testTensorPermutePowerOf2Dims() {
+               int[] shape = {8, 16, 4};
+               int[] perm = {2, 1, 0};
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       // 7. Consecutive transpose (should return to original)
+       @Test
+       public void transposeInPlaceDenseConsecutiveTwice() {
+               MatrixBlock X = MatrixBlock.randOperations(7, 13, 0.75);
+               MatrixBlock original = new MatrixBlock(X);
+               
+               LibMatrixReorg.transposeInPlaceDenseBrenner(X, 1);
                LibMatrixReorg.transposeInPlaceDenseBrenner(X, 1);
+               
+               TestUtils.compareMatrices(X, original, 0);
+       }
 
-               TestUtils.compareMatrices(X, tX, 0);
+       @Test
+       public void testTensorPermuteConsecutiveTwice() {
+               int[] shape = {3, 4, 5};
+               int[] perm = {1, 2, 0};
+               
+               MatrixBlock matrix = createDenseTensor(shape);
+               MatrixBlock original = new MatrixBlock(matrix);
+               
+               LibMatrixReorg.transposeInPlaceTensor(matrix, shape, perm);
+               // Apply reverse permutation to get back
+               int[] reversePerm = new int[perm.length];
+               for (int i = 0; i < perm.length; i++) {
+                       reversePerm[perm[i]] = i;
+               }
+               int[] newShape = new int[shape.length];
+               for (int i = 0; i < perm.length; i++) {
+                       newShape[i] = shape[perm[i]];
+               }
+               LibMatrixReorg.transposeInPlaceTensor(matrix, newShape, 
reversePerm);
+               
+               TestUtils.compareMatrices(matrix, original, 0);
+       }
+
+       // 8.tensors with dimension=1
+       @Test
+       public void testTensorPermuteWithDim1Case1() {
+               int[] shape = {1, 5, 3};
+               int[] perm = {2, 0, 1};
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       @Test
+       public void testTensorPermuteWithDim1Case2() {
+               int[] shape = {4, 1, 2, 1};
+               int[] perm = {2, 3, 0, 1};
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       @Test
+       public void testTensorPermuteWithDim1Case3() {
+               int[] shape = {3, 1, 4};
+               int[] perm = {1, 2, 0};
+               testTransposeInPlaceTensor(shape, perm);
+       }
+
+       // 9. Invalid permutations (negative tests)
+       // NOTE: more detailed error handling can be added in the future, 
currently these are just checking for exceptions
+       @Test
+       public void testTensorPermute_InvalidPerm_OutOfRange() {
+               int[] shape = {2, 3, 4};
+               int[] perm = {0, 1, 3}; // 3 is out of range for 3D tensor
+               
+               MatrixBlock matrix = createDenseTensor(shape);
+               
+               assertThrows(Exception.class,
+                       () -> LibMatrixReorg.transposeInPlaceTensor(matrix, 
shape, perm));
+       }
+
+
+       @Test
+       public void testTensorPermuteInvalidPermWrongLength() {
+               int[] shape = {2, 3, 4};
+               int[] perm = {0, 1}; // only 2 elements but 3d tensor
+               
+               MatrixBlock matrix = createDenseTensor(shape);
+               
+               assertThrows(Exception.class,
+                       () -> LibMatrixReorg.transposeInPlaceTensor(matrix, 
shape, perm));
+       }
+
+       @Test
+       public void testTensorPermuteInvalidPermNegative() {
+               int[] shape = {2, 3, 4};
+               int[] perm = {-1, 1, 2}; // negtive index
+               
+               MatrixBlock matrix = createDenseTensor(shape);
+               
+               assertThrows(Exception.class,
+                       () -> LibMatrixReorg.transposeInPlaceTensor(matrix, 
shape, perm));
+       }
+
+       // 10. Null/empty inputs
+       @Test
+       public void testTensorPermuteEmptyShape() {
+               int[] shape = {};
+               
+               assertThrows(Exception.class,
+                       () -> createDenseTensor(shape));
+       }
+
+       @Test
+       public void testTensorPermuteNullMatrix() {
+               int[] shape = {2, 3};
+               int[] perm = {1, 0};
+               
+               assertThrows(Exception.class,
+                       () -> LibMatrixReorg.transposeInPlaceTensor(null, 
shape, perm));
+       }
+
+
+       //Filling matrices 
+       private static MatrixBlock createDenseTensor(int[] shape) {
+               long size = 1;
+               for (int s : shape)
+                       size *= s;
+       
+               if (size > Integer.MAX_VALUE)
+                       throw new IllegalArgumentException("Tensor too large: " 
+ size);
+       
+               int rows = shape[0];
+               long colsL = size / rows;
+               int cols = (int) colsL;
+       
+               MatrixBlock matrix = new MatrixBlock(rows, cols, false);
+               matrix.allocateDenseBlock();
+       
+               double[] values = matrix.getDenseBlockValues();
+               for (int i = 0; i < values.length; i++)
+                       values[i] = i;
+       
+               if (matrix.getDenseBlock() != null)
+                       matrix.getDenseBlock().setDims(shape);
+       
+               return matrix;
+       }
+
+
+       private void testTransposeInPlaceTensor(int[] shape, int[] perm) {
+               MatrixBlock matrix =createDenseTensor(shape);
+               MatrixBlock expected = permutationOutOfPlace(matrix, shape, 
perm);
+               LibMatrixReorg.transposeInPlaceTensor(matrix, shape, perm);
+               TestUtils.compareMatrices(matrix, expected, 0);
+               TestUtils.compareTensorValues(matrix, expected, 0);
+       }
+       
+       //returns the expected matrix (found out-of-place) for comparision
+       private MatrixBlock permutationOutOfPlace(MatrixBlock in, int[] shape, 
int[] perm) {
+               int[] newShape = new int[shape.length];
+               for(int i=0; i<perm.length; i++){
+                       newShape[i] = shape[perm[i]];
+               }
+               
+               int newRows = newShape[0];
+               long newCols = 1;
+               for(int i = 1; i < newShape.length; i++) {
+                       newCols *= newShape[i];
+               }
+
+               MatrixBlock out = new MatrixBlock(newRows, (int)newCols, false);
+               out.allocateDenseBlock();
+               
+               double[] inVal = in.getDenseBlockValues();
+               double[] outVal = out.getDenseBlockValues();
+
+               int[] originalCoords = new int[shape.length];
+               int[] permCoords = new int[shape.length];
+
+               for(int i = 0; i < inVal.length; i++) {
+                       getCoords(i, shape, originalCoords); 
+                       for(int j = 0; j < perm.length; j++) {
+                               permCoords[j] = originalCoords[perm[j]];
+                       }
+                       int outIdx = getIndex(permCoords, newShape);
+                       outVal[outIdx] = inVal[i];
+               }
+               
+               out.setNumRows(newShape[0]);
+               long cols = 1;
+               for(int i=1; i<newShape.length; i++){
+                       cols *= newShape[i];
+               }
+               out.setNumColumns((int) cols);
+               out.getDenseBlock().setDims(newShape);
+               
+               return out;
+       }
+
+       private void getCoords(int index, int[] shape, int[] originalCoords) {
+               for (int i = shape.length - 1; i >= 0; i--) {
+                       originalCoords[i] = index % shape[i];
+                       index /= shape[i];
+               }
+       }
+
+       private int getIndex(int[] coords, int[] shape) {
+               int index = 0;
+               int multiplier = 1;
+               for (int i = shape.length - 1; i >= 0; i--) {
+                       index += coords[i] * multiplier;
+                       multiplier *= shape[i];
+               }
+               return index;
+       }
+
+       //Test for correct meta-data after permutation
+       @Test
+       public void testTensorPermuteSplitShape6D() {
+               int[] shape = {2, 3, 4, 5, 6, 7};
+               int[] perm = {1, 2, 3, 4, 5, 0}; 
+               
+               long size = 1;
+               for(int s : shape) {
+                       size *= s; 
+               }
+               
+               MatrixBlock X = new MatrixBlock((int) size, 1, false);
+               X.allocateDenseBlock();
+               LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+               testTransposeInPlaceTensorShape(X, shape, perm);
+       }
+
+       @Test
+       public void testTensorPermuteSplitShape6DMax() {
+               int[] shape = {1000, 500, 20, 2, 2, 2};
+               int[] perm = {1, 2, 3, 4, 5, 0}; 
+               
+               long size = 1;
+               for(int s : shape) {
+                       size *= s; 
+               }
+               
+               MatrixBlock X = new MatrixBlock((int) size, 1, false);
+               X.allocateDenseBlock();
+               LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+               testTransposeInPlaceTensorShape(X, shape, perm);
+       }
+
+       @Test
+       public void testTensorPermuteSplitShape4D() {
+               int[] shape = {100, 22, 70, 90};
+               int[] perm = {1, 2, 3, 0}; 
+               
+               long size = 1;
+               for(int s : shape) {
+                       size *= s; 
+               }
+               
+               MatrixBlock X = new MatrixBlock((int) size, 1, false);
+               X.allocateDenseBlock();
+               LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+               testTransposeInPlaceTensorShape(X, shape, perm);
+       }
+
+
+       @Test
+       public void testTensorPermuteSplitShape8D() {
+               int[] shape = {10, 22, 7, 9, 30, 6, 4, 7};
+               int[] perm = { 3, 4, 5, 6, 7, 0, 1, 2}; 
+               
+               long size = 1;
+               for(int s : shape) {
+                       size *= s; 
+               }
+               
+               MatrixBlock X = new MatrixBlock((int) size, 1, false);
+               X.allocateDenseBlock();
+               LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+               testTransposeInPlaceTensorShape(X, shape, perm);
+       }
+
+       @Test
+       public void testTensorPermuteSplitShape5DMiddle() {
+               int[] shape = {10, 8, 5, 4, 2};
+               int[] perm = {0, 2, 1, 3, 4}; 
+               
+               long size = 1;
+               for(int s : shape) {
+                       size *= s; 
+               }
+               
+               MatrixBlock X = new MatrixBlock((int) size, 1, false);
+               X.allocateDenseBlock();
+               LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+               testTransposeInPlaceTensorShape(X, shape, perm);
+       }
+
+       @Test
+       public void testTensorPermuteSplitShape5D() {
+               int[] shape = {2,3,5,2,8}; 
+               int[] perm = {3,4,0,1,2}; 
+               
+               long size = 1;
+               for(int s : shape) {
+                       size *= s; }
+       
+               MatrixBlock X = new MatrixBlock((int) size, 1, false);
+               X.allocateDenseBlock();
+               LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+               testTransposeInPlaceTensorShape(X, shape, perm);
+               }
+               
+               @Test
+               public void testTensorPermuteSplitShape_2D() {
+               int[] shape = {2,3}; 
+               int[] perm = {1,0}; 
+               
+               long size = 1;
+               for(int s : shape) {
+                       size *= s; 
+               }
+               
+               MatrixBlock X = new MatrixBlock((int) size, 1, false);
+               X.allocateDenseBlock();
+               LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+               testTransposeInPlaceTensorShape(X, shape, perm);
+       }
+               
+       private void testTransposeInPlaceTensorShape(MatrixBlock transposed_X, 
int[] originalShape, int[] perm){
+               int[] expectedShape = new int[originalShape.length];
+               for(int i = 0; i < perm.length; i++) {
+                       expectedShape[i] = originalShape[perm[i]];
+               }
+               int expectedRows = expectedShape[0];
+               long expectedCols = 1;
+               for(int i = 1; i < expectedShape.length; i++) {
+                       expectedCols *= expectedShape[i];
+               }
+
+               // MatrixBlock shape-match
+               assertEquals("Matrix Rows mismatch", expectedRows, 
transposed_X.getNumRows());
+               assertEquals("Matrix Columns mismatch", (int)expectedCols, 
transposed_X.getNumColumns());
+
+               // DenseBlock shape-match
+               int[] transposedShape = new int[originalShape.length];
+               DenseBlock dense_X = transposed_X.getDenseBlock();
+               if(dense_X != null){
+                       //Comparison of each dimension
+                       for (int i = 0; i < expectedShape.length; i++) {
+                               transposedShape[i] = dense_X.getDim(i);
+                               assertEquals("Dimension " + i + " mismatch", 
expectedShape[i], dense_X.getDim(i));
+                       }
+                       int currentExpectedSuffix = 
expectedShape[expectedShape.length - 1]; 
+                       //Comparison of suffixes
+                       for (int i = expectedShape.length - 1; i >= 1; i--) {
+                               assertEquals("Suffix product at dim " + i + " 
mismatch", currentExpectedSuffix, dense_X.getCumODims(i - 1));
+                               if(i > 1) {
+                                       currentExpectedSuffix *= 
expectedShape[i - 1];
+                               }
+                       }
+               }
        }
 }


Reply via email to