This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 85fa35312c [MINOR] Fix warnings, data types, formatting of the 
federated backend
85fa35312c is described below

commit 85fa35312c3e536024d9a14738d10cb181e343c0
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Jun 4 19:44:13 2022 +0200

    [MINOR] Fix warnings, data types, formatting of the federated backend
---
 src/main/java/org/apache/sysds/common/Types.java   |   4 +-
 src/main/java/org/apache/sysds/hops/Hop.java       |   2 +-
 .../RewriteElementwiseMultChainOptimization.java   |   2 +-
 .../federated/FederatedLocalData.java              |   1 -
 .../paramserv/FederatedPSControlThread.java        |   1 -
 .../controlprogram/paramserv/HEParamServer.java    | 298 ++++++++++-----------
 .../paramserv/NetworkTrafficCounter.java           |  27 +-
 .../runtime/instructions/cp/PlaintextMatrix.java   |  16 +-
 .../sysds/utils/stats/ParamServStatistics.java     |   1 -
 .../paramserv/EncryptedFederatedParamservTest.java |   7 +-
 .../fedplanning/FederatedMultiplyPlanningTest.java |   1 -
 11 files changed, 179 insertions(+), 181 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 3dfad3413e..a7cfa823aa 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -44,7 +44,9 @@ public class Types
         * Data types (tensor, matrix, scalar, frame, object, unknown).
         */
        public enum DataType {
-               TENSOR, MATRIX, SCALAR, FRAME, LIST, ENCRYPTED_CIPHER, 
ENCRYPTED_PLAIN, UNKNOWN;
+               TENSOR, MATRIX, SCALAR, FRAME, LIST, UNKNOWN,
+               //TODO remove from Data Type -> generic object
+               ENCRYPTED_CIPHER, ENCRYPTED_PLAIN;
                
                public boolean isMatrix() {
                        return this == MATRIX;
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index 7bdb5a424e..2ee317f35e 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -813,7 +813,7 @@ public abstract class Hop implements ParseInfo {
                                
                                break;
                        }
-                       case UNKNOWN: {
+                       default: {
                                //memory estimate always unknown
                                _outputMemEstimate = 
OptimizerUtils.DEFAULT_SIZE;
                                break;
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index bbb8f0a161..d2244c8a7d 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -262,8 +262,8 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                                        case MATRIX: orderDataType[i] = 1; 
break;
                                        case TENSOR: orderDataType[i] = 2; 
break;
                                        case FRAME:  orderDataType[i] = 3; 
break;
-                                       case UNKNOWN:orderDataType[i] = 4; 
break;
                                        case LIST:   orderDataType[i] = 5; 
break;
+                                       default:     orderDataType[i] = 4; 
break;
                                }
                }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
index 77ffb7f847..de56a1a52e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
@@ -25,7 +25,6 @@ import java.util.concurrent.Future;
 import org.apache.log4j.Logger;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
 
 public class FederatedLocalData extends FederatedData {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 54d778486a..0c984698c1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -378,7 +378,6 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        @Override
        public Void call() throws Exception {
                try {
-                       Timing tTotal = new Timing(true);
                        switch (_freq) {
                                case BATCH:
                                        computeWithBatchUpdates();
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
index 577bf6c820..4e873abdb6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
@@ -29,7 +29,6 @@ import 
org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix;
-import org.apache.sysds.utils.NativeHelper;
 import org.apache.sysds.utils.stats.ParamServStatistics;
 
 import java.util.ArrayList;
@@ -43,152 +42,153 @@ import java.util.stream.IntStream;
  * This class implements Homomorphic Encryption (HE) for LocalParamServer. It 
only supports modelAvg=true.
  */
 public class HEParamServer extends LocalParamServer {
-    private int _thread_counter = 0;
-    private final List<FederatedPSControlThread> _threads;
-    private final List<Object> _result_buffer; // one per thread
-    private Object _result;
-    private final SEALServer _seal_server;
-
-    public static HEParamServer create(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
-                                          Statement.PSFrequency freq, 
ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
-                                          MatrixObject valFeatures, 
MatrixObject valLabels, int nbatches)
-    {
-        NativeHEHelper.initialize();
-        return new HEParamServer(model, aggFunc, updateType, freq, ec,
-                workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels, nbatches);
-    }
-
-    private HEParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
-                             Statement.PSFrequency freq, ExecutionContext ec, 
int workerNum, String valFunc, int numBatchesPerEpoch,
-                             MatrixObject valFeatures, MatrixObject valLabels, 
int nbatches)
-    {
-        super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, 
numBatchesPerEpoch, valFeatures, valLabels, nbatches, true);
-
-        _seal_server = new SEALServer();
-
-        _threads = Collections.synchronizedList(new ArrayList<>(workerNum));
-        for (int i = 0; i < getNumWorkers(); i++) {
-            _threads.add(null);
-        }
-
-        _result_buffer = new ArrayList<>(workerNum);
-        resetResultBuffer();
-    }
-
-    public void registerThread(int thread_id, FederatedPSControlThread thread) 
{
-        _threads.set(thread_id, thread);
-    }
-
-    private synchronized void resetResultBuffer() {
-        _result_buffer.clear();
-        for (int i = 0; i < getNumWorkers(); i++) {
-            _result_buffer.add(null);
-        }
-    }
-
-    public byte[] generateA() {
-        return _seal_server.generateA();
-    }
-
-    public PublicKey aggregatePartialPublicKeys(PublicKey[] 
partial_public_keys) {
-        return _seal_server.aggregatePartialPublicKeys(partial_public_keys);
-    }
-
-    /**
-     * this method collects all T Objects from each worker into a list and 
then calls f once on this list to produce
-     * another T, which it returns.
-     */
-    private synchronized <T,U> U collectAndDo(int workerId, T obj, 
Function<List<T>, U> f) {
-        _result_buffer.set(workerId, obj);
-        _thread_counter++;
-
-        if (_thread_counter == getNumWorkers()) {
-            List<T> buf = _result_buffer.stream().map(x -> 
(T)x).collect(Collectors.toList());
-            _result = f.apply(buf);
-            resetResultBuffer();
-            _thread_counter = 0;
-            notifyAll();
-        } else {
-            try {
-                wait();
-            } catch (InterruptedException i) {
-                throw new RuntimeException("thread interrupted");
-            }
-        }
-
-        return (U) _result;
-    }
-
-    private CiphertextMatrix[] homomorphicAggregation(List<ListObject> 
encrypted_models) {
-        Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
-        CiphertextMatrix[] result = new 
CiphertextMatrix[encrypted_models.get(0).getLength()];
-        IntStream.range(0, 
encrypted_models.get(0).getLength()).forEach(matrix_idx -> {
-            CiphertextMatrix[] summands = new 
CiphertextMatrix[encrypted_models.size()];
-            for (int i = 0; i < encrypted_models.size(); i++) {
-                summands[i] = (CiphertextMatrix) 
encrypted_models.get(i).getData(matrix_idx);
-            }
-            result[matrix_idx] = _seal_server.accumulateCiphertexts(summands);;
-        });
-        if (tAgg != null) {
-            ParamServStatistics.accHEAccumulation((long)tAgg.stop());
-        }
-        return result;
-    }
-
-    private Void homomorphicAverage(CiphertextMatrix[] encrypted_sums, 
List<PlaintextMatrix[]> partial_decryptions) {
-        Timing tDecrypt = DMLScript.STATISTICS ? new Timing(true) : null;
-
-        MatrixObject[] result = new 
MatrixObject[partial_decryptions.get(0).length];
-
-        IntStream.range(0, 
partial_decryptions.get(0).length).forEach(matrix_idx -> {
-            PlaintextMatrix[] partial_plaintexts = new 
PlaintextMatrix[partial_decryptions.size()];
-            for (int i = 0; i < partial_decryptions.size(); i++) {
-                partial_plaintexts[i] = partial_decryptions.get(i)[matrix_idx];
-            }
-
-            result[matrix_idx] = 
_seal_server.average(encrypted_sums[matrix_idx], partial_plaintexts);
-        });
-
-        ListObject old_model = getResult();
-        ListObject new_model = new ListObject(old_model);
-        for (int i = 0; i < new_model.getLength(); i++) {
-            new_model.set(i, result[i]);
-        }
-
-        if (tDecrypt != null) {
-            ParamServStatistics.accHEDecryptionTime((long)tDecrypt.stop());
-        }
-
-        updateAndBroadcastModel(new_model, null);
-        return null;
-    }
-
-    // this is only to be used in push()
-    private Timing commTimer;
-    private void startCommTimer() {
-        commTimer = new Timing(true);
-    }
-    private long stopCommTimer() {
-        return (long)commTimer.stop();
-    }
-    // ---------------------------------
-
-    @Override
-    public void push(int workerID, ListObject encrypted_model) {
-        // wait for all updates and sum them homomorphically
-        CiphertextMatrix[] homomorphic_sum = collectAndDo(workerID, 
encrypted_model, x -> {
-            CiphertextMatrix[] res = this.homomorphicAggregation(x);
-            this.startCommTimer();
-            return res;
-        });
-
-        // get partial decryptions
-        PlaintextMatrix[] partial_decryption = 
_threads.get(workerID).getPartialDecryption(homomorphic_sum);
-
-        // do average and update global model
-        collectAndDo(workerID, partial_decryption, x -> {
-            ParamServStatistics.accFedNetworkTime(this.stopCommTimer());
-            return this.homomorphicAverage(homomorphic_sum, x);
-        });
-    }
+       private int _thread_counter = 0;
+       private final List<FederatedPSControlThread> _threads;
+       private final List<Object> _result_buffer; // one per thread
+       private Object _result;
+       private final SEALServer _seal_server;
+
+       public static HEParamServer create(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
+               Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch,
+               MatrixObject valFeatures, MatrixObject valLabels, int nbatches)
+       {
+               NativeHEHelper.initialize();
+               return new HEParamServer(model, aggFunc, updateType, freq, ec,
+                               workerNum, valFunc, numBatchesPerEpoch, 
valFeatures, valLabels, nbatches);
+       }
+
+       private HEParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
+               Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc,
+               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels, int nbatches)
+       {
+               super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, 
numBatchesPerEpoch, valFeatures, valLabels, nbatches, true);
+
+               _seal_server = new SEALServer();
+
+               _threads = Collections.synchronizedList(new 
ArrayList<>(workerNum));
+               for (int i = 0; i < getNumWorkers(); i++) {
+                       _threads.add(null);
+               }
+
+               _result_buffer = new ArrayList<>(workerNum);
+               resetResultBuffer();
+       }
+
+       public void registerThread(int thread_id, FederatedPSControlThread 
thread) {
+               _threads.set(thread_id, thread);
+       }
+
+       private synchronized void resetResultBuffer() {
+               _result_buffer.clear();
+               for (int i = 0; i < getNumWorkers(); i++) {
+                       _result_buffer.add(null);
+               }
+       }
+
+       public byte[] generateA() {
+               return _seal_server.generateA();
+       }
+
+       public PublicKey aggregatePartialPublicKeys(PublicKey[] 
partial_public_keys) {
+               return 
_seal_server.aggregatePartialPublicKeys(partial_public_keys);
+       }
+
+       /**
+        * this method collects all T Objects from each worker into a list and 
then calls f once on this list to produce
+        * another T, which it returns.
+        */
+       @SuppressWarnings("unchecked")
+       private synchronized <T,U> U collectAndDo(int workerId, T obj, 
Function<List<T>, U> f) {
+               _result_buffer.set(workerId, obj);
+               _thread_counter++;
+
+               if (_thread_counter == getNumWorkers()) {
+                       List<T> buf = _result_buffer.stream().map(x -> 
(T)x).collect(Collectors.toList());
+                       _result = f.apply(buf);
+                       resetResultBuffer();
+                       _thread_counter = 0;
+                       notifyAll();
+               } else {
+                       try {
+                               wait();
+                       } catch (InterruptedException i) {
+                               throw new RuntimeException("thread 
interrupted");
+                       }
+               }
+
+               return (U) _result;
+       }
+
+       private CiphertextMatrix[] homomorphicAggregation(List<ListObject> 
encrypted_models) {
+               Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
+               CiphertextMatrix[] result = new 
CiphertextMatrix[encrypted_models.get(0).getLength()];
+               IntStream.range(0, 
encrypted_models.get(0).getLength()).forEach(matrix_idx -> {
+                       CiphertextMatrix[] summands = new 
CiphertextMatrix[encrypted_models.size()];
+                       for (int i = 0; i < encrypted_models.size(); i++) {
+                               summands[i] = (CiphertextMatrix) 
encrypted_models.get(i).getData(matrix_idx);
+                       }
+                       result[matrix_idx] = 
_seal_server.accumulateCiphertexts(summands);;
+               });
+               if (tAgg != null) {
+                       
ParamServStatistics.accHEAccumulation((long)tAgg.stop());
+               }
+               return result;
+       }
+
+       private Void homomorphicAverage(CiphertextMatrix[] encrypted_sums, 
List<PlaintextMatrix[]> partial_decryptions) {
+               Timing tDecrypt = DMLScript.STATISTICS ? new Timing(true) : 
null;
+
+               MatrixObject[] result = new 
MatrixObject[partial_decryptions.get(0).length];
+
+               IntStream.range(0, 
partial_decryptions.get(0).length).forEach(matrix_idx -> {
+                       PlaintextMatrix[] partial_plaintexts = new 
PlaintextMatrix[partial_decryptions.size()];
+                       for (int i = 0; i < partial_decryptions.size(); i++) {
+                               partial_plaintexts[i] = 
partial_decryptions.get(i)[matrix_idx];
+                       }
+
+                       result[matrix_idx] = 
_seal_server.average(encrypted_sums[matrix_idx], partial_plaintexts);
+               });
+
+               ListObject old_model = getResult();
+               ListObject new_model = new ListObject(old_model);
+               for (int i = 0; i < new_model.getLength(); i++) {
+                       new_model.set(i, result[i]);
+               }
+
+               if (tDecrypt != null) {
+                       
ParamServStatistics.accHEDecryptionTime((long)tDecrypt.stop());
+               }
+
+               updateAndBroadcastModel(new_model, null);
+               return null;
+       }
+
+       // this is only to be used in push()
+       private Timing commTimer;
+       private void startCommTimer() {
+               commTimer = new Timing(true);
+       }
+       private long stopCommTimer() {
+               return (long)commTimer.stop();
+       }
+       // ---------------------------------
+
+       @Override
+       public void push(int workerID, ListObject encrypted_model) {
+               // wait for all updates and sum them homomorphically
+               CiphertextMatrix[] homomorphic_sum = collectAndDo(workerID, 
encrypted_model, x -> {
+                       CiphertextMatrix[] res = this.homomorphicAggregation(x);
+                       this.startCommTimer();
+                       return res;
+               });
+
+               // get partial decryptions
+               PlaintextMatrix[] partial_decryption = 
_threads.get(workerID).getPartialDecryption(homomorphic_sum);
+
+               // do average and update global model
+               collectAndDo(workerID, partial_decryption, x -> {
+                       
ParamServStatistics.accFedNetworkTime(this.stopCommTimer());
+                       return this.homomorphicAverage(homomorphic_sum, x);
+               });
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java
index f823b9d3be..9c353c3258 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NetworkTrafficCounter.java
@@ -19,24 +19,23 @@
 
 package org.apache.sysds.runtime.controlprogram.paramserv;
 
-import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.handler.traffic.ChannelTrafficShapingHandler;
 import java.util.function.BiConsumer;
 
 public class NetworkTrafficCounter extends ChannelTrafficShapingHandler {
-    private final BiConsumer<Long, Long> _fn; // (read, written) -> Void, logs 
bytes read and written
-    public NetworkTrafficCounter(BiConsumer<Long, Long> fn) {
-        // checkInterval of zero means that doAccounting will not be called
-        super( 0);
-        _fn = fn;
-    }
+       private final BiConsumer<Long, Long> _fn; // (read, written) -> Void, 
logs bytes read and written
+       public NetworkTrafficCounter(BiConsumer<Long, Long> fn) {
+               // checkInterval of zero means that doAccounting will not be 
called
+               super( 0);
+               _fn = fn;
+       }
 
-    // log bytes read/written after channel is closed
-    @Override
-    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
-        _fn.accept(trafficCounter.cumulativeReadBytes(), 
trafficCounter.cumulativeWrittenBytes());
-        trafficCounter.resetCumulativeTime();
-        super.channelInactive(ctx);
-    }
+       // log bytes read/written after channel is closed
+       @Override
+       public void channelInactive(ChannelHandlerContext ctx) throws Exception 
{
+               _fn.accept(trafficCounter.cumulativeReadBytes(), 
trafficCounter.cumulativeWrittenBytes());
+               trafficCounter.resetCumulativeTime();
+               super.channelInactive(ctx);
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java
index 6fe2b3814f..d36d40bc8e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PlaintextMatrix.java
@@ -26,14 +26,14 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
  * This class abstracts over an encrypted matrix of ciphertexts. It stores the 
data as opaque byte array. The layout is unspecified.
  */
 public class PlaintextMatrix extends Encrypted {
-    private static final long serialVersionUID = 5732436872261940616L;
+       private static final long serialVersionUID = 5732436872261940616L;
 
-    public PlaintextMatrix(int[] dims, DataCharacteristics dc, byte[] data) {
-        super(dims, dc, data, Types.DataType.ENCRYPTED_PLAIN);
-    }
+       public PlaintextMatrix(int[] dims, DataCharacteristics dc, byte[] data) 
{
+               super(dims, dc, data, Types.DataType.ENCRYPTED_PLAIN);
+       }
 
-    @Override
-    public String getDebugName() {
-        return "PlaintextMatrix " + getData().hashCode();
-    }
+       @Override
+       public String getDebugName() {
+               return "PlaintextMatrix " + getData().hashCode();
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java 
b/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java
index 8eb26a1963..3edf7bb77e 100644
--- a/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/ParamServStatistics.java
@@ -21,7 +21,6 @@ package org.apache.sysds.utils.stats;
 
 import java.util.concurrent.atomic.LongAdder;
 
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 
 public class ParamServStatistics {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
index 250358d408..ca50338e33 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
@@ -25,7 +25,6 @@ import java.util.Collection;
 import java.util.List;
 
 import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.hops.codegen.SpoofCompiler;
 import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -94,8 +93,10 @@ public class EncryptedFederatedParamservTest extends 
AutomatedTestBase {
                });
        }
 
-       public EncryptedFederatedParamservTest(String networkType, int 
numFederatedWorkers, int dataSetSize, int batch_size,
-                                                                               
  int epochs, double eta, String utype, String freq, String scheme, String 
runtime_balancing, String weighting, String data_distribution, int seed) {
+       public EncryptedFederatedParamservTest(String networkType, int 
numFederatedWorkers,
+               int dataSetSize, int batch_size, int epochs, double eta, String 
utype, String freq,
+               String scheme, String runtime_balancing, String weighting, 
String data_distribution, int seed)
+       {
                try {
                        NativeHEHelper.initialize();
                } catch (Exception e) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index b9a3a14fd5..14c093ebe8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -23,7 +23,6 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;

Reply via email to