This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new c29ac3e  [SYSTEMDS-2509] Fix missing binning support in spark 
transformencode
c29ac3e is described below

commit c29ac3eeb6bc9590bdc89b77bf2616ee9ca133c8
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Jan 29 19:34:56 2021 +0100

    [SYSTEMDS-2509] Fix missing binning support in spark transformencode
    
    So far, the distributed spark transformencode only supported recoding,
    dummy coding, mv imputation, and NaN omission but lacked proper support
    for binning which required similar to recoding dedicated meta data
    handling for bin boundaries. This patch adds the missing support while
    preserving the general structure of the spark transformencode
    instruction.
    
    Furthermore, this patch also cleans up the incorrectly updated scripts
    for federated lmPredict.
---
 docs/site/builtins-reference.md                    |   4 +-
 ...ltiReturnParameterizedBuiltinSPInstruction.java | 121 ++++++++++++++-------
 .../sysds/runtime/transform/encode/Encoder.java    |  16 +++
 .../sysds/runtime/transform/encode/EncoderBin.java |  72 ++++++++++--
 .../runtime/transform/encode/EncoderComposite.java |  31 +++++-
 .../runtime/transform/encode/EncoderRecode.java    |   6 +-
 .../transform/TransformFrameEncodeApplyTest.java   |  32 +++---
 .../functions/federated/FederatedLmPipeline.dml    |   2 +-
 .../federated/FederatedLmPipeline4Workers.dml      |   2 +-
 .../FederatedLmPipeline4WorkersReference.dml       |   2 +-
 .../federated/FederatedLmPipelineReference.dml     |   2 +-
 11 files changed, 218 insertions(+), 72 deletions(-)

diff --git a/docs/site/builtins-reference.md b/docs/site/builtins-reference.md
index 8263c10..a84ed96 100644
--- a/docs/site/builtins-reference.md
+++ b/docs/site/builtins-reference.md
@@ -1,6 +1,6 @@
 ---
 layout: site
-title: Buildin Reference
+title: Built-in Reference
 ---
 <!--
 {% comment %}
@@ -70,7 +70,7 @@ limitations under the License.
 The DML (Declarative Machine Learning) language has built-in functions which 
enable access to both low- and high-level functions
 to support all kinds of use cases.
 
-A builtin ir either implemented on a compiler level or as DML scripts that are 
loaded at compile time.
+A builtin is either implemented on a compiler level or as DML scripts that are 
loaded at compile time.
 
 # Built-In Construction Functions
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
index 5dd7b7f..9dd4e6d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
@@ -29,6 +29,7 @@ import org.apache.spark.util.AccumulatorV2;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Lop;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -50,6 +51,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.transform.encode.Encoder;
+import org.apache.sysds.runtime.transform.encode.EncoderBin;
 import org.apache.sysds.runtime.transform.encode.EncoderComposite;
 import org.apache.sysds.runtime.transform.encode.EncoderFactory;
 import org.apache.sysds.runtime.transform.encode.EncoderMVImpute;
@@ -57,10 +59,12 @@ import 
org.apache.sysds.runtime.transform.encode.EncoderMVImpute.MVMethod;
 import org.apache.sysds.runtime.transform.encode.EncoderRecode;
 import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
 import org.apache.sysds.runtime.transform.meta.TfOffsetMap;
+
 import scala.Tuple2;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
@@ -113,14 +117,14 @@ public class MultiReturnParameterizedBuiltinSPInstruction 
extends ComputationSPI
                                in.lookup(1L).get(0).getColumnNames() : null; 
                        
                        //step 1: build transform meta data
-                       Encoder encoderBuild = 
EncoderFactory.createEncoder(spec, colnames,
-                               fo.getSchema(), (int)fo.getNumColumns(), null);
+                       EncoderComposite encoderBuild = (EncoderComposite) 
EncoderFactory
+                               .createEncoder(spec, colnames, fo.getSchema(), 
(int)fo.getNumColumns(), null);
                        
