This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c00c89d670 [SYSTEMDS-3547] Tensor in-place permutation operations
c00c89d670 is described below
commit c00c89d670177f6173f38c0a65c9abe12eb05be2
Author: dogakarakas <[email protected]>
AuthorDate: Tue Mar 31 12:38:29 2026 +0200
[SYSTEMDS-3547] Tensor in-place permutation operations
Closes #2412.
Co-authored-by: bakiberkay <[email protected]>
---
.../sysds/runtime/matrix/data/LibMatrixReorg.java | 443 +++++++++++++--
src/test/java/org/apache/sysds/test/TestUtils.java | 6 +
.../TransposeInPlaceBrennerTest.java | 591 ++++++++++++++++++++-
3 files changed, 986 insertions(+), 54 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index ffd7b17a20..f00e0015a8 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -111,7 +111,7 @@ public class LibMatrixReorg {
}
/////////////////////////
- // public interface //
+ // public interface //
/////////////////////////
public static boolean isSupportedReorgOperator( ReorgOperator op ) {
@@ -835,9 +835,9 @@ public class LibMatrixReorg {
* default, while R uses always a column-wise read, rowwise specifying
the write order and column-wise being the
* default.
*
- * @param in input matrix
- * @param rows number of rows
- * @param cols number of columns
+ * @param in input matrix
+ * @param rows number of rows
+ * @param cols number of columns
* @param rowwise if true, reshape by row
* @return output matrix
*/
@@ -852,10 +852,10 @@ public class LibMatrixReorg {
* default, while R uses always a column-wise read, rowwise specifying
the write order and column-wise being the
* default.
*
- * @param in input matrix
- * @param out output matrix
- * @param rows number of rows
- * @param cols number of columns
+ * @param in input matrix
+ * @param out output matrix
+ * @param rows number of rows
+ * @param cols number of columns
* @param rowwise if true, reshape by row
* @return output matrix
*/
@@ -870,12 +870,12 @@ public class LibMatrixReorg {
* default, while R uses always a column-wise read, rowwise specifying
the write order and column-wise being the
* default.
*
- * @param in input matrix
- * @param out output matrix
- * @param rows number of rows
- * @param cols number of columns
+ * @param in input matrix
+ * @param out output matrix
+ * @param rows number of rows
+ * @param cols number of columns
* @param rowwise if true, reshape by row
- * @param k The parallelization degree
+ * @param k The parallelization degree
* @return output matrix
*/
public static MatrixBlock reshape(MatrixBlock in, MatrixBlock out, int
rows, int cols, boolean rowwise, int k) {
@@ -939,7 +939,7 @@ public class LibMatrixReorg {
* @return list of indexed matrix values
*/
public static List<IndexedMatrixValue> reshape(IndexedMatrixValue in,
DataCharacteristics mcIn,
- DataCharacteristics
mcOut, boolean rowwise, boolean outputEmptyBlocks ) {
+
DataCharacteristics mcOut, boolean rowwise, boolean
outputEmptyBlocks ) {
//prepare inputs
MatrixIndexes ixIn = in.getIndexes();
MatrixBlock mbIn = (MatrixBlock) in.getValue();
@@ -1013,7 +1013,7 @@ public class LibMatrixReorg {
//sanity check inputs
if( !(data.getValue() instanceof MatrixBlock &&
offset.getValue() instanceof MatrixBlock) )
throw new DMLRuntimeException("Unsupported input data:
expected "+MatrixBlock.class.getName()+" but got
"+data.getValue().getClass().getName()+" and
"+offset.getValue().getClass().getName());
- if( rmRows &&
data.getValue().getNumRows()!=offset.getValue().getNumRows()
+ if( rmRows &&
data.getValue().getNumRows()!=offset.getValue().getNumRows()
|| !rmRows &&
data.getValue().getNumColumns()!=offset.getValue().getNumColumns() ){
throw new DMLRuntimeException("Dimension mismatch
between input data and offsets: ["
+data.getValue().getNumRows()+"x"+data.getValue().getNumColumns()+" vs
"+offset.getValue().getNumRows()+"x"+offset.getValue().getNumColumns());
@@ -1090,13 +1090,13 @@ public class LibMatrixReorg {
* CP rexpand operation (single input, single output), the classic
example of this operation is one hot encoding of a
* column to multiple columns.
*
- * @param in Input matrix
- * @param ret Output matrix
- * @param max Number of rows/cols of the output
+ * @param in Input matrix
+ * @param ret Output matrix
+ * @param max Number of rows/cols of the output
* @param rows If the expansion is in rows direction
* @param cast If the values contained should be cast to double
(rounded up and down)
* @param ignore Ignore if the input contain values below zero that
technically is incorrect input.
- * @param k Degree of parallelism
+ * @param k Degree of parallelism
* @return Output matrix rexpanded
*/
public static MatrixBlock rexpand(MatrixBlock in, MatrixBlock ret,
double max, boolean rows, boolean cast, boolean ignore, int k) {
@@ -1107,13 +1107,13 @@ public class LibMatrixReorg {
* CP rexpand operation (single input, single output), the classic
example of this operation is one hot encoding of a
* column to multiple columns.
*
- * @param in Input matrix
- * @param ret Output matrix
- * @param max Number of rows/cols of the output
+ * @param in Input matrix
+ * @param ret Output matrix
+ * @param max Number of rows/cols of the output
* @param rows If the expansion is in rows direction
* @param cast If the values contained should be cast to double
(rounded up and down)
* @param ignore Ignore if the input contain values below zero that
technically is incorrect input.
- * @param k Degree of parallelism
+ * @param k Degree of parallelism
* @return Output matrix rexpanded
*/
public static MatrixBlock rexpand(MatrixBlock in, MatrixBlock ret, int
max, boolean rows, boolean cast, boolean ignore, int k){
@@ -1144,8 +1144,8 @@ public class LibMatrixReorg {
* ret = table(seq(1, nrow(A)), A, w)
*
* @param seqHeight A sequence vector height.
- * @param A The MatrixBlock vector to encode.
- * @param w The weight matrix to multiply on output cells.
+ * @param A The MatrixBlock vector to encode.
+ * @param w The weight matrix to multiply on output cells.
* @return A new MatrixBlock with the table result.
*/
public static MatrixBlock fusedSeqRexpand(int seqHeight, MatrixBlock A,
double w) {
@@ -1159,12 +1159,12 @@ public class LibMatrixReorg {
* ret = table(seq(1, nrow(A)), A, w)
*
* @param seqHeight A sequence vector height.
- * @param A The MatrixBlock vector to encode.
- * @param w The weight scalar to multiply on output cells.
- * @param ret The output MatrixBlock, does not have to be used,
but depending on updateClen determine the
- * output size.
+ * @param A The MatrixBlock vector to encode.
+ * @param w The weight scalar to multiply on output cells.
+ * @param ret The output MatrixBlock, does not have to be
used, but depending on updateClen determine the
+ * output size.
* @param updateClen Update clen, if set to true, ignore dimensions of
ret, otherwise use the column dimension of
- * ret.
+ * ret.
* @return A new MatrixBlock or ret.
*/
public static MatrixBlock fusedSeqRexpand(int seqHeight, MatrixBlock A,
double w, MatrixBlock ret,
@@ -1179,12 +1179,12 @@ public class LibMatrixReorg {
* ret = table(seq(1, nrow(A)), A, w)
*
* @param seqHeight A sequence vector height.
- * @param A The MatrixBlock vector to encode.
- * @param w The weight matrix to multiply on output cells.
- * @param ret The output MatrixBlock, does not have to be used,
but depending on updateClen determine the
- * output size.
+ * @param A The MatrixBlock vector to encode.
+ * @param w The weight matrix to multiply on output cells.
+ * @param ret The output MatrixBlock, does not have to be
used, but depending on updateClen determine the
+ * output size.
* @param updateClen Update clen, if set to true, ignore dimensions of
ret, otherwise use the column dimension of
- * ret.
+ * ret.
* @param k Parallelization degree
* @return A new MatrixBlock or ret.
*/
@@ -1318,7 +1318,7 @@ public class LibMatrixReorg {
/**
* Quick check if the input is valid for rexpand, this check does not
guarantee that the input is valid for rexpand
*
- * @param in Input matrix block
+ * @param in Input matrix block
* @param ignore If zero valued cells should be ignored
*/
public static void checkRexpand(MatrixBlock in, boolean ignore){
@@ -1330,12 +1330,12 @@ public class LibMatrixReorg {
/**
* MR/Spark rexpand operation (single input, multiple outputs incl
empty blocks)
*
- * @param data Input indexed matrix block
- * @param max Total nrows/cols of the output
- * @param rows If the expansion is in rows direction
- * @param cast If the values contained should be cast to double
(rounded up and down)
+ * @param data Input indexed matrix block
+ * @param max Total nrows/cols of the output
+ * @param rows If the expansion is in rows direction
+ * @param cast If the values contained should be cast to double
(rounded up and down)
* @param ignore Ignore if the input contain values below zero that
technically is incorrect input.
- * @param blen The block size to slice the output up into
+ * @param blen The block size to slice the output up into
* @param outList The output indexedMatrixValues (a list to add all the
output blocks to / modify)
*/
public static void rexpand(IndexedMatrixValue data, double max, boolean
rows, boolean cast, boolean ignore, long blen, ArrayList<IndexedMatrixValue>
outList) {
@@ -3416,7 +3416,7 @@ public class LibMatrixReorg {
}
private static void createNonZeroIndexes(DataCharacteristics mcIn,
DataCharacteristics mcOut,
- MatrixBlock in, long
row_offset, long col_offset, boolean rowwise, HashSet<MatrixIndexes> ret) {
+
MatrixBlock in, long row_offset, long col_offset, boolean rowwise,
HashSet<MatrixIndexes> ret) {
Iterator<IJV> iter = in.getSparseBlockIterator();
while( iter.hasNext() ) {
IJV cell = iter.next();
@@ -3444,7 +3444,7 @@ public class LibMatrixReorg {
}
private static void reshapeDense(MatrixBlock in, long row_offset, long
col_offset, Map<MatrixIndexes,MatrixBlock> rix,
- DataCharacteristics mcIn,
DataCharacteristics mcOut, boolean rowwise ) {
+
DataCharacteristics mcIn, DataCharacteristics mcOut, boolean rowwise ) {
if( in.isEmptyBlock(false) )
return;
@@ -3480,7 +3480,7 @@ public class LibMatrixReorg {
}
private static void reshapeSparse(MatrixBlock in, long row_offset, long
col_offset, Map<MatrixIndexes,MatrixBlock> rix,
- DataCharacteristics mcIn,
DataCharacteristics mcOut, boolean rowwise ) {
+
DataCharacteristics mcIn, DataCharacteristics mcOut, boolean rowwise ) {
if( in.isEmptyBlock(false) )
return;
@@ -3542,7 +3542,7 @@ public class LibMatrixReorg {
}
private static MatrixIndexes computeInBlockIndex(MatrixIndexes ixout,
long ai, long aj,
- DataCharacteristics
mcIn, DataCharacteristics mcOut, boolean rowwise )
+
DataCharacteristics mcIn, DataCharacteristics mcOut,
boolean rowwise )
{
long tempc = computeGlobalCellIndex(mcIn, ai, aj, rowwise);
long ci = rowwise ?
(tempc/mcOut.getCols())%mcOut.getBlocksize() :
@@ -4517,13 +4517,13 @@ public class LibMatrixReorg {
* https://dl.acm.org/doi/pdf/10.1145/355611.362542.
*
* @param matrix The matrix whose elements are being shifted.
- * @param moved Boolean array tracking whether an element has
already been moved.
- * @param rows The number of rows in the matrix.
+ * @param moved Boolean array tracking whether an element has already
been moved.
+ * @param rows The number of rows in the matrix.
* @param maxIndex The maximum valid index in the matrix.
- * @param count The number of elements left to process.
+ * @param count The number of elements left to process.
* @param workSize The length of moved.
- * @param start The starting index for the cycle shift.
- * @param comp The corresponding companion index.
+ * @param start The starting index for the cycle shift.
+ * @param comp The corresponding companion index.
* @return The updated count of elements remaining to shift.
*/
private static int simultaneousCycleShift(double[] matrix, boolean[]
moved, int rows, int maxIndex, int count,
@@ -4585,10 +4585,10 @@ public class LibMatrixReorg {
* Performs prime factorization of a given number n. The method
calculates the prime factors of n, their exponents,
* powers and stores the results in the provided arrays.
*
- * @param n The number to be factorized.
- * @param primes Array to store the unique prime factors of n.
+ * @param n The number to be factorized.
+ * @param primes Array to store the unique prime factors of n.
* @param exponents Array to store the exponents of the respective
prime factors.
- * @param powers Array to store the powers of the respective prime
factors.
+ * @param powers Array to store the powers of the respective
prime factors.
* @return The number of unique prime factors.
*/
private static int primeFactorization(int n, int[] primes, int[]
exponents, int[] powers) {
@@ -4635,4 +4635,341 @@ public class LibMatrixReorg {
}
return count;
}
+
+ // TENSOR
+ /**
+ * Performs prime in-place tensor transposition for arbitrary
permutations.
+ *
+ * @param in
+ * Tensor stored as MatrixBlock
+ * @param shape
+ * Original shape informtion of tensor
+ * @param perm
+ * Permutation of tensor
+ */
+ // (A) If permutation is split-index reducible -> reduce to 2D and use
transposeInPlaceDenseBrenner()
+ // (B) Else -> decompose perm into adjacent swaps and apply each via
1324 primitive from EITHOT algorithm
+ // (https://dl.acm.org/doi/10.1145/3711871)
+ // with Brenner's method instead of Catanzaro's algorithm for
generalizability to arbitrary large dimensions
+ // ------------------------------------------------------------
+ public static boolean transposeInPlaceTensor(MatrixBlock in, int[]
shape, int[] perm) {
+ final int rank = shape.length;
+
+ // final shape
+ final int[] finalShape = new int[rank];
+ for (int i = 0; i < rank; i++)
+ finalShape[i] = shape[perm[i]];
+
+ // Identity perm -> metadata only
+ boolean identity = true;
+ for (int i = 0; i < rank; i++) {
+ if (perm[i] != i) {
+ identity = false;
+ break;
+ }
+ }
+ if (identity) {
+ restoreMetadata(in, finalShape);
+ return true;
+ }
+
+ // (A) Split-index reducible
+ int splitIdx = findSplitIndex(perm);
+ if (splitIdx != -1) {
+ int newRows = 1;
+ for (int i = 0; i < splitIdx; i++)
+ newRows *= shape[perm[i]];
+
+ long newColsL = 1;
+ for (int i = splitIdx; i < rank; i++)
+ newColsL *= shape[perm[i]];
+ int newCols = (int) newColsL;
+
+ try {
+ in.setNumRows(newCols);
+ in.setNumColumns(newRows);
+ transposeInPlaceDenseBrenner(in, 1);
+ } finally {
+ restoreMetadata(in, finalShape);
+ }
+ return true;
+ }
+
+ // (B) General path: usage of 1324 primitv
+
+ final double[] tensor = in.getDenseBlockValues();
+ int[] curShape = Arrays.copyOf(shape, rank);
+ in.getDenseBlock().setDims(curShape);
+
+ // plan adjacent swaps to realize perm
+ int[] swaps = permutationToAdjacentSwaps(rank, perm);
+ int swapCount = swaps[0];
+
+ for (int s = 1; s <= swapCount; s++) {
+ int k = swaps[s];
+ reshape1324(in, tensor, curShape, k);
+ int tmp = curShape[k];
+ curShape[k] = curShape[k + 1];
+ curShape[k + 1] = tmp;
+ in.getDenseBlock().setDims(curShape);
+ }
+
+ restoreMetadata(in, finalShape);
+ return true;
+ }
+
+ /**
+ * Applies a single adjacent-axis swap (k <-> k+1) to a dense tensor
**in-place** by reducing it to a rank-4 view
+ * and calling primitive 1324.
+ *
+ * @param in
+ * MatrixBlock holding the dense tensor buffer
(metadata is temporarily modified)
+ * @param a
+ * backing dense buffer (row-major, contiguous)
+ * @param curShape
+ * current logical tensor shape/order before
applying this adjacent swap
+ * @param k
+ * adjacent axis index to swap (swaps axis k with
axis k+1)
+ */
+ private static void reshape1324(MatrixBlock in, double[] a, int[]
curShape, int k) {
+ int lastDim = curShape.length;
+
+ int left = prod(curShape, 0, k);
+ int A = curShape[k];
+ int B = curShape[k + 1];
+ int right = prod(curShape, k + 2, lastDim);
+
+ // metadata-only reshape to 4D
+ in.getDenseBlock().setDims(new int[] { left, A, B, right });
+
+ // in-place 1324 on that view
+ prim1324(a, 0, left, A, B, right);
+
+ // caller restores dims to curShape after it swaps
curShape[k],curShape[k+1]
+ }
+
+ private static int prod(int[] shape, int start, int end) {
+ long p = 1;
+ for (int i = start; i < end; i++) {
+ p *= shape[i];
+ }
+ return (int) p;
+ }
+
+ /**
+ * Decomposes an arbitrary permutation into a sequence of adjacent
swaps.
+ *
+ * @param rank
+ * tensor rank
+ * @param perm
+ * target permutation (maps output axis i to input
axis perm[i])
+ *
+ * @return swap plan array
+ */
+ private static int[] permutationToAdjacentSwaps(int rank, int[] perm) {
+ int[] order = new int[rank];
+ // original order of permutation
+ for (int i = 0; i < rank; i++)
+ order[i] = i;
+
+ int maxSwaps = rank * (rank - 1) / 2;
+ int[] out = new int[maxSwaps + 1]; // stores swap order
+ int cnt = 0; // number of swaps needed
+
+ for (int targetPos = 0; targetPos < rank; targetPos++) {
+ int wantedAxis = perm[targetPos];
+
+ // index of dimension in current permutation
+ int curPos = -1;
+ for (int p = targetPos; p < rank; p++) {
+ if (order[p] == wantedAxis) {
+ curPos = p;
+ break;
+ }
+ }
+ if (curPos < 0)
+ throw new IllegalArgumentException("Invalid
perm");
+
+ while (curPos > targetPos) {
+ int t = order[curPos - 1];
+ order[curPos - 1] = order[curPos];
+ order[curPos] = t;
+
+ out[++cnt] = curPos - 1;
+ curPos--;
+ }
+ }
+
+ out[0] = cnt;
+ return out;
+ }
+
+ /**
+ * Primitive {@code 1324}: swaps dimensions 2 and 3 while keeping
dimensions 1 and 4 fixed:
+ *
+ * @param a
+ * dense buffer
+ * @param offset
+ * base offset into {a} (usually 0)
+ * @param d1
+ * first dimension (number of slices)
+ * @param d2
+ * second dimension (matrix rows)
+ * @param d3
+ * third dimension (matrix cols)
+ * @param d4
+ * fourth dimension (block length per matrix cell)
+ */
+ private static void prim1324(double[] a, int offset, int d1, int d2,
int d3, int d4) {
+ for (int i1 = 0; i1 < d1; i1++) {
+ int slice = d2 * d3 * d4;
+ int base = offset + i1 * slice;
+ transposeBlocksInPlace(a, base, d2, d3, d4);
+ }
+ }
+
+ /**
+ * In-place transpose of an matrix where each element is a contiguous
block of length {blk}. Performs a cycle-walk
+ * permutation over block positions induced by transpose. For each
unvisited start position, we rotate blocks along
+ * its cycle using one temporary block buffer.
+ *
+ * @param tensor
+ * backing dense buffer
+ * @param base
+ * offset of the (m*n*blk) region
+ * @param d2
+ * number of rows in the block-matrix
+ * @param d3
+ * number of columns in the block-matrix
+ * @param blk
+ * block length (number of doubles per cell), d4
+ */
+ private static void transposeBlocksInPlace(double[] tensor, int base,
int d2, int d3, int blk) {
+ int numBlocks = d2 * d3;
+ boolean[] visited = new boolean[numBlocks];
+ double[] tmp = new double[blk]; // buffer for one block
+
+ for (int start = 0; start < numBlocks; start++) {
+ if (visited[start])
+ continue;
+
+ int next = transposeBlockIndex(start, d2, d3);
+
+ // no movement
+ if (next == start) {
+ visited[start] = true;
+ continue;
+ }
+
+ // save start
+ System.arraycopy(tensor, base + start * blk, tmp, 0,
blk);
+
+ // cycle-following
+ int cur = start;
+ while (true) {
+ visited[cur] = true;
+ int prev = inverseTransposeBlockIndex(cur, d2,
d3);
+ if (prev == start)
+ break;
+
+ System.arraycopy(tensor, base + prev * blk,
tensor, base + cur * blk, blk);
+ cur = prev;
+ }
+
+ System.arraycopy(tmp, 0, tensor, base + cur * blk, blk);
+ visited[cur] = true;
+ }
+ }
+
+ /**
+ * Finds the target index of a current block
+ *
+ * @param block_idx
+ * index of block
+ * @param m
+ * number of rows
+ * @param n
+ * number of columns
+ *
+ * @return new block idx
+ */
+ private static int transposeBlockIndex(int block_idx, int m, int n) {
+ int i = block_idx / n;
+ int j = block_idx % n;
+ return j * m + i;
+ }
+
+ /**
+ * Finds the idx of the element which moves to the current block index
duing permutation
+ *
+ * @param curr_block_idx
+ * index of current block
+ * @param m
+ * number of rows
+ * @param n
+ * number of columns
+ *
+ * @return new block idx
+ */
+ private static int inverseTransposeBlockIndex(int curr_block_idx, int
m, int n) {
+ int i = curr_block_idx % m;
+ int j = curr_block_idx / m;
+ return i * n + j;
+ }
+
+ /**
+ * Finds a split index for a tensor permutation that allows reduction
of the permutation to a 2D matrix transpose.
+ * @param perm permutation of tensor axes
+ * @return split index {i} if reducible, otherwise {-1}
+ */
+ public static int findSplitIndex(int[] perm) {
+ if (perm == null || perm.length < 2)
+ return -1;
+ int n = perm.length;
+
+ for (int i = 1; i < n; i++) {
+ boolean contiguousFirst = isContiguousRange(perm, 0, i);
+ boolean contiguousSecond = isContiguousRange(perm, i,
n);
+
+ if (contiguousFirst && contiguousSecond) {
+ if (isSorted(perm, 0, i) && isSorted(perm, i,
n)) {
+ return i;
+ }
+ }
+ }
+ return -1;
+ }
+
+ private static boolean isSorted(int[] perm, int start, int end) {
+ for (int i = start; i < end - 1; i++)
+ if (perm[i] > perm[i + 1])
+ return false;
+ return true;
+ }
+
+ private static boolean isContiguousRange(int[] perm, int start, int
end) {
+ int min = perm[start], max = perm[start];
+ for (int i = start + 1; i < end; i++) {
+ if (perm[i] < min)
+ min = perm[i];
+ if (perm[i] > max)
+ max = perm[i];
+ }
+ return (max - min + 1) == (end - start);
+ }
+
+ /**
+ * Restores SystemDS matrix/tensor metadata after an in-place tensor
permutation.
+ * @param in matrix/tensor block whose metadata is restored
+ * @param finalShape final tensor shape after permutation
+ */
+ private static void restoreMetadata(MatrixBlock in, int[] finalShape) {
+ in.setNumRows(finalShape[0]);
+ long totalRemaining = 1;
+ for (int i = 1; i < finalShape.length; i++)
+ totalRemaining *= finalShape[i];
+ in.setNumColumns((int) totalRemaining);
+ if (in.getDenseBlock() != null)
+ in.getDenseBlock().setDims(finalShape);
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java
b/src/test/java/org/apache/sysds/test/TestUtils.java
index e470dd8253..5ebc243dd4 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -3990,4 +3990,10 @@ public class TestUtils {
return 2;
}
}
+
+ public static void compareTensorValues(MatrixBlock actual, MatrixBlock
expected, double epsilon) {
+ double[] a = actual.getDenseBlockValues();
+ double[] e = expected.getDenseBlockValues();
+ compareMatrices(e, a, epsilon);
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
index 7b575cf37c..50e79279b6 100644
---
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
+++
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
@@ -19,16 +19,19 @@
package org.apache.sysds.test.component.matrix.libMatrixReorg;
+import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
public class TransposeInPlaceBrennerTest {
+
@Test
public void transposeInPlaceDenseBrennerOnePrime() {
// 3*4-1 = 11
@@ -102,9 +105,595 @@ public class TransposeInPlaceBrennerTest {
private void testTransposeInPlaceDense(int rows, int cols, double
sparsity) {
MatrixBlock X = MatrixBlock.randOperations(rows, cols,
sparsity);
MatrixBlock tX = LibMatrixReorg.transpose(X);
+ LibMatrixReorg.transposeInPlaceDenseBrenner(X, 1);
+
+
+ TestUtils.compareMatrices(X, tX, 0);
+ }
+
+ // Tests for tensor permutations
+ @Test
+ public void testTensorPermuteSplit3D() {
+ int[] shape = {50,2,10};
+ int[] perm = {1,2,0};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplit8D() {
+ int[] shape = {3,2,1,3,2,3,1,2};
+ int[] perm = {4,5,6,7,0,1,2,3};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplit4D() {
+ int[] shape = {3,2,5,3};
+ int[] perm = {2,3,0,1};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplit2D21() {
+ int[] shape = {4, 10};
+ int[] perm = {1,0};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ //Test for primitives
+ @Test
+ public void testTensorPermute3D213() {
+ int[] shape = {4, 2, 7};
+ int[] perm = {1,0,2};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermute3D132() {
+ int[] shape = {3, 4, 2};
+ int[] perm = {0, 2, 1};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermute4D1324() {
+ int[] shape = {3, 2, 2, 3};
+ int[] perm = {0, 2, 1, 3};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermute4Db1324() {
+ int[] shape = {3, 4, 5, 6};
+ int[] perm = {0, 2, 1, 3};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplit5D() {
+ int[] shape = {2, 3, 4, 5, 6};
+ int[] perm = {2, 3, 4, 0, 1};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplit6D() {
+ int[] shape = {4, 3, 2, 5, 8, 2};
+ int[] perm = {3, 4, 5, 0, 1, 2};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplit5DMiddleSwap() {
+ int[] shape = {2, 6, 2, 4, 5};
+ int[] perm = {4, 3, 2, 1, 0};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermute5DMiddleSwapComplex() {
+ int[] shape = {2, 2, 3, 4, 2};
+ int[] perm = {0, 2, 1, 3, 4};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermute7Db() {
+ int[] shape = {20, 30, 15, 5, 2, 5, 2};
+ int[] perm = {0, 6, 1, 5, 4, 2, 3};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermute7D() {
+ int[] shape = {2, 3, 5, 5, 2, 3, 2};
+ int[] perm = {0, 6, 1, 5, 4, 2, 3};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplitMax2() {
+ int[] shape = {1000, 300, 100};
+ int[] perm = {2, 0, 1};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+ @Test
+ public void testTensorPermuteSplitMax3() {
+ int[] shape = {8000, 4000, 2};
+ int[] perm = {2, 0, 1};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermute3DAllCases() {
+ int[] shape = {2, 3, 2};
+ int[] perm1 = {0,1, 2};
+ int[] perm2 = {0,2, 1};
+ int[] perm3 = {1,0,2};
+ int[] perm4 = {1,2,0};
+ int[] perm5 = {2,0,1};
+ int[] perm6 = {2,1,0};
+ testTransposeInPlaceTensor(shape, perm1);
+ testTransposeInPlaceTensor(shape, perm2);
+ testTransposeInPlaceTensor(shape, perm3);
+ testTransposeInPlaceTensor(shape, perm4);
+ testTransposeInPlaceTensor(shape, perm5);
+ testTransposeInPlaceTensor(shape, perm6);
+ }
+
+ @Test
+ public void testTensorPermuteSplit4Db213() {
+ int[] shape = {2, 3, 4};
+ int[] perm = {1, 0, 2};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplit4Db132() {
+ int[] shape = {2, 3, 4};
+ int[] perm = {0, 2, 1};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ // Edge case tests
+
+ // 1. Square matrices
+ @Test
+ public void transposeInPlaceDenseSquare5x5() {
+ testTransposeInPlaceDense(5, 5, 0.8);
+ }
+
+ @Test
+ public void transposeInPlaceDenseSquare100x100() {
+ testTransposeInPlaceDense(100, 100, 0.7);
+ }
+
+ @Test
+ public void testTensorPermute3DSquareDims() {
+ int[] shape = {4, 4, 4};
+ int[] perm = {2, 0, 1};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ // 2. Vectors (1×N and N×1)
+ @Test
+ public void transposeInPlaceDenseRowVector() {
+ testTransposeInPlaceDense(1, 50, 0.9);
+ }
+
+ @Test
+ public void transposeInPlaceDenseColVector() {
+ testTransposeInPlaceDense(50, 1, 0.9);
+ }
+
+ @Test
+ public void testTensorPermuteVectorLike() {
+ int[] shape = {1, 20};
+ int[] perm = {1, 0};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ // 3. Single element
+ @Test
+ public void transposeInPlaceDenseSingleElement() {
+ testTransposeInPlaceDense(1, 1, 1.0);
+ }
+
+ @Test
+ public void testTensorPermute_SingleElement() {
+ int[] shape = {1, 1, 1};
+ int[] perm = {2, 1, 0};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ // 4. Prime dimensions
+ @Test
+ public void transposeInPlaceDensePrime7x11() {
+ testTransposeInPlaceDense(7, 11, 0.75);
+ }
+
+ @Test
+ public void transposeInPlaceDensePrime13x17() {
+ testTransposeInPlaceDense(13, 17, 0.82);
+ }
+
+ @Test
+ public void testTensorPermuteAllPrimeDims() {
+ int[] shape = {3, 5, 7};
+ int[] perm = {1, 2, 0};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ // 5. Power of 2 dimensions (common in computing, just to be sure)
+ @Test
+ public void transposeInPlaceDensePowerOf264x128() {
+ testTransposeInPlaceDense(64, 128, 0.6);
+ }
+
+ @Test
+ public void transposeInPlaceDensePowerOf232x64() {
+ testTransposeInPlaceDense(32, 64, 0.85);
+ }
+
+ @Test
+ public void testTensorPermutePowerOf2Dims() {
+ int[] shape = {8, 16, 4};
+ int[] perm = {2, 1, 0};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ // 7. Consecutive transpose (should return to original)
+ @Test
+ public void transposeInPlaceDenseConsecutiveTwice() {
+ MatrixBlock X = MatrixBlock.randOperations(7, 13, 0.75);
+ MatrixBlock original = new MatrixBlock(X);
+
+ LibMatrixReorg.transposeInPlaceDenseBrenner(X, 1);
LibMatrixReorg.transposeInPlaceDenseBrenner(X, 1);
+
+ TestUtils.compareMatrices(X, original, 0);
+ }
- TestUtils.compareMatrices(X, tX, 0);
+ @Test
+ public void testTensorPermuteConsecutiveTwice() {
+ int[] shape = {3, 4, 5};
+ int[] perm = {1, 2, 0};
+
+ MatrixBlock matrix = createDenseTensor(shape);
+ MatrixBlock original = new MatrixBlock(matrix);
+
+ LibMatrixReorg.transposeInPlaceTensor(matrix, shape, perm);
+ // Apply reverse permutation to get back
+ int[] reversePerm = new int[perm.length];
+ for (int i = 0; i < perm.length; i++) {
+ reversePerm[perm[i]] = i;
+ }
+ int[] newShape = new int[shape.length];
+ for (int i = 0; i < perm.length; i++) {
+ newShape[i] = shape[perm[i]];
+ }
+ LibMatrixReorg.transposeInPlaceTensor(matrix, newShape,
reversePerm);
+
+ TestUtils.compareMatrices(matrix, original, 0);
+ }
+
+ // 8.tensors with dimension=1
+ @Test
+ public void testTensorPermuteWithDim1Case1() {
+ int[] shape = {1, 5, 3};
+ int[] perm = {2, 0, 1};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteWithDim1Case2() {
+ int[] shape = {4, 1, 2, 1};
+ int[] perm = {2, 3, 0, 1};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteWithDim1Case3() {
+ int[] shape = {3, 1, 4};
+ int[] perm = {1, 2, 0};
+ testTransposeInPlaceTensor(shape, perm);
+ }
+
+ // 9. Invalid permutations (negative tests)
+ // NOTE: more detailed error handling can be added in the future,
currently these are just checking for exceptions
+ @Test
+ public void testTensorPermute_InvalidPerm_OutOfRange() {
+ int[] shape = {2, 3, 4};
+ int[] perm = {0, 1, 3}; // 3 is out of range for 3D tensor
+
+ MatrixBlock matrix = createDenseTensor(shape);
+
+ assertThrows(Exception.class,
+ () -> LibMatrixReorg.transposeInPlaceTensor(matrix,
shape, perm));
+ }
+
+
+ @Test
+ public void testTensorPermuteInvalidPermWrongLength() {
+ int[] shape = {2, 3, 4};
+ int[] perm = {0, 1}; // only 2 elements but 3d tensor
+
+ MatrixBlock matrix = createDenseTensor(shape);
+
+ assertThrows(Exception.class,
+ () -> LibMatrixReorg.transposeInPlaceTensor(matrix,
shape, perm));
+ }
+
+ @Test
+ public void testTensorPermuteInvalidPermNegative() {
+ int[] shape = {2, 3, 4};
+ int[] perm = {-1, 1, 2}; // negtive index
+
+ MatrixBlock matrix = createDenseTensor(shape);
+
+ assertThrows(Exception.class,
+ () -> LibMatrixReorg.transposeInPlaceTensor(matrix,
shape, perm));
+ }
+
+ // 10. Null/empty inputs
+ @Test
+ public void testTensorPermuteEmptyShape() {
+ int[] shape = {};
+
+ assertThrows(Exception.class,
+ () -> createDenseTensor(shape));
+ }
+
+ @Test
+ public void testTensorPermuteNullMatrix() {
+ int[] shape = {2, 3};
+ int[] perm = {1, 0};
+
+ assertThrows(Exception.class,
+ () -> LibMatrixReorg.transposeInPlaceTensor(null,
shape, perm));
+ }
+
+
+ //Filling matrices
+ private static MatrixBlock createDenseTensor(int[] shape) {
+ long size = 1;
+ for (int s : shape)
+ size *= s;
+
+ if (size > Integer.MAX_VALUE)
+ throw new IllegalArgumentException("Tensor too large: "
+ size);
+
+ int rows = shape[0];
+ long colsL = size / rows;
+ int cols = (int) colsL;
+
+ MatrixBlock matrix = new MatrixBlock(rows, cols, false);
+ matrix.allocateDenseBlock();
+
+ double[] values = matrix.getDenseBlockValues();
+ for (int i = 0; i < values.length; i++)
+ values[i] = i;
+
+ if (matrix.getDenseBlock() != null)
+ matrix.getDenseBlock().setDims(shape);
+
+ return matrix;
+ }
+
+
+ private void testTransposeInPlaceTensor(int[] shape, int[] perm) {
+ MatrixBlock matrix =createDenseTensor(shape);
+ MatrixBlock expected = permutationOutOfPlace(matrix, shape,
perm);
+ LibMatrixReorg.transposeInPlaceTensor(matrix, shape, perm);
+ TestUtils.compareMatrices(matrix, expected, 0);
+ TestUtils.compareTensorValues(matrix, expected, 0);
+ }
+
+ //returns the expected matrix (found out-of-place) for comparision
+ private MatrixBlock permutationOutOfPlace(MatrixBlock in, int[] shape,
int[] perm) {
+ int[] newShape = new int[shape.length];
+ for(int i=0; i<perm.length; i++){
+ newShape[i] = shape[perm[i]];
+ }
+
+ int newRows = newShape[0];
+ long newCols = 1;
+ for(int i = 1; i < newShape.length; i++) {
+ newCols *= newShape[i];
+ }
+
+ MatrixBlock out = new MatrixBlock(newRows, (int)newCols, false);
+ out.allocateDenseBlock();
+
+ double[] inVal = in.getDenseBlockValues();
+ double[] outVal = out.getDenseBlockValues();
+
+ int[] originalCoords = new int[shape.length];
+ int[] permCoords = new int[shape.length];
+
+ for(int i = 0; i < inVal.length; i++) {
+ getCoords(i, shape, originalCoords);
+ for(int j = 0; j < perm.length; j++) {
+ permCoords[j] = originalCoords[perm[j]];
+ }
+ int outIdx = getIndex(permCoords, newShape);
+ outVal[outIdx] = inVal[i];
+ }
+
+ out.setNumRows(newShape[0]);
+ long cols = 1;
+ for(int i=1; i<newShape.length; i++){
+ cols *= newShape[i];
+ }
+ out.setNumColumns((int) cols);
+ out.getDenseBlock().setDims(newShape);
+
+ return out;
+ }
+
+ private void getCoords(int index, int[] shape, int[] originalCoords) {
+ for (int i = shape.length - 1; i >= 0; i--) {
+ originalCoords[i] = index % shape[i];
+ index /= shape[i];
+ }
+ }
+
+ private int getIndex(int[] coords, int[] shape) {
+ int index = 0;
+ int multiplier = 1;
+ for (int i = shape.length - 1; i >= 0; i--) {
+ index += coords[i] * multiplier;
+ multiplier *= shape[i];
+ }
+ return index;
+ }
+
+ //Test for correct meta-data after permutation
+ @Test
+ public void testTensorPermuteSplitShape6D() {
+ int[] shape = {2, 3, 4, 5, 6, 7};
+ int[] perm = {1, 2, 3, 4, 5, 0};
+
+ long size = 1;
+ for(int s : shape) {
+ size *= s;
+ }
+
+ MatrixBlock X = new MatrixBlock((int) size, 1, false);
+ X.allocateDenseBlock();
+ LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+ testTransposeInPlaceTensorShape(X, shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplitShape6DMax() {
+ int[] shape = {1000, 500, 20, 2, 2, 2};
+ int[] perm = {1, 2, 3, 4, 5, 0};
+
+ long size = 1;
+ for(int s : shape) {
+ size *= s;
+ }
+
+ MatrixBlock X = new MatrixBlock((int) size, 1, false);
+ X.allocateDenseBlock();
+ LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+ testTransposeInPlaceTensorShape(X, shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplitShape4D() {
+ int[] shape = {100, 22, 70, 90};
+ int[] perm = {1, 2, 3, 0};
+
+ long size = 1;
+ for(int s : shape) {
+ size *= s;
+ }
+
+ MatrixBlock X = new MatrixBlock((int) size, 1, false);
+ X.allocateDenseBlock();
+ LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+ testTransposeInPlaceTensorShape(X, shape, perm);
+ }
+
+
+ @Test
+ public void testTensorPermuteSplitShape8D() {
+ int[] shape = {10, 22, 7, 9, 30, 6, 4, 7};
+ int[] perm = { 3, 4, 5, 6, 7, 0, 1, 2};
+
+ long size = 1;
+ for(int s : shape) {
+ size *= s;
+ }
+
+ MatrixBlock X = new MatrixBlock((int) size, 1, false);
+ X.allocateDenseBlock();
+ LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+ testTransposeInPlaceTensorShape(X, shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplitShape5DMiddle() {
+ int[] shape = {10, 8, 5, 4, 2};
+ int[] perm = {0, 2, 1, 3, 4};
+
+ long size = 1;
+ for(int s : shape) {
+ size *= s;
+ }
+
+ MatrixBlock X = new MatrixBlock((int) size, 1, false);
+ X.allocateDenseBlock();
+ LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+ testTransposeInPlaceTensorShape(X, shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplitShape5D() {
+ int[] shape = {2,3,5,2,8};
+ int[] perm = {3,4,0,1,2};
+
+ long size = 1;
+ for(int s : shape) {
+ size *= s; }
+
+ MatrixBlock X = new MatrixBlock((int) size, 1, false);
+ X.allocateDenseBlock();
+ LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+ testTransposeInPlaceTensorShape(X, shape, perm);
+ }
+
+ @Test
+ public void testTensorPermuteSplitShape_2D() {
+ int[] shape = {2,3};
+ int[] perm = {1,0};
+
+ long size = 1;
+ for(int s : shape) {
+ size *= s;
+ }
+
+ MatrixBlock X = new MatrixBlock((int) size, 1, false);
+ X.allocateDenseBlock();
+ LibMatrixReorg.transposeInPlaceTensor(X, shape, perm);
+ testTransposeInPlaceTensorShape(X, shape, perm);
+ }
+
+ private void testTransposeInPlaceTensorShape(MatrixBlock transposed_X,
int[] originalShape, int[] perm){
+ int[] expectedShape = new int[originalShape.length];
+ for(int i = 0; i < perm.length; i++) {
+ expectedShape[i] = originalShape[perm[i]];
+ }
+ int expectedRows = expectedShape[0];
+ long expectedCols = 1;
+ for(int i = 1; i < expectedShape.length; i++) {
+ expectedCols *= expectedShape[i];
+ }
+
+ // MatrixBlock shape-match
+ assertEquals("Matrix Rows mismatch", expectedRows,
transposed_X.getNumRows());
+ assertEquals("Matrix Columns mismatch", (int)expectedCols,
transposed_X.getNumColumns());
+
+ // DenseBlock shape-match
+ int[] transposedShape = new int[originalShape.length];
+ DenseBlock dense_X = transposed_X.getDenseBlock();
+ if(dense_X != null){
+ //Comparison of each dimension
+ for (int i = 0; i < expectedShape.length; i++) {
+ transposedShape[i] = dense_X.getDim(i);
+ assertEquals("Dimension " + i + " mismatch",
expectedShape[i], dense_X.getDim(i));
+ }
+ int currentExpectedSuffix =
expectedShape[expectedShape.length - 1];
+ //Comparison of suffixes
+ for (int i = expectedShape.length - 1; i >= 1; i--) {
+ assertEquals("Suffix product at dim " + i + "
mismatch", currentExpectedSuffix, dense_X.getCumODims(i - 1));
+ if(i > 1) {
+ currentExpectedSuffix *=
expectedShape[i - 1];
+ }
+ }
+ }
}
}