This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit e2e560a2d6382d9b290f0e8f25e41772ad6c91ec
Author: baunsgaard <[email protected]>
AuthorDate: Wed Dec 21 13:18:27 2022 +0100

    [SYSTEMDS-3272] applySchema FrameBlock parallel
    
    This commit improve performance of applySchema through parallelization,
    from 0.8- 0.9 sec to 0.169 sec on a 64kx2k Frame block, also included
    are test with 100% test coverage of the applySchema.
---
 .../sysds/runtime/frame/data/FrameBlock.java       |  30 ++---
 .../frame/data/lib/FrameLibApplySchema.java        | 106 +++++++++++++++++
 .../frame/data/lib/FrameLibDetectSchema.java       |   2 +-
 .../test/component/frame/FrameApplySchema.java     | 129 +++++++++++++++++++--
 4 files changed, 239 insertions(+), 28 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java 
b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
index 7d2fbb5797..80c3508fea 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
@@ -51,12 +51,14 @@ import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.codegen.CodegenUtils;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.frame.data.columns.Array;
 import org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
 import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
 import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory;
 import org.apache.sysds.runtime.frame.data.lib.FrameFromMatrixBlock;
+import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema;
 import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema;
 import org.apache.sysds.runtime.functionobjects.ValueComparisonFunction;
 import org.apache.sysds.runtime.instructions.cp.BooleanObject;
@@ -151,6 +153,14 @@ public class FrameBlock implements CacheBlock<FrameBlock>, 
Externalizable {
                                appendRow(data[i]);
        }
 