-                       MaxLongAccumulator accMax = 
registerMaxLongAccumulator(sec.getSparkContext()); 
+                       MaxLongAccumulator accMax = 
registerMaxLongAccumulator(sec.getSparkContext());
                        JavaRDD<String> rcMaps = in
                                .mapPartitionsToPair(new 
TransformEncodeBuildFunction(encoderBuild))
                                .distinct().groupByKey()
-                               .flatMap(new 
TransformEncodeGroupFunction(accMax));
+                               .flatMap(new 
TransformEncodeGroupFunction(encoderBuild, accMax));
                        if( containsMVImputeEncoder(encoderBuild) ) {
                                EncoderMVImpute mva = 
getMVImputeEncoder(encoderBuild);
                                rcMaps = rcMaps.union(
@@ -241,34 +245,47 @@ public class MultiReturnParameterizedBuiltinSPInstruction 
extends ComputationSPI
        {
                private static final long serialVersionUID = 
6336375833412029279L;
 
-               private EncoderRecode _raEncoder = null;
+               private EncoderComposite _encoder = null;
                
-               public TransformEncodeBuildFunction(Encoder encoder) {
-                       for( Encoder cEncoder : 
((EncoderComposite)encoder).getEncoders() )
-                               if( cEncoder instanceof EncoderRecode )
-                                       _raEncoder = (EncoderRecode)cEncoder;
+               public TransformEncodeBuildFunction(EncoderComposite encoder) {
+                       _encoder = encoder;
                }
                
                @Override
                public Iterator<Tuple2<Integer, Object>> 
call(Iterator<Tuple2<Long, FrameBlock>> iter)
                        throws Exception 
                {
-                       //build meta data (e.g., recode maps)
-                       if( _raEncoder != null ) {
-                               _raEncoder.prepareBuildPartial();
-                               while( iter.hasNext() )
-                                       
_raEncoder.buildPartial(iter.next()._2());
-                       }
+                       //build meta data (e.g., recoding recode maps and 
binning min/max)
+                       _encoder.prepareBuildPartial();
+                       while( iter.hasNext() )
+                               _encoder.buildPartial(iter.next()._2());
                        
-                       //output recode maps as columnID - token pairs
+                       //encoder-specific outputs
+                       EncoderRecode raEncoder = 
(EncoderRecode)_encoder.getEncoder(EncoderRecode.class);
+                       EncoderBin baEncoder = 
(EncoderBin)_encoder.getEncoder(EncoderBin.class);
                        ArrayList<Tuple2<Integer,Object>> ret = new 
ArrayList<>();
-                       HashMap<Integer,HashSet<Object>> tmp = 
_raEncoder.getCPRecodeMapsPartial();
-                       for( Entry<Integer,HashSet<Object>> e1 : tmp.entrySet() 
)
-                               for( Object token : e1.getValue() )
-                                       ret.add(new Tuple2<>(e1.getKey(), 
token));
-                       if( _raEncoder != null )
-                               _raEncoder.getCPRecodeMapsPartial().clear();
-               
+                       
+                       //output recode maps as columnID - token pairs
+                       if( raEncoder != null ) {
+                               HashMap<Integer,HashSet<Object>> tmp = 
raEncoder.getCPRecodeMapsPartial();
+                               for( Entry<Integer,HashSet<Object>> e1 : 
tmp.entrySet() )
+                                       for( Object token : e1.getValue() )
+                                               ret.add(new 
Tuple2<>(e1.getKey(), token));
+                               if( raEncoder != null )
+                                       
raEncoder.getCPRecodeMapsPartial().clear();
+                       }
+                       
+                       //output binning column min/max as columnID - min/max 
pairs
+                       if( baEncoder != null ) {
+                               int[] colIDs = baEncoder.getColList();
+                               double[] colMins = baEncoder.getColMins();
+                               double[] colMaxs = baEncoder.getColMaxs();
+                               for(int j=0; j<colIDs.length; j++) {
+                                       ret.add(new Tuple2<>(colIDs[j], 
String.valueOf(colMins[j])));
+                                       ret.add(new Tuple2<>(colIDs[j], 
String.valueOf(colMaxs[j])));
+                               }
+                       }
+                       
                        return ret.iterator();
                }
        }
@@ -285,9 +302,11 @@ public class MultiReturnParameterizedBuiltinSPInstruction 
extends ComputationSPI
        {
                private static final long serialVersionUID = 
-1034187226023517119L;
 
-               private MaxLongAccumulator _accMax = null;
+               private final EncoderComposite _encoder;
+               private final MaxLongAccumulator _accMax;
                
-               public TransformEncodeGroupFunction( MaxLongAccumulator accMax 
) {
+               public TransformEncodeGroupFunction(EncoderComposite encoder, 
MaxLongAccumulator accMax) {
+                       _encoder = encoder;
                        _accMax = accMax;
                }
                
@@ -295,22 +314,48 @@ public class MultiReturnParameterizedBuiltinSPInstruction 
extends ComputationSPI
                public Iterator<String> call(Tuple2<Integer, Iterable<Object>> 
arg0)
                        throws Exception 
                {
-                       String colID = String.valueOf(arg0._1());
+                       String scolID = String.valueOf(arg0._1());
+                       int colID = Integer.parseInt(scolID);
                        Iterator<Object> iter = arg0._2().iterator();
-                       
                        ArrayList<String> ret = new ArrayList<>();
-                       StringBuilder sb = new StringBuilder();
+                       
                        long rowID = 1;
-                       while( iter.hasNext() ) {
-                               sb.append(rowID);
-                               sb.append(' ');
-                               sb.append(colID);
-                               sb.append(' ');
-                               sb.append(EncoderRecode.constructRecodeMapEntry(
-                                               iter.next().toString(), rowID));
-                               ret.add(sb.toString());
-                               sb.setLength(0); 
-                               rowID++;
+                       StringBuilder sb = new StringBuilder();
+                       
+                       //handle recode maps
+                       if( _encoder.isEncoder(colID, EncoderRecode.class) ) {
+                               while( iter.hasNext() ) {
+                                       sb.append(rowID).append(' 
').append(scolID).append(' ');
+                                       
sb.append(EncoderRecode.constructRecodeMapEntry(iter.next().toString(), rowID));
+                                       ret.add(sb.toString());
+                                       sb.setLength(0); 
+                                       rowID++;
+                               }
+                       }
+                       //handle bin boundaries
+                       else if( _encoder.isEncoder(colID, EncoderBin.class) ) {
+                               EncoderBin baEncoder = 
(EncoderBin)_encoder.getEncoder(EncoderBin.class);
+                               double min = Double.MAX_VALUE;
+                               double max = -Double.MAX_VALUE;
+                               while( iter.hasNext() ) {
+                                       double value = 
Double.parseDouble(iter.next().toString());
+                                       min = Math.min(min, value);
+                                       max = Math.max(max, value);
+                               }
+                               int j = 
Arrays.binarySearch(baEncoder.getColList(), colID);
+                               baEncoder.computeBins(j, min, max);
+                               double[] binMins = baEncoder.getBinMins(j);
+                               double[] binMaxs = baEncoder.getBinMaxs(j);
+                               for(int i=0; i<binMins.length; i++) {
+                                       sb.append(rowID).append(' 
').append(scolID).append(' ');
+                                       
sb.append(binMins[i]).append(Lop.DATATYPE_PREFIX).append(binMaxs[i]);
+                                       ret.add(sb.toString());
+                                       sb.setLength(0);
+                                       rowID++;
+                               }
+                       }
+                       else {
+                               throw new DMLRuntimeException("Unsupported 
metadata output for encoder: \n"+_encoder);
                        }
                        _accMax.add(rowID-1);
                        
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
index 0758620..784e4d6 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
@@ -119,6 +119,22 @@ public abstract class Encoder implements Serializable
        public abstract void build(FrameBlock in);
        
        /**
+        * Allocates internal data structures for partial build.
+        */
+       public void prepareBuildPartial() {
+               //do nothing
+       }
+       
+       /**
+        * Partial build of internal data structures (e.g., in distributed 
spark operations).
+        * 
+        * @param in input frame block
+        */
+       public void buildPartial(FrameBlock in) {
+               //do nothing
+       }
+       
+       /**
         * Encode input data blockwise according to existing transform meta
         * data (transform apply).
         * 
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
index 4caee9b..cbbeb67 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
@@ -50,10 +50,14 @@ public class EncoderBin extends Encoder
        protected int[] _numBins = null;
        
        //frame transform-apply attributes
+       // a) column bin boundaries
        //TODO binMins is redundant and could be removed
        private double[][] _binMins = null;
        private double[][] _binMaxs = null;
-
+       // b) column min/max (for partial build)
+       private double[] _colMins = null;
+       private double[] _colMaxs = null;
+       
        public EncoderBin(JSONObject parsedSpec, String[] colnames, int clen, 
int minCol, int maxCol)
                throws JSONException, IOException 
        {
@@ -91,6 +95,22 @@ public class EncoderBin extends Encoder
                _binMaxs = binMaxs;
        }
        
+       public double[] getColMins() {
+               return _colMins;
+       }
+       
+       public double[] getColMaxs() {
+               return _colMaxs;
+       }
+       
+       public double[] getBinMins(int j) {
+               return _binMins[j];
+       }
+       
+       public double[] getBinMaxs(int j) {
+               return _binMaxs[j];
+       }
+       
        @Override
        public MatrixBlock encode(FrameBlock in, MatrixBlock out) {
                build(in);
@@ -101,9 +121,6 @@ public class EncoderBin extends Encoder
        public void build(FrameBlock in) {
                if ( !isApplicable() )
                        return;
-               // initialize internal transformation metadata
-               _binMins = new double[_colList.length][];
-               _binMaxs = new double[_colList.length][];
                
                // derive bin boundaries from min/max per column
                for(int j=0; j <_colList.length; j++) {
@@ -116,12 +133,49 @@ public class EncoderBin extends Encoder
                                min = Math.min(min, inVal);
                                max = Math.max(max, inVal);
                        }
-                       _binMins[j] = new double[_numBins[j]];
-                       _binMaxs[j] = new double[_numBins[j]];
-                       for(int i=0; i<_numBins[j]; i++) {
-                               _binMins[j][i] = min + i*(max-min)/_numBins[j];
-                               _binMaxs[j][i] = min + 
(i+1)*(max-min)/_numBins[j];
+                       computeBins(j, min, max);
+               }
+       }
+       
+       public void computeBins(int j, double min, double max) {
+               // ensure allocated internal transformation metadata
+               if( _binMins == null || _binMaxs == null ) {
+                       _binMins = new double[_colList.length][];
+                       _binMaxs = new double[_colList.length][];
+               }
+               _binMins[j] = new double[_numBins[j]];
+               _binMaxs[j] = new double[_numBins[j]];
+               for(int i=0; i<_numBins[j]; i++) {
+                       _binMins[j][i] = min + i*(max-min)/_numBins[j];
+                       _binMaxs[j][i] = min + (i+1)*(max-min)/_numBins[j];
+               }
+       }
+       
+       public void prepareBuildPartial() {
+               //ensure allocated min/max arrays
+               if( _colMins == null ) {
+                       _colMins = new double[_colList.length];
+                       _colMaxs = new double[_colList.length];
+               }
+       }
+       
+       public void buildPartial(FrameBlock in) {
+               if ( !isApplicable() )
+                       return;
+               
+               // derive bin boundaries from min/max per column
+               for(int j=0; j <_colList.length; j++) {
+                       double min = Double.POSITIVE_INFINITY;
+                       double max = Double.NEGATIVE_INFINITY;
+                       int colID = _colList[j];
+                       for( int i=0; i<in.getNumRows(); i++ ) {
+                               double inVal = UtilFunctions.objectToDouble(
+                                       in.getSchema()[colID-1], in.get(i, 
colID-1));
+                               min = Math.min(min, inVal);
+                               max = Math.max(max, inVal);
                        }
+                       _colMins[j] = min;
+                       _colMaxs[j] = max;
                }
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
index cc59932..f4923d6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
@@ -59,6 +60,22 @@ public class EncoderComposite extends Encoder
                return _encoders;
        }
        
+       public Encoder getEncoder(Class<?> type) {
+               for( Encoder encoder : _encoders ) {
+                       if( encoder.getClass().equals(type) )
+                               return encoder;
+               }
+               return null;
+       }
+       
+       public boolean isEncoder(int colID, Class<?> type) {
+               for( Encoder encoder : _encoders ) {
+                       if( encoder.getClass().equals(type) )
+                               return 
ArrayUtils.contains(encoder.getColList(), colID);
+               }
+               return false;
+       }
+       
        @Override
        public MatrixBlock encode(FrameBlock in, MatrixBlock out) {
                try {
@@ -91,7 +108,19 @@ public class EncoderComposite extends Encoder
                        encoder.build(in);
        }
        
-       @Override 
+       @Override
+       public void prepareBuildPartial() {
+               for( Encoder encoder : _encoders )
+                       encoder.prepareBuildPartial();
+       }
+       
+       @Override
+       public void buildPartial(FrameBlock in) {
+               for( Encoder encoder : _encoders )
+                       encoder.buildPartial(in);
+       }
+       
+       @Override
        public MatrixBlock apply(FrameBlock in, MatrixBlock out) {
                try {
                        for( Encoder encoder : _encoders )
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index 1dc7bf2..ada0dff 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -37,7 +37,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
 import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
 
-public class EncoderRecode extends Encoder 
+public class EncoderRecode extends Encoder
 {
        private static final long serialVersionUID = 8213163881283341874L;
        
@@ -139,13 +139,15 @@ public class EncoderRecode extends Encoder
        protected void putCode(HashMap<String,Long> map, String key) {
                map.put(key, Long.valueOf(map.size()+1));
        }
-       
+
+       @Override
        public void prepareBuildPartial() {
                //ensure allocated partial recode map
                if( _rcdMapsPart == null )
                        _rcdMapsPart = new HashMap<>();
        }
 
+       @Override
        public void buildPartial(FrameBlock in) {
                if( !isApplicable() )
                        return;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
index 92b0b2d..18f5fd7 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
@@ -133,10 +133,10 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
                runTransformTest(ExecMode.SINGLE_NODE, "csv", 
TransformType.BIN, false);
        }
        
-//     @Test
-//     public void testHomesBinningIDsSparkCSV() {
-//             runTransformTest(ExecMode.SPARK, "csv", TransformType.BIN, 
false);
-//     }
+       @Test
+       public void testHomesBinningIDsSparkCSV() {
+               runTransformTest(ExecMode.SPARK, "csv", TransformType.BIN, 
false);
+       }
        
        @Test
        public void testHomesBinningIDsHybridCSV() {
@@ -148,10 +148,10 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
                runTransformTest(ExecMode.SINGLE_NODE, "csv", 
TransformType.BIN_DUMMY, false);
        }
 
-//     @Test
-//     public void testHomesBinningDummyIDsSparkCSV() {
-//             runTransformTest(ExecMode.SPARK, "csv", 
TransformType.BIN_DUMMY, false);
-//     }
+       @Test
+       public void testHomesBinningDummyIDsSparkCSV() {
+               runTransformTest(ExecMode.SPARK, "csv", 
TransformType.BIN_DUMMY, false);
+       }
        
        @Test
        public void testHomesBinningDummyIDsHybridCSV() {
@@ -238,10 +238,10 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
                runTransformTest(ExecMode.SINGLE_NODE, "csv", 
TransformType.BIN, true);
        }
        
-//     @Test
-//     public void testHomesBinningColnamesSparkCSV() {
-//             runTransformTest(ExecMode.SPARK, "csv", TransformType.BIN, 
true);
-//     }
+       @Test
+       public void testHomesBinningColnamesSparkCSV() {
+               runTransformTest(ExecMode.SPARK, "csv", TransformType.BIN, 
true);
+       }
        
        @Test
        public void testHomesBinningColnamesHybridCSV() {
@@ -253,10 +253,10 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
                runTransformTest(ExecMode.SINGLE_NODE, "csv", 
TransformType.BIN_DUMMY, true);
        }
        
-//     @Test
-//     public void testHomesBinningDummyColnamesSparkCSV() {
-//             runTransformTest(ExecMode.SPARK, "csv", 
TransformType.BIN_DUMMY, true);
-//     }
+       @Test
+       public void testHomesBinningDummyColnamesSparkCSV() {
+               runTransformTest(ExecMode.SPARK, "csv", 
TransformType.BIN_DUMMY, true);
+       }
        
        @Test
        public void testHomesBinningDummyColnamesHybridCSV() {
diff --git a/src/test/scripts/functions/federated/FederatedLmPipeline.dml 
b/src/test/scripts/functions/federated/FederatedLmPipeline.dml
index fdad81c..957e630 100644
--- a/src/test/scripts/functions/federated/FederatedLmPipeline.dml
+++ b/src/test/scripts/functions/federated/FederatedLmPipeline.dml
@@ -47,7 +47,7 @@ X = scale(X=X, center=TRUE, scale=TRUE);
 B = lm(X=Xtrain, y=ytrain, icpt=1, reg=1e-3, tol=1e-9, verbose=TRUE)
 
 # model evaluation on test split
-yhat = lmPredict(X=Xtest, B=B, icpt=1, ytest=ytest verbose=TRUE);
+yhat = lmPredict(X=Xtest, B=B, icpt=1, ytest=ytest, verbose=TRUE);
 
 # write trained model and meta data
 write(B, $out)
diff --git 
a/src/test/scripts/functions/federated/FederatedLmPipeline4Workers.dml 
b/src/test/scripts/functions/federated/FederatedLmPipeline4Workers.dml
index dce7015..0dfdefb 100644
--- a/src/test/scripts/functions/federated/FederatedLmPipeline4Workers.dml
+++ b/src/test/scripts/functions/federated/FederatedLmPipeline4Workers.dml
@@ -49,7 +49,7 @@ X = scale(X=X, center=TRUE, scale=TRUE);
 B = lm(X=Xtrain, y=ytrain, icpt=1, reg=1e-3, tol=1e-9, verbose=TRUE)
 
 # model evaluation on test split
-yhat = lmPredict(X=Xtest, B=B, icpt=1, ytest=ytest verbose=TRUE);
+yhat = lmPredict(X=Xtest, B=B, icpt=1, ytest=ytest, verbose=TRUE);
 
 # write trained model and meta data
 write(B, $out)
diff --git 
a/src/test/scripts/functions/federated/FederatedLmPipeline4WorkersReference.dml 
b/src/test/scripts/functions/federated/FederatedLmPipeline4WorkersReference.dml
index 318f441..62e1642 100644
--- 
a/src/test/scripts/functions/federated/FederatedLmPipeline4WorkersReference.dml
+++ 
b/src/test/scripts/functions/federated/FederatedLmPipeline4WorkersReference.dml
@@ -47,7 +47,7 @@ X = scale(X=X, center=TRUE, scale=TRUE);
 B = lm(X=Xtrain, y=ytrain, icpt=1, reg=1e-3, tol=1e-9, verbose=TRUE)
 
 # model evaluation on test split
-yhat = lmPredict(X=Xtest, B=B, icpt=1, ytest=ytest verbose=TRUE);
+yhat = lmPredict(X=Xtest, B=B, icpt=1, ytest=ytest, verbose=TRUE);
 
 # write trained model and meta data
 write(B, $7)
diff --git 
a/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml 
b/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml
index 1fe5c21..21ee463 100644
--- a/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml
+++ b/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml
@@ -47,7 +47,7 @@ X = scale(X=X, center=TRUE, scale=TRUE);
 B = lm(X=Xtrain, y=ytrain, icpt=1, reg=1e-3, tol=1e-9, verbose=TRUE)
 
 # model evaluation on test split
-yhat = lmPredict(X=Xtest, B=B, icpt=1, ytest=ytest verbose=TRUE);
+yhat = lmPredict(X=Xtest, B=B, icpt=1, ytest=ytest, verbose=TRUE);
 
 # write trained model and meta data
 write(B, $7)

Reply via email to