kbeedkar commented on code in PR #551:
URL: https://github.com/apache/incubator-wayang/pull/551#discussion_r2057213215


##########
wayang-fl/src/main/java/org/server/FLServer.java:
##########
@@ -84,6 +84,8 @@ public void 
handleSendPlanHyperParametersMessage(SendPlanHyperparametersMessage
         for(ActorRef client : active_clients.keySet()){
             if(!active_clients.get(client)) continue;
             active_client_count++;
+            // remove this line later
+            client_hyperparams.put("inputFileUrl", 
"file:/Users/vedantaneogi/Downloads/higgs_part"+active_client_count+".txt");

Review Comment:
   Remove these hardcoded paths!



##########
wayang-fl/src/main/java/org/client/FLClient.java:
##########
@@ -93,14 +93,18 @@ private void 
handlePlanHyperparametersMessage(PlanHyperparametersMessage msg) {
     }
 
     private void buildPlan(Object operand){
-
-        Operator op = planFunction.apply(operand, planBuilder, hyperparams);
+        System.out.println(hyperparams.get("inputFileUrl"));
+        JavaPlanBuilder newPlanBuilder = new JavaPlanBuilder(wayangContext)
+                .withJobName(client.getName()+"-job")
+                .withUdfJarOf(FLClient.class);
+        Operator op = planFunction.apply(operand, newPlanBuilder, hyperparams);
 //        System.out.println(op);
         Class classType = 
op.getOutput(0).getType().getDataUnitType().getTypeClass();
         LocalCallbackSink<?> sink = 
LocalCallbackSink.createCollectingSink(collector, classType);
         op.connectTo(0, sink, 0);
         plan = new WayangPlan(sink);
-//        System.out.println(plan);

Review Comment:
   Remove these systems.out messages and use logger



##########
wayang-fl/src/main/java/org/temp/test.java:
##########
@@ -0,0 +1,78 @@
+package org.temp;
+
+import org.apache.wayang.api.DataQuantaBuilder;
+import org.apache.wayang.api.JavaPlanBuilder;
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.ml4all.abstraction.api.Compute;
+import org.apache.wayang.ml4all.abstraction.api.Sample;
+import org.apache.wayang.ml4all.abstraction.api.Transform;
+import org.apache.wayang.ml4all.abstraction.plan.ML4allModel;
+import org.apache.wayang.ml4all.abstraction.plan.wrappers.ComputeWrapper;
+import 
org.apache.wayang.ml4all.abstraction.plan.wrappers.TransformPerPartitionWrapper;
+import org.apache.wayang.ml4all.algorithms.sgd.ComputeLogisticGradient;
+import org.apache.wayang.ml4all.algorithms.sgd.SGDSample;
+import org.client.FLClient;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+public class test {
+    public static void main(String args[]){
+
+        WayangContext wayangContext = new WayangContext(new 
Configuration()).withPlugin(Java.basicPlugin());
+        JavaPlanBuilder pb = new JavaPlanBuilder(wayangContext)
+                .withJobName("test-client"+"-job")
+                .withUdfJarOf(FLClient.class);
+//        List<Double> weights = new ArrayList<>(Collections.nCopies(29, 
0.0));;
+        double[] weights = new double[29];
+        String inputFileUrl = 
"file:/Users/vedantaneogi/Downloads/higgs_part1.txt";
+        int datasetSize = 29;
+        ML4allModel model = new ML4allModel();
+        model.put("weights", weights);
+        ArrayList<ML4allModel> broadcastModel = new ArrayList<>(1);
+        broadcastModel.add(model);
+        // Step 1: Define ML operators
+        Sample sampleOp = new SGDSample();
+        Transform transformOp = new LibSVMTransform(29);
+        Compute computeOp = new ComputeLogisticGradient();
+
+        // Step 2: Create weight DataQuanta
+        var weightsBuilder = pb
+                .loadCollection(broadcastModel)
+                .withName("model");
+
+        // Step 3: Load dataset and apply transform
+        DataQuantaBuilder transformBuilder = (DataQuantaBuilder) pb
+                .readTextFile(inputFileUrl)
+                .withName("source")
+                .mapPartitions(new TransformPerPartitionWrapper(transformOp))
+                .withName("transform");
+
+
+//            Collection<?> parsedData = transformBuilder.collect();
+//            for (Object row : parsedData) {
+//                System.out.println(row);
+//            }
+
+        // Step 4: Sample, compute gradient, and broadcast weights
+        DataQuantaBuilder result = (DataQuantaBuilder) transformBuilder
+                .sample(sampleOp.sampleSize())
+                .withSampleMethod(sampleOp.sampleMethod())
+                .withDatasetSize(datasetSize)
+                .map(new ComputeWrapper<>(computeOp))
+                .withBroadcast(weightsBuilder, "model");
+//
+//        System.out.println(result.collect());
+//        Collection<?> output = result.collect();
+        for (Object o : result.collect()) {
+            System.out.println("Type: " + o.getClass().getName());
+            System.out.println("Value: " + o);
+            for (Object idk : (double[]) o){
+                System.out.println(idk);

Review Comment:
   Remove these hardcoded system.out messages. Use logger with appropriate log 
level



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to