+       public FrameBlock(ValueType[] schema, String[] colNames, 
ColumnMetadata[] meta, Array<?>[] data ){
+               _numRows = data[0].size();
+               _schema = schema;
+               _colnames = colNames;
+               _colmeta = meta; 
+               _coldata = data;
+       }
+
        /**
         * Get the number of rows of the frame block.
         *
@@ -279,6 +289,10 @@ public class FrameBlock implements CacheBlock<FrameBlock>, 
Externalizable {
                return _colmeta[c];
        }
 
+       public Array<?>[] getColumns(){
+               return _coldata;
+       }
+
        public boolean isColumnMetadataDefault() {
                boolean ret = true;
                for( int j=0; j<getNumColumns() && ret; j++ )
@@ -1733,21 +1747,7 @@ public class FrameBlock implements 
CacheBlock<FrameBlock>, Externalizable {
         * @return A new FrameBlock with the schema applied.
         */
        public FrameBlock applySchema(ValueType[] schema) {
-               if(schema.length != _schema.length)
-                       throw new DMLRuntimeException(//
-                               "Invalid apply schema with different number of 
columns expected: " + _schema.length + " got: "
-                                       + schema.length);
-               FrameBlock ret = new FrameBlock();
-               final int nCol = getNumColumns();
-               ret._numRows = getNumRows();
-               ret._schema = schema;
-               ret._colnames = _colnames;
-               ret._colmeta = _colmeta;
-               ret._coldata = new Array[nCol];
-               for(int i = 0; i < nCol; i++)
-                       ret._coldata[i] = _coldata[i].changeType(schema[i]);
-               ret._msize = -1;
-               return ret;
+               return FrameLibApplySchema.applySchema(this, schema, 
InfrastructureAnalyzer.getLocalParallelism());
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java
 
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java
new file mode 100644
index 0000000000..57c79a49a9
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.frame.data.lib;
+
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.stream.IntStream;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.Array;
+import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
+public class FrameLibApplySchema {
+
+       protected static final Log LOG = 
LogFactory.getLog(FrameLibApplySchema.class.getName());
+
+       private final FrameBlock fb;
+       private final ValueType[] schema;
+       private final int nCol;
+       private final Array<?>[] columnsIn;
+       private final Array<?>[] columnsOut;
+
+       /**
+        * Method to create a new FrameBlock where the given schema is applied, 
k is parallelization degree.
+        * 
+        * @param fb     The input block to apply schema to
+        * @param schema The schema to apply
+        * @param k      The parallelization degree
+        * @return A new FrameBlock allocated with new arrays.
+        */
+       public static FrameBlock applySchema(FrameBlock fb, ValueType[] schema, 
int k) {
+               return new FrameLibApplySchema(fb, schema).apply(k);
+       }
+
+       private FrameLibApplySchema(FrameBlock fb, ValueType[] schema) {
+               this.fb = fb;
+               this.schema = schema;
+               verifySize();
+               nCol = fb.getNumColumns();
+               columnsIn = fb.getColumns();
+               columnsOut = new Array<?>[nCol];
+
+       }
+
+       private FrameBlock apply(int k) {
+               if(k <= 1 || nCol == 1)
+                       applySingleThread();
+               else
+                       applyMultiThread(k);
+
+               final String[] colNames = fb.getColumnNames(false);
+               final ColumnMetadata[] meta = fb.getColumnMetadata();
+               return new FrameBlock(schema, colNames, meta, columnsOut);
+       }
+
+       private void applySingleThread() {
+               for(int i = 0; i < nCol; i++)
+                       columnsOut[i] = columnsIn[i].changeType(schema[i]);
+       }
+
+       private void applyMultiThread(int k) {
+               final ExecutorService pool = CommonThreadPool.get(k);
+               try {
+
+                       pool.submit(() -> {
+                               IntStream.rangeClosed(0, nCol - 1).parallel() 
// parallel columns
+                                       .forEach(x -> columnsOut[x] = 
columnsIn[x].changeType(schema[x]));
+                       }).get();
+
+                       pool.shutdown();
+               }
+               catch(InterruptedException | ExecutionException e) {
+                       pool.shutdown();
+                       throw new DMLRuntimeException("Failed to combine column 
groups", e);
+               }
+       }
+
+       private void verifySize() {
+               if(schema.length != fb.getSchema().length)
+                       throw new DMLRuntimeException(//
+                               "Invalid apply schema with different number of 
columns expected: " + fb.getSchema().length + " got: "
+                                       + schema.length);
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
 
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
index 8a37ac8f4d..3219617f27 100644
--- 
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
+++ 
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
@@ -34,7 +34,7 @@ import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.UtilFunctions;
 
 public final class FrameLibDetectSchema {
-       // private static final Log LOG = 
LogFactory.getLog(FrameBlock.class.getName());
+       // private static final Log LOG = 
LogFactory.getLog(FrameLibDetectSchema.class.getName());
 
        private FrameLibDetectSchema() {
                // private constructor
diff --git 
a/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java 
b/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java
index a843d9f916..fb0262daa3 100644
--- a/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java
+++ b/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java
@@ -25,37 +25,142 @@ import static org.junit.Assert.fail;
 import java.util.Random;
 
 import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
-import org.apache.sysds.runtime.frame.data.columns.BooleanArray;
+import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema;
 import org.junit.Test;
 
 public class FrameApplySchema {
 
+       @Test
+       public void testApplySchemaStringToBoolean() {
+               try {
+
+                       FrameBlock fb = genStringContainingBoolean(10, 2);
+                       ValueType[] schema = new ValueType[] 
{ValueType.BOOLEAN, ValueType.BOOLEAN};
+                       FrameBlock ret = fb.applySchema(schema);
+                       assertTrue(ret.getColumn(0).getValueType() == 
ValueType.BOOLEAN);
+                       assertTrue(ret.getColumn(1).getValueType() == 
ValueType.BOOLEAN);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
 
        @Test
-       public void testApplySchema(){
-               try{
+       public void testApplySchemaStringToInt() {
+               try {
+                       FrameBlock fb = genStringContainingInteger(10, 2);
+                       ValueType[] schema = new ValueType[] {ValueType.INT32, 
ValueType.INT32};
+                       FrameBlock ret = fb.applySchema(schema);
+                       assertTrue(ret.getColumn(0).getValueType() == 
ValueType.INT32);
+                       assertTrue(ret.getColumn(1).getValueType() == 
ValueType.INT32);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
 
-                       FrameBlock fb = genBoolean(10, 2);
-                       ValueType[] schema = new 
ValueType[]{ValueType.BOOLEAN,ValueType.BOOLEAN};
+       @Test
+       public void testApplySchemaStringToIntSingleCol() {
+               try {
+                       FrameBlock fb = genStringContainingInteger(10, 1);
+                       ValueType[] schema = new ValueType[] {ValueType.INT32};
                        FrameBlock ret = fb.applySchema(schema);
-                       assertTrue(ret.getColumn(0) instanceof BooleanArray);
-                       assertTrue(ret.getColumn(1) instanceof BooleanArray);
+                       assertTrue(ret.getColumn(0).getValueType() == 
ValueType.INT32);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @Test
+       public void testApplySchemaStringToIntDirectCallSingleThread() {
+               try {
+                       FrameBlock fb = genStringContainingInteger(10, 3);
+                       ValueType[] schema = new ValueType[] {ValueType.INT32, 
ValueType.INT32, ValueType.INT32};
+                       FrameBlock ret = FrameLibApplySchema.applySchema(fb, 
schema, 1);
+                       for(int i = 0; i < ret.getNumColumns(); i++)
+                               assertTrue(ret.getColumn(i).getValueType() == 
ValueType.INT32);
+
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @Test
+       public void testApplySchemaStringToIntDirectCallMultiThread() {
+               try {
+                       FrameBlock fb = genStringContainingInteger(10, 3);
+                       ValueType[] schema = new ValueType[] {ValueType.INT32, 
ValueType.INT32, ValueType.INT32};
+                       FrameBlock ret = FrameLibApplySchema.applySchema(fb, 
schema, 3);
+                       for(int i = 0; i < ret.getNumColumns(); i++)
+                               assertTrue(ret.getColumn(i).getValueType() == 
ValueType.INT32);
+
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
                }
-               catch(Exception e){
+       }
+
+
+       @Test
+       public void testApplySchemaStringToIntDirectCallMultiThreadSingleCol() {
+               try {
+                       FrameBlock fb = genStringContainingInteger(10, 1);
+                       ValueType[] schema = new ValueType[] {ValueType.INT32};
+                       FrameBlock ret = FrameLibApplySchema.applySchema(fb, 
schema, 3);
+                       for(int i = 0; i < ret.getNumColumns(); i++)
+                               assertTrue(ret.getColumn(i).getValueType() == 
ValueType.INT32);
+
+               }
+               catch(Exception e) {
                        e.printStackTrace();
                        fail(e.getMessage());
                }
        }
 
-       private FrameBlock genBoolean(int row, int col){
+       @Test(expected = DMLRuntimeException.class)
+       public void testInvalidInput() {
+               FrameBlock fb = genStringContainingInteger(10, 10);
+               ValueType[] schema = new ValueType[] {ValueType.INT32, 
ValueType.INT32, ValueType.INT32};
+               FrameLibApplySchema.applySchema(fb, schema, 3);
+       }
+
+       @Test(expected = DMLRuntimeException.class)
+       public void testInvalidInput2() {
+               FrameBlock fb = genStringContainingInteger(10, 3);
+               ValueType[] schema = new ValueType[] {ValueType.UNKNOWN, 
ValueType.INT32, ValueType.INT32};
+               FrameLibApplySchema.applySchema(fb, schema, 3);
+       }
+
+       private FrameBlock genStringContainingInteger(int row, int col) {
                FrameBlock ret = new FrameBlock();
                Random r = new Random(31);
-               for(int c = 0; c < col; c ++){
+               for(int c = 0; c < col; c++) {
                        String[] column = new String[row];
-                       for(int i = 0; i < row; i ++)
+                       for(int i = 0; i < row; i++)
+                               column[i] = "" + r.nextInt();
+
+                       ret.appendColumn(column);
+               }
+               return ret;
+       }
+
+       private FrameBlock genStringContainingBoolean(int row, int col) {
+               FrameBlock ret = new FrameBlock();
+               Random r = new Random(31);
+               for(int c = 0; c < col; c++) {
+                       String[] column = new String[row];
+                       for(int i = 0; i < row; i++)
                                column[i] = "" + r.nextBoolean();
-                       
+
                        ret.appendColumn(column);
                }
                return ret;

Reply via email to