This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new de7e9a0bc4 [SYSTEMDS-3547] Tensor permutation operations
de7e9a0bc4 is described below

commit de7e9a0bc46bcc8ad366373bb25d230c99aaaf20
Author: BeceHQ <[email protected]>
AuthorDate: Tue Mar 31 17:18:26 2026 +0200

    [SYSTEMDS-3547] Tensor permutation operations
    
    Closes #2426.
---
 .../sysds/runtime/matrix/data/LibMatrixReorg.java  | 335 ++++++++++++++++
 .../matrix/libMatrixReorg/PermuteTest.java         | 444 +++++++++++++++++++++
 .../TransposeInPlaceBrennerTest.java               |   6 +-
 .../component/tensor/TransposeLinDataTest.java     | 197 +++++++++
 4 files changed, 979 insertions(+), 3 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index f00e0015a8..040a4e1dcb 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -4972,4 +4972,339 @@ public class LibMatrixReorg {
                if (in.getDenseBlock() != null)
                        in.getDenseBlock().setDims(finalShape);
        }
+               
+       private static long[] getStridesForPermutation(int[] dims) {
+               long[] strides = new long[dims.length];
+               long stride = 1;
+               for( int i = dims.length - 1; i >= 0; i-- ) {
+                       strides[i] = stride;
+                       stride *= dims[i];
+               }
+               return strides;
+       }
+
+       public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] 
perm) {
+               return permute(in, inDims, perm, 1);
+       }
+
+       public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] 
perm, int k) {
+               int rank = inDims.length;
+               
+               boolean isIdentity = true;
+               for( int i = 0; i < rank; i++ ) {
+                       if( perm[i] != i ) {
+                               isIdentity = false;
+                               break;
+                       }
+               }
+
+               if( isIdentity ) {
+                       return new MatrixBlock(in);
+               }
+
+               int[] outDims = new int[rank];
+               for( int i = 0; i < rank; i++ ) { 
+                       outDims[i] = inDims[perm[i]];
+               }
+
+               long length = 1;
+               for( int d : outDims ) {
+                       length *= d;
+               }
+
+               MatrixBlock out = new MatrixBlock(1, (int)length, false);
+               out.allocateDenseBlock();
+
+               DenseBlock inDB = in.getDenseBlock();
+               DenseBlock outDB = out.getDenseBlock();
+
+               long[] inStrides = getStridesForPermutation(inDims);
+               long[] outStrides = getStridesForPermutation(outDims);
+               
+               long[] permutedStrides = new long[rank];
+               for( int i = 0; i < rank; i++ ) {
+                       permutedStrides[i] = outStrides[perm[i]];
+               }
+
+               boolean useParallel = (k > 1 || k == -1) && length >= 
PAR_NUMCELL_THRESHOLD;
+               int numThreads = k == -1 ? 
Runtime.getRuntime().availableProcessors() : k;
+
+               if( inDB.numBlocks() == 1 && outDB.numBlocks() == 1 ) {
+                       double[] inData = inDB.valuesAt(0);
+                       double[] outData = outDB.valuesAt(0);
+                       
+                       if( useParallel && rank > 0 ) {
+                               permuteSingleBlockParallel(inData, outData, 
inDims, inStrides, 
+                                       permutedStrides, numThreads, length);
+                       } else {
+                               permuteSingleBlock(inData, outData, inDims, 
inStrides, 
+                                       permutedStrides, 0, 0, 0);
+                       }
+               } else {
+                       if( useParallel && rank > 0 ) {
+                               permuteMultiBlockParallel(inDB, outDB, inDims, 
inStrides, 
+                                       permutedStrides, numThreads, length);
+                       } else {
+                               permuteMultiBlock(inDB, outDB, inDims, 
inStrides, 
+                                       permutedStrides, 0, 0L, 0L);
+                       }
+               }
+               return out;
+       }
+
+       private static void permuteSingleBlock(
+               double[] inData, double[] outData,
+               int[] inDims, long[] inStrides, long[] permutedStrides,
+               int dim, int inOffset, int outOffset) 
+       {
+               if( dim == inDims.length - 1 ) {
+                       int len = inDims[dim];
+                       int outStride = (int) permutedStrides[dim];
+
+                       if( outStride == 1 ) {
+                               System.arraycopy(inData, inOffset, outData, 
outOffset, len);
+                       } else {
+                               transposeRow(inData, outData, inOffset, 
outOffset, outStride, len);
+                       }
+                       return;
+               }
+
+               int dimSize = inDims[dim];
+               long inStep = inStrides[dim];
+               long outStep = permutedStrides[dim];
+
+               final int BLOCK_SIZE = 128;
+               for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) {
+                       int bimin = Math.min(bi + BLOCK_SIZE, dimSize);
+                       for( int i = bi; i < bimin; i++ ) {
+                               permuteSingleBlock(
+                                               inData, outData, inDims, 
inStrides, permutedStrides,
+                                               dim + 1,
+                                               inOffset + (int)(i * inStep),
+                                               outOffset + (int)(i * outStep)
+                               );
+                       }
+               }
+       }
+
+       private static void permuteSingleBlockParallel(
+                       double[] inData, double[] outData,
+                       int[] inDims, long[] inStrides, long[] permutedStrides,
+                       int k, long totalElements) {
+               
+               final long elementsPerThread = Math.max(1024, (totalElements + 
k - 1) / k);
+               final int actualThreads = (int) Math.min(k, (totalElements + 
elementsPerThread - 1) / elementsPerThread);
+               
+               final ExecutorService pool = 
CommonThreadPool.get(actualThreads);
+               try {
+                       final ArrayList<PermuteSingleBlockTask> tasks = new 
ArrayList<>();
+                       
+                       for( int t = 0; t < actualThreads; t++ ) {
+                               final long start = t * elementsPerThread;
+                               final long end = Math.min(start + 
elementsPerThread, totalElements);
+                               
+                               if( start >= totalElements ) {
+                                       break;
+                               }
+                               
+                               tasks.add(new PermuteSingleBlockTask(inData, 
outData, inDims, 
+                                       inStrides, permutedStrides, start, 
end));
+                       }
+
+                       for( Future<Object> task : pool.invokeAll(tasks) ) {
+                               task.get();
+                       }
+               } catch (Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               } finally {
+                       pool.shutdown();
+               }
+       }
+
+       private static void permuteMultiBlock(
+               DenseBlock inDB, DenseBlock outDB,
+               int[] inDims, long[] inStrides, long[] permutedStrides,
+               int dim, long inOffset, long outOffset) {
+
+               if( dim == inDims.length - 1 ) {
+                       int len = inDims[dim];
+                       long outStride = permutedStrides[dim];
+                       
+                       int inBlockSize = inDB.blockSize();
+                       int outBlockSize = outDB.blockSize();
+
+                       for( int i = 0; i < len; i++ ) {
+                               long currentInAbs = inOffset + i * 
inStrides[dim];
+                               long currentOutAbs = outOffset + i * outStride;
+                               
+                               int inBlockIdx = (int) (currentInAbs / 
inBlockSize);
+                               int inRelIdx = (int) (currentInAbs % 
inBlockSize);
+                               
+                               int outBlockIdx = (int) (currentOutAbs / 
outBlockSize);
+                               int outRelIdx = (int) (currentOutAbs % 
outBlockSize);
+                               
+                               double[] inArr = inDB.valuesAt(inBlockIdx);
+                               double[] outArr = outDB.valuesAt(outBlockIdx);
+                               
+                               if( inArr != null && outArr != null && 
+                                       inRelIdx < inArr.length && outRelIdx < 
outArr.length ) {
+                                       outArr[outRelIdx] = inArr[inRelIdx];
+                               }
+                       }
+                       return;
+               }
+
+               int dimSize = inDims[dim];
+               long inStep = inStrides[dim];
+               long outStep = permutedStrides[dim];
+
+               final int BLOCK_SIZE = 128;
+               for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) {
+                       int bimin = Math.min(bi + BLOCK_SIZE, dimSize);
+                       for( int i = bi; i < bimin; i++ ) {
+                               permuteMultiBlock(
+                                       inDB, outDB, inDims, inStrides, 
permutedStrides,
+                                       dim + 1,
+                                       inOffset + i * inStep,
+                                       outOffset + i * outStep
+                               );
+                       }
+               }
+       }
+
+       private static void permuteMultiBlockParallel(
+                       DenseBlock inDB, DenseBlock outDB,
+                       int[] inDims, long[] inStrides, long[] permutedStrides,
+                       int k, long totalElements) {
+               
+               final long elementsPerThread = Math.max(1024, (totalElements + 
k - 1) / k);
+               final int actualThreads = (int) Math.min(k, (totalElements + 
elementsPerThread - 1) / elementsPerThread);
+
+               final ExecutorService pool = 
CommonThreadPool.get(actualThreads);
+               try {
+                       final ArrayList<PermuteMultiBlockTask> tasks = new 
ArrayList<>();
+                       
+                       for( int t = 0; t < actualThreads; t++ ) {
+                               final long start = t * elementsPerThread;
+                               final long end = Math.min(start + 
elementsPerThread, totalElements);
+                               
+                               if( start >= totalElements ) {
+                                       break;
+                               }
+                               
+                               tasks.add(new PermuteMultiBlockTask(inDB, 
outDB, inDims, 
+                                       inStrides, permutedStrides, start, 
end));
+                       }
+
+                       for( Future<Object> task : pool.invokeAll(tasks) ) {
+                               task.get();
+                       }
+
+               } catch (Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               } finally {
+                       pool.shutdown();
+               }
+       }
+
+       private static class PermuteSingleBlockTask implements Callable<Object> 
{
+               //TODO call single-threaded kernel for block
+               
+               private final double[] inData;
+               private final double[] outData;
+               private final int[] inDims;
+               private final long[] inStrides;
+               private final long[] permutedStrides;
+               private final long start;
+               private final long end;
+               
+               protected PermuteSingleBlockTask(double[] inData, double[] 
outData,
+                               int[] inDims, long[] inStrides, long[] 
permutedStrides,
+                               long start, long end) {
+                       this.inData = inData;
+                       this.outData = outData;
+                       this.inDims = inDims;
+                       this.inStrides = inStrides;
+                       this.permutedStrides = permutedStrides;
+                       this.start = start;
+                       this.end = end;
+               }
+               
+               @Override
+               public Object call() {
+                       for( long idx = start; idx < end; idx++ ) {
+                               long inIdx = 0;
+                               long outIdx = 0;
+                               long remaining = idx;
+                               
+                               for( int d = 0; d < inDims.length; d++ ) {
+                                       long coord = remaining / inStrides[d];
+                                       remaining = remaining % inStrides[d];
+                                       inIdx += coord * inStrides[d];
+                                       outIdx += coord * permutedStrides[d];
+                               }
+                               
+                               outData[(int)outIdx] = inData[(int)inIdx];
+                       }
+                       return null;
+               }
+       }
+
+       private static class PermuteMultiBlockTask implements Callable<Object> {
+               //TODO call single-threaded kernel for block
+               
+               private final DenseBlock inDB;
+               private final DenseBlock outDB;
+               private final int[] inDims;
+               private final long[] inStrides;
+               private final long[] permutedStrides;
+               private final long start;
+               private final long end;
+               
+               protected PermuteMultiBlockTask(DenseBlock inDB, DenseBlock 
outDB,
+                               int[] inDims, long[] inStrides, long[] 
permutedStrides,
+                               long start, long end) {
+                       this.inDB = inDB;
+                       this.outDB = outDB;
+                       this.inDims = inDims;
+                       this.inStrides = inStrides;
+                       this.permutedStrides = permutedStrides;
+                       this.start = start;
+                       this.end = end;
+               }
+               
+               @Override
+               public Object call() {
+                       int inBlockSize = inDB.blockSize();
+                       int outBlockSize = outDB.blockSize();
+                       
+                       for( long idx = start; idx < end; idx++ ) {
+                               long inIdx = 0;
+                               long outIdx = 0;
+                               long remaining = idx;
+                               
+                               for( int d = 0; d < inDims.length; d++ ) {
+                                       long coord = remaining / inStrides[d];
+                                       remaining = remaining % inStrides[d];
+                                       inIdx += coord * inStrides[d];
+                                       outIdx += coord * permutedStrides[d];
+                               }
+                               
+                               int inBlockIdx = (int) (inIdx / inBlockSize);
+                               int inRelIdx = (int) (inIdx % inBlockSize);
+                               
+                               int outBlockIdx = (int) (outIdx / outBlockSize);
+                               int outRelIdx = (int) (outIdx % outBlockSize);
+                               
+                               double[] inArr = inDB.valuesAt(inBlockIdx);
+                               double[] outArr = outDB.valuesAt(outBlockIdx);
+                               
+                               if( inArr != null && outArr != null && 
+                                       inRelIdx < inArr.length && outRelIdx < 
outArr.length ) {
+                                       outArr[outRelIdx] = inArr[inRelIdx];
+                               }
+                       }
+                       return null;
+               }
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java
 
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java
new file mode 100644
index 0000000000..60810b9228
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrix.libMatrixReorg;
+
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.mockito.Mockito;
+import java.util.Arrays;
+
+public class PermuteTest {
+
+       @Test
+       public void testBasicPermute() {
+               int[] shape = {2, 3, 4};
+               MatrixBlock tensor = generateMatrixBlock(shape);
+               
+               Assert.assertEquals(24, tensor.getNumRows() * 
tensor.getNumColumns());
+               
+               double[] data = tensor.getDenseBlockValues();
+               Assert.assertEquals(23.0, data[1 * 4 * 3 + 2 * 4 + 3], 0.001);
+               Assert.assertEquals(0.0, data[0 * 4 * 3 + 0 * 4 + 0], 0.001);
+
+               int[] permutation = {1, 0, 2};
+               MatrixBlock outTensor = LibMatrixReorg.permute(tensor, shape, 
permutation); 
+
+               double[] outData = outTensor.getDenseBlockValues();
+               Assert.assertEquals(24, outData.length); 
+               Assert.assertEquals(4.0, outData[8], 0.001);
+               Assert.assertEquals(15.0, outData[7], 0.001);
+       }
+
+       @Test
+       public void testPermute2DTranspose() {
+               int[] shape = {10, 5};
+               int[] perm = {1, 0};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+               
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @Test
+       public void testPermute3DSimple() {
+               int[] shape = {2, 3, 4};
+               int[] perm = {1, 0, 2};
+
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @Test
+       public void testPermute3DIdentity() {
+               int[] shape = {5, 5, 5};
+               int[] perm = {0, 1, 2};
+
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @Test
+       public void testPermute4DReverse() {
+               int[] shape = {2, 3, 4, 5};
+               int[] perm = {3, 2, 1, 0};
+
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @Test
+       public void testPermuteHighRank() {
+               int[] shape = {2, 2, 2, 2, 2, 2};
+               int[] perm = {5, 0, 4, 1, 3, 2};
+
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @Test
+       public void testLargeBlockLogicMocked() {
+               int[] shape = {10, 10, 10};
+               int[] perm = {2, 0, 1};
+
+               MatrixBlock in = generateMatrixBlock(shape);
+               DenseBlock originalDB = in.getDenseBlock();
+               DenseBlock spyDB = Mockito.spy(originalDB);
+               Mockito.when(spyDB.numBlocks()).thenReturn(2);
+               in.setDenseBlock(spyDB);
+
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+
+               MatrixBlock originalIn = generateMatrixBlock(shape);
+               verifyPermutation(originalIn, out, shape, perm);
+       }
+
+       @Test
+       public void testLargeBlockLogicMockedInputAndOutput() {
+               int[] shape = {4, 4, 4};
+               int[] perm = {2, 1, 0};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               DenseBlock spyIn = Mockito.spy(in.getDenseBlock());
+               Mockito.when(spyIn.numBlocks()).thenReturn(5);
+               in.setDenseBlock(spyIn);
+               
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+               
+               MatrixBlock originalIn = generateMatrixBlock(shape);
+               verifyPermutation(originalIn, out, shape, perm);
+       }
+
+       @Test
+       public void testPermute3DParallel() {
+               int[] shape = {100, 100, 100};
+               int[] perm = {2, 0, 1};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm, -1);
+               
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @Test
+       @Ignore
+       public void testPerformanceSingleVsMultiThreaded() {
+               int size = 100; 
+               int[] shape = {size, size, size};
+               int[] perm = {2, 0, 1};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               
+               long startSingle = System.nanoTime();
+               MatrixBlock outSingle = LibMatrixReorg.permute(in, shape, perm, 
1);
+               long timeSingle = System.nanoTime() - startSingle;
+               
+               long startMulti = System.nanoTime();
+               MatrixBlock outMulti = LibMatrixReorg.permute(in, shape, perm, 
-1);
+               long timeMulti = System.nanoTime() - startMulti;
+               
+               verifyPermutation(in, outSingle, shape, perm);
+               verifyPermutation(in, outMulti, shape, perm);
+               
+               System.out.println("Large Matrix (" + size + "x" + size + "x" + 
size + "):");
+               System.out.println("Single-threaded: " + timeSingle / 1_000_000 
+ " ms");
+               System.out.println("Multi-threaded: " + timeMulti / 1_000_000 + 
" ms");
+               System.out.println("Speedup: " + String.format("%.2fx", 
(double)timeSingle / timeMulti));
+
+               Assert.assertTrue("Multi-threaded should be faster for large 
matrices", timeMulti < timeSingle);
+       }
+
+       @SuppressWarnings("unused")
+       @Test
+       @Ignore
+       public void testPerformanceLargeMatrixSingleVsMulti() {
+               int[] shape = {1, 10000, 10000};
+               int[] perm = {0, 2, 1};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               
+               long startSingle = System.nanoTime();
+               MatrixBlock outSingle = LibMatrixReorg.permute(in, shape, perm, 
1);
+               long timeSingle = System.nanoTime() - startSingle;
+               
+               long startMulti = System.nanoTime();
+               MatrixBlock outMulti = LibMatrixReorg.permute(in, shape, perm, 
-1);
+               long timeMulti = System.nanoTime() - startMulti;
+               
+               System.out.println("Large Matrix (" + 1 + "x" + 10000 + "x" + 
100000 + "):");
+               System.out.println("Single-threaded: " + timeSingle / 1_000_000 
+ " ms");
+               System.out.println("Multi-threaded: " + timeMulti / 1_000_000 + 
" ms");
+               System.out.println("Speedup: " + String.format("%.2fx", 
(double)timeSingle / timeMulti));
+               
+               Assert.assertTrue("Multi-threaded should be faster for large 
matrices", timeMulti < timeSingle);
+       }
+
+       @SuppressWarnings("unused")
+       @Test
+       @Ignore
+       public void testPerformancePermuteVsNativeTranspose() {
+               int size = 1000;
+               MatrixBlock in = new MatrixBlock(size, size, false);
+               in.allocateDenseBlock();
+               double[] data = in.getDenseBlockValues();
+               for (int i = 0; i < size; i++) {
+                       for (int j = 0; j < size; j++) {
+                               data[i * size + j] = i * size + j;
+                       }
+               }
+               
+               int[] shape = {size, size};
+               int[] perm = {1, 0};
+               
+               long startPermute = System.nanoTime();
+               MatrixBlock outPermute = LibMatrixReorg.permute(in, shape, 
perm, -1);
+               long timePermute = System.nanoTime() - startPermute;
+               
+               long startTranspose = System.nanoTime();
+               MatrixBlock outTranspose = LibMatrixReorg.transpose(in);
+               long timeTranspose = System.nanoTime() - startTranspose;
+               
+               System.out.println("Transpose Performance (" + size + "x" + 
size + "):");
+               System.out.println("Permute function: " + timePermute / 
1_000_000 + " ms");
+               System.out.println("Native transpose: " + timeTranspose / 
1_000_000 + " ms");
+               System.out.println("Ratio: " + String.format("%.2fx", 
(double)timePermute / timeTranspose));
+               
+               double[] permuteData = outPermute.getDenseBlockValues();
+               
+               for (int i = 0; i < size; i++) {
+                       for (int j = 0; j < size; j++) {
+                               double expected = in.get(j, i);
+                               double actual = permuteData[i * size + j];
+                               Assert.assertEquals("Mismatch at (" + i + "," + 
j + ")", expected, actual, 0.0001);
+                       }
+               }
+       }
+
+       @Test
+       public void testEdgeCaseSingleElement() {
+               int[] shape = {1, 1, 1};
+               int[] perm = {2, 1, 0};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+               
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @Test
+       public void testEdgeCaseOneDimensionOne() {
+               int[] shape = {5, 1, 10};
+               int[] perm = {2, 0, 1};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+               
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @Test
+       public void testEdgeCaseTwoDimensionsOne() {
+               int[] shape = {1, 1, 100};
+               int[] perm = {2, 1, 0};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+               
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @SuppressWarnings("unused")
+       @Test
+       public void testConsecutivePermutations() {
+               int[] shape = {3, 4, 5};
+               int[] perm1 = {1, 0, 2};
+               int[] perm2 = {2, 0, 1};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock temp = LibMatrixReorg.permute(in, shape, perm1);
+               
+               int[] tempShape = {shape[perm1[0]], shape[perm1[1]], 
shape[perm1[2]]};
+               MatrixBlock out = LibMatrixReorg.permute(temp, tempShape, 
perm2);
+               
+               int[] finalShape = {tempShape[perm2[0]], tempShape[perm2[1]], 
tempShape[perm2[2]]};
+               
+               verifyPermutation(temp, out, tempShape, perm2);
+       }
+
+       @Test
+       public void testDifferentThreadCounts() {
+               int[] shape = {50, 50, 50};
+               int[] perm = {2, 0, 1};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               
+               MatrixBlock out1 = LibMatrixReorg.permute(in, shape, perm, 1);
+               MatrixBlock out2 = LibMatrixReorg.permute(in, shape, perm, 2);
+               MatrixBlock out4 = LibMatrixReorg.permute(in, shape, perm, 4);
+               MatrixBlock out8 = LibMatrixReorg.permute(in, shape, perm, 8);
+               
+               double[] data1 = out1.getDenseBlockValues();
+               double[] data2 = out2.getDenseBlockValues();
+               double[] data4 = out4.getDenseBlockValues();
+               double[] data8 = out8.getDenseBlockValues();
+               
+               for (int i = 0; i < data1.length; i++) {
+                       Assert.assertEquals(data1[i], data2[i], 0.0001);
+                       Assert.assertEquals(data1[i], data4[i], 0.0001);
+                       Assert.assertEquals(data1[i], data8[i], 0.0001);
+               }
+       }
+
+       @Test
+       public void testPermuteAllDimensionsCyclic() {
+               int[] shape = {3, 4, 5, 2};
+               int[] perm = {1, 2, 3, 0};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+               
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @Test
+       public void testPermuteNonContiguousStrides() {
+               int[] shape = {7, 11, 13};
+               int[] perm = {2, 0, 1};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+               
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       @Test
+       public void testPermuteLargePrimeStrides() {
+               int[] shape = {17, 19};
+               int[] perm = {1, 0};
+               
+               MatrixBlock in = generateMatrixBlock(shape);
+               MatrixBlock out = LibMatrixReorg.permute(in, shape, perm);
+               
+               verifyPermutation(in, out, shape, perm);
+       }
+
+       private MatrixBlock generateMatrixBlock(int[] shape) {
+               long len = 1;
+               for (int d : shape) len *= d;
+               
+               MatrixBlock mb = new MatrixBlock(1, (int)len, false);
+               mb.allocateDenseBlock();
+               double[] data = mb.getDenseBlockValues();
+               for (int i = 0; i < data.length; i++) {
+                       data[i] = (double) i;
+               }
+               return mb;
+       }
+
+       private void verifyPermutation(MatrixBlock in, MatrixBlock out, int[] 
inShape, int[] perm) {
+               double[] inData = new double[(int)(in.getNumRows() * 
in.getNumColumns())];
+               double[] outData = new double[(int)(out.getNumRows() * 
out.getNumColumns())];
+               
+               DenseBlock inDB = in.getDenseBlock();
+               DenseBlock outDB = out.getDenseBlock();
+               
+               if (inDB != null) {
+                       int inBlockSize = inDB.blockSize();
+                       for (int i = 0; i < inDB.numBlocks(); i++) {
+                               double[] block = inDB.valuesAt(i);
+                               int offset = i * inBlockSize;
+                               int len = Math.min(inBlockSize, inData.length - 
offset);
+                               System.arraycopy(block, 0, inData, offset, len);
+                       }
+               }
+               
+               if (outDB != null) {
+                       int outBlockSize = outDB.blockSize();
+                       for (int i = 0; i < outDB.numBlocks(); i++) {
+                               double[] block = outDB.valuesAt(i);
+                               int offset = i * outBlockSize;
+                               int len = Math.min(outBlockSize, outData.length 
- offset);
+                               System.arraycopy(block, 0, outData, offset, 
len);
+                       }
+               }
+               
+               int rank = inShape.length;
+               int[] outShape = new int[rank];
+               for (int i = 0; i < rank; i++) 
+                       outShape[i] = inShape[perm[i]];
+
+               long[] outStrides = getStrides(outShape);
+               long[] inStrides = getStrides(inShape);
+
+               long len = 1;
+               for (int d : outShape) len *= d;
+
+               for (long i = 0; i < len; i++) {
+                       int[] outCoords = new int[rank];
+                       long temp = i;
+                       for (int d = 0; d < rank; d++) {
+                               outCoords[d] = (int)(temp / outStrides[d]);
+                               temp = temp % outStrides[d];
+                       }
+
+                       int[] inCoords = new int[rank];
+                       for (int d = 0; d < rank; d++) {
+                               inCoords[perm[d]] = outCoords[d];
+                       }
+                       
+                       long inIndex = 0;
+                       for (int d = 0; d < rank; d++) {
+                               inIndex += inCoords[d] * inStrides[d];
+                       }
+                       
+                       double expectedValue = inData[(int)inIndex];
+                       double actualValue = outData[(int)i];
+                       
+                       if (Math.abs(expectedValue - actualValue) > 0.0001) {
+                               Assert.fail("Mismatch at linear output index " 
+ i + 
+                                                       ". Output coords " + 
Arrays.toString(outCoords) + 
+                                                       ". Input coords " + 
Arrays.toString(inCoords) +
+                                                       ". Expected " + 
expectedValue + " but got " + actualValue);
+                       }
+               }
+       }
+
+       private long[] getStrides(int[] dims) {
+               long[] strides = new long[dims.length];
+               long stride = 1;
+               for (int i = dims.length - 1; i >= 0; i--) {
+                       strides[i] = stride;
+                       stride *= dims[i];
+               }
+               return strides;
+       }
+}
\ No newline at end of file
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
 
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
index 50e79279b6..45457eae6e 100644
--- 
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/TransposeInPlaceBrennerTest.java
@@ -301,7 +301,7 @@ public class TransposeInPlaceBrennerTest {
        }
 
        @Test
-       public void testTensorPermute_SingleElement() {
+       public void testTensorPermuteSingleElement() {
                int[] shape = {1, 1, 1};
                int[] perm = {2, 1, 0};
                testTransposeInPlaceTensor(shape, perm);
@@ -403,7 +403,7 @@ public class TransposeInPlaceBrennerTest {
        // 9. Invalid permutations (negative tests)
        // NOTE: more detailed error handling can be added in the future, 
currently these are just checking for exceptions
        @Test
-       public void testTensorPermute_InvalidPerm_OutOfRange() {
+       public void testTensorPermuteInvalidPermOutOfRange() {
                int[] shape = {2, 3, 4};
                int[] perm = {0, 1, 3}; // 3 is out of range for 3D tensor
                
@@ -647,7 +647,7 @@ public class TransposeInPlaceBrennerTest {
                }
                
                @Test
-               public void testTensorPermuteSplitShape_2D() {
+               public void testTensorPermuteSplitShape2D() {
                int[] shape = {2,3}; 
                int[] perm = {1,0}; 
                
diff --git 
a/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java
 
b/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java
new file mode 100644
index 0000000000..d7e13a8b56
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.tensor;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.data.TensorBlock;
+ import java.util.Arrays;
+
+public class TransposeLinDataTest {
+
+    @Test
+    public void Testrightelem(){
+        int[] shape = {2, 3, 4};
+        TensorBlock tensor = TensorUtils.createArangeTensor(shape);
+
+        Assert.assertArrayEquals(new int[]{2, 3, 4}, tensor.getDims()); 
+        Assert.assertEquals(0.0, tensor.get(new int[]{0, 0, 0}));
+        Assert.assertEquals(23.0, tensor.get(new int[]{1, 2, 3}));
+        Assert.assertEquals(6.0, tensor.get(new int[]{0, 1, 2}));
+        Assert.assertEquals(12.0, tensor.get(new int[]{1, 0, 0}));
+        printTensor(tensor);
+
+
+        int[] permutation = {1, 0, 2};
+        TensorBlock outTensor = PermuteIt.permute(tensor, permutation); 
+        printTensor(outTensor); 
+
+        Assert.assertArrayEquals(new int[]{3, 2, 4}, outTensor.getDims()); 
+        Assert.assertEquals(0.0, outTensor.get(new int[]{0,0,0})); 
+        Assert.assertEquals(23.0, outTensor.get(new int[]{2, 1, 3})); 
+        Assert.assertEquals(12.0, outTensor.get(new int[]{0, 1, 0})); 
+        Assert.assertEquals(17.0, outTensor.get(new int[]{1, 1, 1})); 
+        
+
+        int[] second_permutation = {2, 1, 0}; 
+        TensorBlock perm2Block = PermuteIt.permute(tensor, 
second_permutation); 
+        printTensor(perm2Block); 
+
+        Assert.assertArrayEquals(new int[]{4, 3, 2}, perm2Block.getDims()); 
+        Assert.assertEquals(0.0, perm2Block.get(new int[]{0, 0, 0}));
+        Assert.assertEquals(12.0, perm2Block.get(new int[]{0, 0, 1})); 
+        Assert.assertEquals(11.0, perm2Block.get(new int[]{3, 2, 0})); 
+        Assert.assertEquals(23.0, perm2Block.get(new int[]{3, 2, 1})); 
+        
+    }
+
+    
+
+
+    public class TensorUtils {
+
+        public static TensorBlock createArangeTensor(int[] shape) {
+            TensorBlock tb = new TensorBlock(ValueType.FP64, shape);
+            tb.allocateBlock();
+            double[] counter = { 0.0 };
+            int[] currentIndices = new int[shape.length];
+            
+            fillRecursively(tb, shape, 0, currentIndices, counter);
+            
+            return tb;
+        }
+
+        private static void fillRecursively(TensorBlock tb, int[] shape, int 
dim, int[] currentIndices, double[] counter) {
+            if (dim == shape.length) {
+                tb.set(currentIndices, counter[0]);
+                counter[0]++; 
+                return;
+            }
+
+            for (int i = 0; i < shape[dim]; i++) {
+                currentIndices[dim] = i;
+
+                fillRecursively(tb, shape, dim + 1, currentIndices, counter);
+            }
+        }
+    }
+
+
+
+    public class PermuteIt {
+
+
+        public static TensorBlock permute(TensorBlock tensor, int[] 
permute_dims) { 
+
+            int anz_dims = tensor.getNumDims(); 
+            int[] dims = tensor.getDims();
+            ValueType tensorType = tensor.getValueType();
+
+            int[] out_shape = new int[anz_dims]; 
+
+            for (int idx = 0; idx < anz_dims; idx++){
+                out_shape[idx] = dims[permute_dims[idx]];
+            }
+
+            TensorBlock outTensor = new TensorBlock(tensorType, out_shape); 
+            outTensor.allocateBlock();
+
+            int[] inIndex = new int[anz_dims]; 
+            int[] outIndex = new int[anz_dims]; 
+
+            rekursion(tensor, outTensor, permute_dims, dims, 0, inIndex, 
outIndex); 
+            return outTensor; 
+        }   
+
+        public static void rekursion(TensorBlock inTensor, 
+                                     TensorBlock outTensor, 
+                                     int[] permutation, 
+                                     int[] inShape, 
+                                     int dim, 
+                                     int[] inIndex, 
+                                     int[]outIndex
+                                     ){
+
+            if (dim == inShape.length) {
+                for(int idx = 0; idx < permutation.length; idx++){
+                    outIndex[idx] = inIndex[permutation[idx]]; 
+                }
+                double val = (double) inTensor.get(inIndex); 
+                outTensor.set(outIndex, val); 
+                return; 
+            }
+
+            for(int idx = 0; idx < inShape[dim]; idx++){
+                inIndex[dim] = idx; 
+                rekursion(inTensor, outTensor, permutation, inShape, dim+1, 
inIndex, outIndex);
+            }
+            
+        }
+
+    }
+   
+
+    public static void printTensor(TensorBlock tb) {
+        StringBuilder sb = new StringBuilder();
+        int[] shape = tb.getDims();
+        int[] currentIndices = new int[shape.length];
+        
+        sb.append("Tensor(").append(Arrays.toString(shape)).append("):\n");
+        printRecursive(tb, shape, 0, currentIndices, sb, 0);
+        
+        System.out.println(sb.toString());
+    }
+
+    private static void printRecursive(TensorBlock tb, int[] shape, int dim, 
int[] indices, StringBuilder sb, int indent) {
+        for (int k = 0; k < indent; k++) sb.append(" ");
+
+        sb.append("[");
+
+        if (dim == shape.length - 1) {
+            for (int i = 0; i < shape[dim]; i++) {
+                indices[dim] = i;
+                double val = (double) tb.get(indices); 
+                sb.append(String.format("%.1f", val)); 
+                if (i < shape[dim] - 1) sb.append(", ");
+            }
+            sb.append("]");
+        } 
+
+        else {
+            sb.append("\n");
+            for (int i = 0; i < shape[dim]; i++) {
+                indices[dim] = i;
+                printRecursive(tb, shape, dim + 1, indices, sb, indent + 2);
+                
+                if (i < shape[dim] - 1) {
+                    sb.append(",");
+                    sb.append("\n"); 
+                    if (shape.length - dim > 2) sb.append("\n"); 
+                }
+            }
+            sb.append("\n"); 
+            for (int k = 0; k < indent; k++) sb.append(" ");
+            sb.append("]");
+        }
+    }
+
+}
\ No newline at end of file

Reply via email to