This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new de7e9a0bc4 [SYSTEMDS-3547] Tensor permutation operations
de7e9a0bc4 is described below
commit de7e9a0bc46bcc8ad366373bb25d230c99aaaf20
Author: BeceHQ <[email protected]>
AuthorDate: Tue Mar 31 17:18:26 2026 +0200
[SYSTEMDS-3547] Tensor permutation operations
Closes #2426.
---
.../sysds/runtime/matrix/data/LibMatrixReorg.java | 335 ++++++++++++++++
.../matrix/libMatrixReorg/PermuteTest.java | 444 +++++++++++++++++++++
.../TransposeInPlaceBrennerTest.java | 6 +-
.../component/tensor/TransposeLinDataTest.java | 197 +++++++++
4 files changed, 979 insertions(+), 3 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index f00e0015a8..040a4e1dcb 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -4972,4 +4972,339 @@ public class LibMatrixReorg {
if (in.getDenseBlock() != null)
in.getDenseBlock().setDims(finalShape);
}
+
+ private static long[] getStridesForPermutation(int[] dims) {
+ long[] strides = new long[dims.length];
+ long stride = 1;
+ for( int i = dims.length - 1; i >= 0; i-- ) {
+ strides[i] = stride;
+ stride *= dims[i];
+ }
+ return strides;
+ }
+
+ public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[]
perm) {
+ return permute(in, inDims, perm, 1);
+ }
+
+ public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[]
perm, int k) {
+ int rank = inDims.length;
+
+ boolean isIdentity = true;
+ for( int i = 0; i < rank; i++ ) {
+ if( perm[i] != i ) {
+ isIdentity = false;
+ break;
+ }
+ }
+
+ if( isIdentity ) {
+ return new MatrixBlock(in);
+ }
+
+ int[] outDims = new int[rank];
+ for( int i = 0; i < rank; i++ ) {
+ outDims[i] = inDims[perm[i]];
+ }
+
+ long length = 1;
+ for( int d : outDims ) {
+ length *= d;
+ }
+
+ MatrixBlock out = new MatrixBlock(1, (int)length, false);
+ out.allocateDenseBlock();
+
+ DenseBlock inDB = in.getDenseBlock();
+ DenseBlock outDB = out.getDenseBlock();
+
+ long[] inStrides = getStridesForPermutation(inDims);
+ long[] outStrides = getStridesForPermutation(outDims);
+
+ long[] permutedStrides = new long[rank];
+ for( int i = 0; i < rank; i++ ) {
+ permutedStrides[i] = outStrides[perm[i]];
+ }
+
+ boolean useParallel = (k > 1 || k == -1) && length >=
PAR_NUMCELL_THRESHOLD;
+ int numThreads = k == -1 ?
Runtime.getRuntime().availableProcessors() : k;
+
+ if( inDB.numBlocks() == 1 && outDB.numBlocks() == 1 ) {
+ double[] inData = inDB.valuesAt(0);
+ double[] outData = outDB.valuesAt(0);
+
+ if( useParallel && rank > 0 ) {
+ permuteSingleBlockParallel(inData, outData,
inDims, inStrides,
+ permutedStrides, numThreads, length);
+ } else {
+ permuteSingleBlock(inData, outData, inDims,
inStrides,
+ permutedStrides, 0, 0, 0);
+ }
+ } else {
+ if( useParallel && rank > 0 ) {
+ permuteMultiBlockParallel(inDB, outDB, inDims,
inStrides,
+ permutedStrides, numThreads, length);
+ } else {
+ permuteMultiBlock(inDB, outDB, inDims,
inStrides,
+ permutedStrides, 0, 0L, 0L);
+ }
+ }
+ return out;
+ }
+
+ private static void permuteSingleBlock(
+ double[] inData, double[] outData,
+ int[] inDims, long[] inStrides, long[] permutedStrides,
+ int dim, int inOffset, int outOffset)
+ {
+ if( dim == inDims.length - 1 ) {
+ int len = inDims[dim];
+ int outStride = (int) permutedStrides[dim];
+
+ if( outStride == 1 ) {
+ System.arraycopy(inData, inOffset, outData,
outOffset, len);
+ } else {
+ transposeRow(inData, outData, inOffset,
outOffset, outStride, len);
+ }
+ return;
+ }
+
+ int dimSize = inDims[dim];
+ long inStep = inStrides[dim];
+ long outStep = permutedStrides[dim];
+
+ final int BLOCK_SIZE = 128;
+ for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) {
+ int bimin = Math.min(bi + BLOCK_SIZE, dimSize);
+ for( int i = bi; i < bimin; i++ ) {
+ permuteSingleBlock(
+ inData, outData, inDims,
inStrides, permutedStrides,
+ dim + 1,
+ inOffset + (int)(i * inStep),
+ outOffset + (int)(i * outStep)
+ );
+ }
+ }
+ }
+
+ private static void permuteSingleBlockParallel(
+ double[] inData, double[] outData,
+ int[] inDims, long[] inStrides, long[] permutedStrides,
+ int k, long totalElements) {
+
+ final long elementsPerThread = Math.max(1024, (totalElements +
k - 1) / k);
+ final int actualThreads = (int) Math.min(k, (totalElements +
elementsPerThread - 1) / elementsPerThread);
+
+ final ExecutorService pool =
CommonThreadPool.get(actualThreads);
+ try {
+ final ArrayList<PermuteSingleBlockTask> tasks = new
ArrayList<>();
+
+ for( int t = 0; t < actualThreads; t++ ) {
+ final long start = t * elementsPerThread;
+ final long end = Math.min(start +
elementsPerThread, totalElements);
+
+ if( start >= totalElements ) {
+ break;
+ }
+
+ tasks.add(new PermuteSingleBlockTask(inData,
outData, inDims,
+ inStrides, permutedStrides, start,
end));
+ }
+
+ for( Future<Object> task : pool.invokeAll(tasks) ) {
+ task.get();
+ }
+ } catch (Exception ex) {
+ throw new DMLRuntimeException(ex);
+ } finally {
+ pool.shutdown();
+ }
+ }
+
+ private static void permuteMultiBlock(
+ DenseBlock inDB, DenseBlock outDB,
+ int[] inDims, long[] inStrides, long[] permutedStrides,
+ int dim, long inOffset, long outOffset) {
+
+ if( dim == inDims.length - 1 ) {
+ int len = inDims[dim];
+ long outStride = permutedStrides[dim];
+
+ int inBlockSize = inDB.blockSize();
+ int outBlockSize = outDB.blockSize();
+
+ for( int i = 0; i < len; i++ ) {
+ long currentInAbs = inOffset + i *
inStrides[dim];
+ long currentOutAbs = outOffset + i * outStride;
+
+ int inBlockIdx = (int) (currentInAbs /
inBlockSize);
+ int inRelIdx = (int) (currentInAbs %
inBlockSize);
+
+ int outBlockIdx = (int) (currentOutAbs /
outBlockSize);
+ int outRelIdx = (int) (currentOutAbs %
outBlockSize);
+
+ double[] inArr = inDB.valuesAt(inBlockIdx);
+ double[] outArr = outDB.valuesAt(outBlockIdx);
+
+ if( inArr != null && outArr != null &&
+ inRelIdx < inArr.length && outRelIdx <
outArr.length ) {
+ outArr[outRelIdx] = inArr[inRelIdx];
+ }
+ }
+ return;
+ }
+
+ int dimSize = inDims[dim];
+ long inStep = inStrides[dim];
+ long outStep = permutedStrides[dim];
+
+ final int BLOCK_SIZE = 128;
+ for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) {
+ int bimin = Math.min(bi + BLOCK_SIZE, dimSize);
+ for( int i = bi; i < bimin; i++ ) {
+ permuteMultiBlock(
+ inDB, outDB, inDims, inStrides,
permutedStrides,
+ dim + 1,
+ inOffset + i * inStep,
+ outOffset + i * outStep
+ );
+ }
+ }
+ }
+
+ private static void permuteMultiBlockParallel(
+ DenseBlock inDB, DenseBlock outDB,
+ int[] inDims, long[] inStrides, long[] permutedStrides,
+ int k, long totalElements) {
+
+ final long elementsPerThread = Math.max(1024, (totalElements +
k - 1) / k);
+ final int actualThreads = (int) Math.min(k, (totalElements +
elementsPerThread - 1) / elementsPerThread);
+
+ final ExecutorService pool =
CommonThreadPool.get(actualThreads);
+ try {
+ final ArrayList<PermuteMultiBlockTask> tasks = new
ArrayList<>();
+
+ for( int t = 0; t < actualThreads; t++ ) {
+ final long start = t * elementsPerThread;
+ final long end = Math.min(start +
elementsPerThread, totalElements);
+
+ if( start >= totalElements ) {
+ break;
+ }
+
+ tasks.add(new PermuteMultiBlockTask(inDB,
outDB, inDims,
+ inStrides, permutedStrides, start,
end));
+ }
+
+ for( Future<Object> task : pool.invokeAll(tasks) ) {
+ task.get();
+ }
+
+ } catch (Exception ex) {
+ throw new DMLRuntimeException(ex);
+ } finally {
+ pool.shutdown();
+ }
+ }
+
+ private static class PermuteSingleBlockTask implements Callable<Object>
{
+ //TODO call single-threaded kernel for block
+
+ private final double[] inData;
+ private final double[] outData;
+ private final int[] inDims;
+ private final long[] inStrides;
+ private final long[] permutedStrides;
+ private final long start;
+ private final long end;
+
+ protected PermuteSingleBlockTask(double[] inData, double[]
outData,
+ int[] inDims, long[] inStrides, long[]
permutedStrides,
+ long start, long end) {
+ this.inData = inData;
+ this.outData = outData;
+ this.inDims = inDims;
+ this.inStrides = inStrides;
+ this.permutedStrides = permutedStrides;
+ this.start = start;
+ this.end = end;
+ }
+
+ @Override
+ public Object call() {
+ for( long idx = start; idx < end; idx++ ) {
+ long inIdx = 0;
+ long outIdx = 0;
+ long remaining = idx;
+
+ for( int d = 0; d < inDims.length; d++ ) {
+ long coord = remaining / inStrides[d];
+ remaining = remaining % inStrides[d];
+ inIdx += coord * inStrides[d];
+ outIdx += coord * permutedStrides[d];
+ }
+
+ outData[(int)outIdx] = inData[(int)inIdx];
+ }
+ return null;
+ }
+ }
+
+ private static class PermuteMultiBlockTask implements Callable<Object> {
+ //TODO call single-threaded kernel for block
+
+ private final DenseBlock inDB;
+ private final DenseBlock outDB;
+ private final int[] inDims;
+ private final long[] inStrides;
+ private final long[] permutedStrides;
+ private final long start;
+ private final long end;
+
+ protected PermuteMultiBlockTask(DenseBlock inDB, DenseBlock
outDB,
+ int[] inDims, long[] inStrides, long[]
permutedStrides,
+ long start, long end) {
+ this.inDB = inDB;
+ this.outDB = outDB;
+ this.inDims = inDims;
+ this.inStrides = inStrides;
+ this.permutedStrides = permutedStrides;
+ this.start = start;
+ this.end = end;
+ }
+
+ @Override
+ public Object call() {
+ int inBlockSize = inDB.blockSize();
+ int outBlockSize = outDB.blockSize();
+
+ for( long idx = start; idx < end; idx++ ) {
+ long inIdx = 0;
+ long outIdx = 0;
+ long remaining = idx;
+
+ for( int d = 0; d < inDims.length; d++ ) {
+ long coord = remaining / inStrides[d];
+ remaining = remaining % inStrides[d];
+ inIdx += coord * inStrides[d];
+ outIdx += coord * permutedStrides[d];
+ }
+
+ int inBlockIdx = (int) (inIdx / inBlockSize);
+ int inRelIdx = (int) (inIdx % inBlockSize);
+
+ int outBlockIdx = (int) (outIdx / outBlockSize);
+ int outRelIdx = (int) (outIdx % outBlockSize);
+
+ double[] inArr = inDB.valuesAt(inBlockIdx);
+ double[] outArr = outDB.valuesAt(outBlockIdx);
+
+ if( inArr != null && outArr != null &&
+ inRelIdx < inArr.length && outRelIdx <
outArr.length ) {
+ outArr[outRelIdx] = inArr[inRelIdx];
+ }
+ }
+ return null;
+ }
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java
new file mode 100644
index 0000000000..60810b9228
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrix.libMatrixReorg;
+
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.mockito.Mockito;
+import java.util.Arrays;
+
+public class PermuteTest {
+
+ @Test
+ public void testBasicPermute() {
+ int[] shape = {2, 3, 4};
+ MatrixBlock tensor = generateMatrixBlock(shape);
+
+ Assert.assertEquals(24, tensor.getNumRows() *
tensor.getNumColumns());
+
+ double[] data = tensor.getDenseBlockValues();
+ Assert.assertEquals(23.0, data[1 * 4 * 3 + 2 * 4 + 3], 0.001);
+ Assert.assertEquals(0.0, data[0 * 4 * 3 + 0 * 4 + 0], 0.001);
+
+ int[] permutation = {1, 0, 2};
+ MatrixBlock outTensor = LibMatrixReorg.permute(tensor, shape,
permutation);
+
+ double[] outData = outTensor.getDenseBlockValues();
+ Assert.assertEquals(24, outData.length);
+ Assert.assertEquals(4.0, outData[8], 0.001);
+ Assert.assertEquals(15.0, outData[7], 0.001);
+ }
+
+ @Test
+ public void testPermute2DTranspose() {
+ int[] shape = {10, 5};
+ int[] perm = {1, 0};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @Test
+ public void testPermute3DSimple() {
+ int[] shape = {2, 3, 4};
+ int[] perm = {1, 0, 2};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @Test
+ public void testPermute3DIdentity() {
+ int[] shape = {5, 5, 5};
+ int[] perm = {0, 1, 2};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @Test
+ public void testPermute4DReverse() {
+ int[] shape = {2, 3, 4, 5};
+ int[] perm = {3, 2, 1, 0};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @Test
+ public void testPermuteHighRank() {
+ int[] shape = {2, 2, 2, 2, 2, 2};
+ int[] perm = {5, 0, 4, 1, 3, 2};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @Test
+ public void testLargeBlockLogicMocked() {
+ int[] shape = {10, 10, 10};
+ int[] perm = {2, 0, 1};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ DenseBlock originalDB = in.getDenseBlock();
+ DenseBlock spyDB = Mockito.spy(originalDB);
+ Mockito.when(spyDB.numBlocks()).thenReturn(2);
+ in.setDenseBlock(spyDB);
+
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ MatrixBlock originalIn = generateMatrixBlock(shape);
+ verifyPermutation(originalIn, out, shape, perm);
+ }
+
+ @Test
+ public void testLargeBlockLogicMockedInputAndOutput() {
+ int[] shape = {4, 4, 4};
+ int[] perm = {2, 1, 0};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ DenseBlock spyIn = Mockito.spy(in.getDenseBlock());
+ Mockito.when(spyIn.numBlocks()).thenReturn(5);
+ in.setDenseBlock(spyIn);
+
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ MatrixBlock originalIn = generateMatrixBlock(shape);
+ verifyPermutation(originalIn, out, shape, perm);
+ }
+
+ @Test
+ public void testPermute3DParallel() {
+ int[] shape = {100, 100, 100};
+ int[] perm = {2, 0, 1};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm, -1);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @Test
+ @Ignore
+ public void testPerformanceSingleVsMultiThreaded() {
+ int size = 100;
+ int[] shape = {size, size, size};
+ int[] perm = {2, 0, 1};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+
+ long startSingle = System.nanoTime();
+ MatrixBlock outSingle = LibMatrixReorg.permute(in, shape, perm,
1);
+ long timeSingle = System.nanoTime() - startSingle;
+
+ long startMulti = System.nanoTime();
+ MatrixBlock outMulti = LibMatrixReorg.permute(in, shape, perm,
-1);
+ long timeMulti = System.nanoTime() - startMulti;
+
+ verifyPermutation(in, outSingle, shape, perm);
+ verifyPermutation(in, outMulti, shape, perm);
+
+ System.out.println("Large Matrix (" + size + "x" + size + "x" +
size + "):");
+ System.out.println("Single-threaded: " + timeSingle / 1_000_000
+ " ms");
+ System.out.println("Multi-threaded: " + timeMulti / 1_000_000 +
" ms");
+ System.out.println("Speedup: " + String.format("%.2fx",
(double)timeSingle / timeMulti));
+
+ Assert.assertTrue("Multi-threaded should be faster for large
matrices", timeMulti < timeSingle);
+ }
+
+ @SuppressWarnings("unused")
+ @Test
+ @Ignore
+ public void testPerformanceLargeMatrixSingleVsMulti() {
+ int[] shape = {1, 10000, 10000};
+ int[] perm = {0, 2, 1};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+
+ long startSingle = System.nanoTime();
+ MatrixBlock outSingle = LibMatrixReorg.permute(in, shape, perm,
1);
+ long timeSingle = System.nanoTime() - startSingle;
+
+ long startMulti = System.nanoTime();
+ MatrixBlock outMulti = LibMatrixReorg.permute(in, shape, perm,
-1);
+ long timeMulti = System.nanoTime() - startMulti;
+
+ System.out.println("Large Matrix (" + 1 + "x" + 10000 + "x" +
100000 + "):");
+ System.out.println("Single-threaded: " + timeSingle / 1_000_000
+ " ms");
+ System.out.println("Multi-threaded: " + timeMulti / 1_000_000 +
" ms");
+ System.out.println("Speedup: " + String.format("%.2fx",
(double)timeSingle / timeMulti));
+
+ Assert.assertTrue("Multi-threaded should be faster for large
matrices", timeMulti < timeSingle);
+ }
+
+ @SuppressWarnings("unused")
+ @Test
+ @Ignore
+ public void testPerformancePermuteVsNativeTranspose() {
+ int size = 1000;
+ MatrixBlock in = new MatrixBlock(size, size, false);
+ in.allocateDenseBlock();
+ double[] data = in.getDenseBlockValues();
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < size; j++) {
+ data[i * size + j] = i * size + j;
+ }
+ }
+
+ int[] shape = {size, size};
+ int[] perm = {1, 0};
+
+ long startPermute = System.nanoTime();
+ MatrixBlock outPermute = LibMatrixReorg.permute(in, shape,
perm, -1);
+ long timePermute = System.nanoTime() - startPermute;
+
+ long startTranspose = System.nanoTime();
+ MatrixBlock outTranspose = LibMatrixReorg.transpose(in);
+ long timeTranspose = System.nanoTime() - startTranspose;
+
+ System.out.println("Transpose Performance (" + size + "x" +
size + "):");
+ System.out.println("Permute function: " + timePermute /
1_000_000 + " ms");
+ System.out.println("Native transpose: " + timeTranspose /
1_000_000 + " ms");
+ System.out.println("Ratio: " + String.format("%.2fx",
(double)timePermute / timeTranspose));
+
+ double[] permuteData = outPermute.getDenseBlockValues();
+
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < size; j++) {
+ double expected = in.get(j, i);
+ double actual = permuteData[i * size + j];
+ Assert.assertEquals("Mismatch at (" + i + "," +
j + ")", expected, actual, 0.0001);
+ }
+ }
+ }
+
+ @Test
+ public void testEdgeCaseSingleElement() {
+ int[] shape = {1, 1, 1};
+ int[] perm = {2, 1, 0};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @Test
+ public void testEdgeCaseOneDimensionOne() {
+ int[] shape = {5, 1, 10};
+ int[] perm = {2, 0, 1};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @Test
+ public void testEdgeCaseTwoDimensionsOne() {
+ int[] shape = {1, 1, 100};
+ int[] perm = {2, 1, 0};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @SuppressWarnings("unused")
+ @Test
+ public void testConsecutivePermutations() {
+ int[] shape = {3, 4, 5};
+ int[] perm1 = {1, 0, 2};
+ int[] perm2 = {2, 0, 1};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock temp = LibMatrixReorg.permute(in, shape, perm1);
+
+ int[] tempShape = {shape[perm1[0]], shape[perm1[1]],
shape[perm1[2]]};
+ MatrixBlock out = LibMatrixReorg.permute(temp, tempShape,
perm2);
+
+ int[] finalShape = {tempShape[perm2[0]], tempShape[perm2[1]],
tempShape[perm2[2]]};
+
+ verifyPermutation(temp, out, tempShape, perm2);
+ }
+
+ @Test
+ public void testDifferentThreadCounts() {
+ int[] shape = {50, 50, 50};
+ int[] perm = {2, 0, 1};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+
+ MatrixBlock out1 = LibMatrixReorg.permute(in, shape, perm, 1);
+ MatrixBlock out2 = LibMatrixReorg.permute(in, shape, perm, 2);
+ MatrixBlock out4 = LibMatrixReorg.permute(in, shape, perm, 4);
+ MatrixBlock out8 = LibMatrixReorg.permute(in, shape, perm, 8);
+
+ double[] data1 = out1.getDenseBlockValues();
+ double[] data2 = out2.getDenseBlockValues();
+ double[] data4 = out4.getDenseBlockValues();
+ double[] data8 = out8.getDenseBlockValues();
+
+ for (int i = 0; i < data1.length; i++) {
+ Assert.assertEquals(data1[i], data2[i], 0.0001);
+ Assert.assertEquals(data1[i], data4[i], 0.0001);
+ Assert.assertEquals(data1[i], data8[i], 0.0001);
+ }
+ }
+
+ @Test
+ public void testPermuteAllDimensionsCyclic() {
+ int[] shape = {3, 4, 5, 2};
+ int[] perm = {1, 2, 3, 0};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @Test
+ public void testPermuteNonContiguousStrides() {
+ int[] shape = {7, 11, 13};
+ int[] perm = {2, 0, 1};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ @Test
+ public void testPermuteLargePrimeStrides() {
+ int[] shape = {17, 19};
+ int[] perm = {1, 0};
+
+ MatrixBlock in = generateMatrixBlock(shape);
+ MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+ verifyPermutation(in, out, shape, perm);
+ }
+
+ private MatrixBlock generateMatrixBlock(int[] shape) {
+ long len = 1;
+ for (int d : shape) len *= d;
+
+ MatrixBlock mb = new MatrixBlock(1, (int)len, false);
+ mb.allocateDenseBlock();
+ double[] data = mb.getDenseBlockValues();
+ for (int i = 0; i < data.length; i++) {
+ data[i] = (double) i;
+ }
+ return mb;
+ }
+
+ private void verifyPermutation(MatrixBlock in, MatrixBlock out, int[]
inShape, int[] perm) {
+ double[] inData = new double[(int)(in.getNumRows() *
in.getNumColumns())];
+ double[] outData = new double[(int)(out.getNumRows() *
out.getNumColumns())];
+
+ DenseBlock inDB = in.getDenseBlock();
+ DenseBlock outDB = out.getDenseBlock();
+
+ if (inDB != null) {
+ int inBlockSize = inDB.blockSize();
+ for (int i = 0; i < inDB.numBlocks(); i++) {
+ double[] block = inDB.valuesAt(i);
+ int offset = i * inBlockSize;
+ int len = Math.min(inBlockSize, inData.length -
offset);
+ System.arraycopy(block, 0, inData, offset, len);
+ }
+ }
+
+ if (outDB != null) {
+ int outBlockSize = outDB.blockSize();
+ for (int i = 0; i < outDB.numBlocks(); i++) {
+ double[] block = outDB.valuesAt(i);
+ int offset = i * outBlockSize;
+ int len = Math.min(outBlockSize, outData.length
- offset);
+ System.arraycopy(block, 0, outData, offset,
len);
+ }
+ }
+
+ int rank = inShape.length;
+ int[] outShape = new int[rank];
+ for (int i = 0; i < rank; i++)
+ outShape[i] = inShape[perm[i]];
+
+ long[] outStrides = getStrides(outShape);
+ long[] inStrides = getStrides(inShape);
+
+ long len = 1;
+ for (int d : outShape) len *= d;
+
+ for (long i = 0; i < len; i++) {
+ int[] outCoords = new int[rank];
+ long temp = i;
+ for (int d = 0; d < rank; d++) {
+ outCoords[d] = (int)(temp / outStrides[d]);
+ temp = temp % outStrides[d];
+ }
+
+ int[] inCoords = new int[rank];
+ for (int d = 0; d < rank; d++) {
+ inCoords[perm[d]] = outCoords[d];
+ }
+
+ long inIndex = 0;
+ for (int d = 0; d < rank; d++) {
+ inIndex += inCoords[d] * inStrides[d];
+ }
+
+ double expectedValue = inData[(int)inIndex];
+ double actualValue = outData[(int)i];
+
+ if (Math.abs(expectedValue - actualValue) > 0.0001) {
+ Assert.fail("Mismatch at linear output index "
+ i +
+ ". Output coords " +
Arrays.toString(outCoords) +
+ ". Input coords " +
Arrays.toString(inCoords) +
+ ". Expected " +
expectedValue + " but got " + actualValue);
+ }
+ }
+ }
+
+ private long[] getStrides(int[] dims) {
+ long[] strides = new long[dims.length];
+ long stride = 1;
+ for (int i = dims.length - 1; i >= 0; i--) {
+ strides[i] = stride;
+ stride *= dims[i];
+ }
+ return strides;
+ }
+}
\ No newline at end of file
diff --git
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
index 50e79279b6..45457eae6e 100644
---
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
+++
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
@@ -301,7 +301,7 @@ public class TransposeInPlaceBrennerTest {
}
@Test
- public void testTensorPermute_SingleElement() {
+ public void testTensorPermuteSingleElement() {
int[] shape = {1, 1, 1};
int[] perm = {2, 1, 0};
testTransposeInPlaceTensor(shape, perm);
@@ -403,7 +403,7 @@ public class TransposeInPlaceBrennerTest {
// 9. Invalid permutations (negative tests)
// NOTE: more detailed error handling can be added in the future,
currently these are just checking for exceptions
@Test
- public void testTensorPermute_InvalidPerm_OutOfRange() {
+ public void testTensorPermuteInvalidPermOutOfRange() {
int[] shape = {2, 3, 4};
int[] perm = {0, 1, 3}; // 3 is out of range for 3D tensor
@@ -647,7 +647,7 @@ public class TransposeInPlaceBrennerTest {
}
@Test
- public void testTensorPermuteSplitShape_2D() {
+ public void testTensorPermuteSplitShape2D() {
int[] shape = {2,3};
int[] perm = {1,0};
diff --git
a/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java
b/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java
new file mode 100644
index 0000000000..d7e13a8b56
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.tensor;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.data.TensorBlock;
+ import java.util.Arrays;
+
+public class TransposeLinDataTest {
+
+ @Test
+ public void Testrightelem(){
+ int[] shape = {2, 3, 4};
+ TensorBlock tensor = TensorUtils.createArangeTensor(shape);
+
+ Assert.assertArrayEquals(new int[]{2, 3, 4}, tensor.getDims());
+ Assert.assertEquals(0.0, tensor.get(new int[]{0, 0, 0}));
+ Assert.assertEquals(23.0, tensor.get(new int[]{1, 2, 3}));
+ Assert.assertEquals(6.0, tensor.get(new int[]{0, 1, 2}));
+ Assert.assertEquals(12.0, tensor.get(new int[]{1, 0, 0}));
+ printTensor(tensor);
+
+
+ int[] permutation = {1, 0, 2};
+ TensorBlock outTensor = PermuteIt.permute(tensor, permutation);
+ printTensor(outTensor);
+
+ Assert.assertArrayEquals(new int[]{3, 2, 4}, outTensor.getDims());
+ Assert.assertEquals(0.0, outTensor.get(new int[]{0,0,0}));
+ Assert.assertEquals(23.0, outTensor.get(new int[]{2, 1, 3}));
+ Assert.assertEquals(12.0, outTensor.get(new int[]{0, 1, 0}));
+ Assert.assertEquals(17.0, outTensor.get(new int[]{1, 1, 1}));
+
+
+ int[] second_permutation = {2, 1, 0};
+ TensorBlock perm2Block = PermuteIt.permute(tensor,
second_permutation);
+ printTensor(perm2Block);
+
+ Assert.assertArrayEquals(new int[]{4, 3, 2}, perm2Block.getDims());
+ Assert.assertEquals(0.0, perm2Block.get(new int[]{0, 0, 0}));
+ Assert.assertEquals(12.0, perm2Block.get(new int[]{0, 0, 1}));
+ Assert.assertEquals(11.0, perm2Block.get(new int[]{3, 2, 0}));
+ Assert.assertEquals(23.0, perm2Block.get(new int[]{3, 2, 1}));
+
+ }
+
+
+
+
+ public class TensorUtils {
+
+ public static TensorBlock createArangeTensor(int[] shape) {
+ TensorBlock tb = new TensorBlock(ValueType.FP64, shape);
+ tb.allocateBlock();
+ double[] counter = { 0.0 };
+ int[] currentIndices = new int[shape.length];
+
+ fillRecursively(tb, shape, 0, currentIndices, counter);
+
+ return tb;
+ }
+
+ private static void fillRecursively(TensorBlock tb, int[] shape, int
dim, int[] currentIndices, double[] counter) {
+ if (dim == shape.length) {
+ tb.set(currentIndices, counter[0]);
+ counter[0]++;
+ return;
+ }
+
+ for (int i = 0; i < shape[dim]; i++) {
+ currentIndices[dim] = i;
+
+ fillRecursively(tb, shape, dim + 1, currentIndices, counter);
+ }
+ }
+ }
+
+
+
+ public class PermuteIt {
+
+
+ public static TensorBlock permute(TensorBlock tensor, int[]
permute_dims) {
+
+ int anz_dims = tensor.getNumDims();
+ int[] dims = tensor.getDims();
+ ValueType tensorType = tensor.getValueType();
+
+ int[] out_shape = new int[anz_dims];
+
+ for (int idx = 0; idx < anz_dims; idx++){
+ out_shape[idx] = dims[permute_dims[idx]];
+ }
+
+ TensorBlock outTensor = new TensorBlock(tensorType, out_shape);
+ outTensor.allocateBlock();
+
+ int[] inIndex = new int[anz_dims];
+ int[] outIndex = new int[anz_dims];
+
+ rekursion(tensor, outTensor, permute_dims, dims, 0, inIndex,
outIndex);
+ return outTensor;
+ }
+
+ public static void rekursion(TensorBlock inTensor,
+ TensorBlock outTensor,
+ int[] permutation,
+ int[] inShape,
+ int dim,
+ int[] inIndex,
+ int[]outIndex
+ ){
+
+ if (dim == inShape.length) {
+ for(int idx = 0; idx < permutation.length; idx++){
+ outIndex[idx] = inIndex[permutation[idx]];
+ }
+ double val = (double) inTensor.get(inIndex);
+ outTensor.set(outIndex, val);
+ return;
+ }
+
+ for(int idx = 0; idx < inShape[dim]; idx++){
+ inIndex[dim] = idx;
+ rekursion(inTensor, outTensor, permutation, inShape, dim+1,
inIndex, outIndex);
+ }
+
+ }
+
+ }
+
+
+ public static void printTensor(TensorBlock tb) {
+ StringBuilder sb = new StringBuilder();
+ int[] shape = tb.getDims();
+ int[] currentIndices = new int[shape.length];
+
+ sb.append("Tensor(").append(Arrays.toString(shape)).append("):\n");
+ printRecursive(tb, shape, 0, currentIndices, sb, 0);
+
+ System.out.println(sb.toString());
+ }
+
+ private static void printRecursive(TensorBlock tb, int[] shape, int dim,
int[] indices, StringBuilder sb, int indent) {
+ for (int k = 0; k < indent; k++) sb.append(" ");
+
+ sb.append("[");
+
+ if (dim == shape.length - 1) {
+ for (int i = 0; i < shape[dim]; i++) {
+ indices[dim] = i;
+ double val = (double) tb.get(indices);
+ sb.append(String.format("%.1f", val));
+ if (i < shape[dim] - 1) sb.append(", ");
+ }
+ sb.append("]");
+ }
+
+ else {
+ sb.append("\n");
+ for (int i = 0; i < shape[dim]; i++) {
+ indices[dim] = i;
+ printRecursive(tb, shape, dim + 1, indices, sb, indent + 2);
+
+ if (i < shape[dim] - 1) {
+ sb.append(",");
+ sb.append("\n");
+ if (shape.length - dim > 2) sb.append("\n");
+ }
+ }
+ sb.append("\n");
+ for (int k = 0; k < indent; k++) sb.append(" ");
+ sb.append("]");
+ }
+ }
+
+}
\ No newline at end of file