This is an automated email from the ASF dual-hosted git repository. zkaoudi pushed a commit to branch revert-496-tf in repository https://gitbox.apache.org/repos/asf/incubator-wayang.git
commit d782b4233b17bdde2a14bad1a1dd38b8e82eee5d Author: Zoi Kaoudi <[email protected]> AuthorDate: Fri Jan 31 16:03:24 2025 +0100 Revert "update tensorflow java to 1.0.0" --- wayang-api/wayang-api-scala-java/pom.xml | 2 +- wayang-platforms/wayang-tensorflow/pom.xml | 2 +- .../wayang/tensorflow/model/TensorflowModel.java | 17 +++-- .../tensorflow/model/TensorflowModelTest.java | 80 ---------------------- .../wayang/tests/TensorflowIntegrationIT.java | 3 +- .../org/apache/wayang/tests/TensorflowIrisIT.java | 10 +-- 6 files changed, 20 insertions(+), 94 deletions(-) diff --git a/wayang-api/wayang-api-scala-java/pom.xml b/wayang-api/wayang-api-scala-java/pom.xml index 38f2a870..12c9b395 100644 --- a/wayang-api/wayang-api-scala-java/pom.xml +++ b/wayang-api/wayang-api-scala-java/pom.xml @@ -34,7 +34,7 @@ <properties> <java-module-name>org.apache.wayang.api</java-module-name> - <tensorflow.version>1.0.0-rc.2</tensorflow.version> + <tensorflow.version>0.4.2</tensorflow.version> </properties> <dependencyManagement> diff --git a/wayang-platforms/wayang-tensorflow/pom.xml b/wayang-platforms/wayang-tensorflow/pom.xml index 66a018d5..d752be1b 100644 --- a/wayang-platforms/wayang-tensorflow/pom.xml +++ b/wayang-platforms/wayang-tensorflow/pom.xml @@ -37,7 +37,7 @@ <maven.compiler.source>11</maven.compiler.source> <maven.compiler.target>11</maven.compiler.target> <wayang.version>0.7.1</wayang.version> - <tensorflow.version>1.0.0-rc.2</tensorflow.version> + <tensorflow.version>0.4.2</tensorflow.version> </properties> <dependencies> diff --git a/wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/TensorflowModel.java b/wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/TensorflowModel.java index e881f95d..80113ffb 100644 --- a/wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/TensorflowModel.java +++ b/wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/TensorflowModel.java @@ -22,7 +22,10 @@ import org.apache.wayang.basic.model.DLModel; import org.apache.wayang.basic.model.op.Input; import org.apache.wayang.basic.model.op.Op; import org.apache.wayang.basic.model.optimizer.Optimizer; -import org.tensorflow.*; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.*; import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Ops; @@ -119,12 +122,14 @@ public class TensorflowModel extends DLModel implements AutoCloseable { if (accuracyCalculation != null) { runner.fetch(accuracyCalculation.getName()); } - try (Result ret = runner.run()) { - TFloat32 loss = (TFloat32) ret.get(0); + List<Tensor> ret = runner.run(); + try (TFloat32 loss = (TFloat32) ret.get(0)) { System.out.printf("[epoch %d, batch %d] loss: %f ", i + 1, start / batchSize + 1, loss.getFloat()); - - TFloat32 acc = (TFloat32) ret.get(1); - System.out.printf("accuracy: %f ", acc.getFloat()); + } + if (accuracyCalculation != null) { + try (TFloat32 acc = (TFloat32) ret.get(1)) { + System.out.printf("accuracy: %f ", acc.getFloat()); + } } System.out.println(); } diff --git a/wayang-platforms/wayang-tensorflow/src/test/java/org/apache/wayang/tensorflow/model/TensorflowModelTest.java b/wayang-platforms/wayang-tensorflow/src/test/java/org/apache/wayang/tensorflow/model/TensorflowModelTest.java deleted file mode 100644 index 14f0354b..00000000 --- a/wayang-platforms/wayang-tensorflow/src/test/java/org/apache/wayang/tensorflow/model/TensorflowModelTest.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.wayang.tensorflow.model; - -import org.apache.wayang.basic.model.DLModel; -import org.apache.wayang.basic.model.op.*; -import org.apache.wayang.basic.model.op.nn.CrossEntropyLoss; -import org.apache.wayang.basic.model.op.nn.Linear; -import org.apache.wayang.basic.model.op.nn.Sigmoid; -import org.apache.wayang.basic.model.optimizer.GradientDescent; -import org.apache.wayang.basic.model.optimizer.Optimizer; -import org.junit.Test; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.IntNdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt32; -public class TensorflowModelTest { - @Test - public void test() { - FloatNdArray x = NdArrays.ofFloats(Shape.of(6, 4)) - .set(NdArrays.vectorOf(5.1f, 3.5f, 1.4f, 0.2f), 0) - .set(NdArrays.vectorOf(4.9f, 3.0f, 1.4f, 0.2f), 1) - .set(NdArrays.vectorOf(6.9f, 3.1f, 4.9f, 1.5f), 2) - .set(NdArrays.vectorOf(5.5f, 2.3f, 4.0f, 1.3f), 3) - .set(NdArrays.vectorOf(5.8f, 2.7f, 5.1f, 1.9f), 4) - .set(NdArrays.vectorOf(6.7f, 3.3f, 5.7f, 2.5f), 5) - ; - IntNdArray y = NdArrays.vectorOf(0, 0, 1, 1, 2, 2); - Op l1 = new Linear(4, 64, true); - Op s1 = new Sigmoid(); - Op l2 = new Linear(64, 3, true); - s1.with(l1.with(new Input(Input.Type.FEATURES))); - l2.with(s1); - DLModel model = new DLModel(l2); - Op criterion = new CrossEntropyLoss(3); - criterion.with( - new Input(Input.Type.PREDICTED, Op.DType.FLOAT32), - new Input(Input.Type.LABEL, Op.DType.INT32) - ); - Op acc = new Mean(0); - acc.with(new Cast(Op.DType.FLOAT32).with(new Eq().with( - new ArgMax(1).with(new Input(Input.Type.PREDICTED, Op.DType.FLOAT32)), - new Input(Input.Type.LABEL, Op.DType.INT32) - ))); - Optimizer optimizer = new GradientDescent(0.02f); - try (TensorflowModel tfModel = new TensorflowModel(model, criterion, optimizer, acc)) { - System.out.println(tfModel.getOut().getName()); - tfModel.train(x, y, 100, 6); - TFloat32 predicted = tfModel.predict(x); - Ops tf = Ops.create(); - org.tensorflow.op.math.ArgMax<TInt32> argMax = tf.math.argMax(tf.constantOf(predicted), tf.constant(1), TInt32.class); - final TInt32 tensor = argMax.asTensor(); - System.out.print("[ "); - for (int i = 0; i < tensor.shape().size(0); i++) { - System.out.print(tensor.getInt(i) + " "); - } - System.out.println("]"); - } - System.out.println(); - } -} \ No newline at end of file diff --git a/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIntegrationIT.java b/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIntegrationIT.java index b60224a1..3a7f0364 100644 --- a/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIntegrationIT.java +++ b/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIntegrationIT.java @@ -30,6 +30,7 @@ import org.apache.wayang.core.api.WayangContext; import org.apache.wayang.core.plan.wayangplan.WayangPlan; import org.apache.wayang.java.Java; import org.apache.wayang.tensorflow.Tensorflow; +import org.junit.Ignore; import org.junit.Test; import java.util.ArrayList; @@ -67,7 +68,7 @@ public class TensorflowIntegrationIT { public static String[] LABELS = new String[]{"Iris-setosa", "Iris-versicolor", "Iris-virginica"}; - @Test + @Ignore public void test() { /* training features */ CollectionSource<float[]> trainXSource = new CollectionSource<>(trainX, float[].class); diff --git a/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java b/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java index e511df82..eb428ffe 100644 --- a/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java +++ b/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java @@ -24,22 +24,22 @@ import org.apache.wayang.basic.model.op.nn.CrossEntropyLoss; import org.apache.wayang.basic.model.op.nn.Linear; import org.apache.wayang.basic.model.op.nn.Sigmoid; import org.apache.wayang.basic.model.optimizer.Adam; +import org.apache.wayang.basic.model.optimizer.GradientDescent; import org.apache.wayang.basic.model.optimizer.Optimizer; import org.apache.wayang.basic.operators.*; import org.apache.wayang.core.api.WayangContext; import org.apache.wayang.core.plan.wayangplan.Operator; import org.apache.wayang.core.plan.wayangplan.WayangPlan; import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.core.util.WayangCollections; import org.apache.wayang.java.Java; import org.apache.wayang.tensorflow.Tensorflow; +import org.junit.Ignore; import org.junit.Test; import java.net.URI; import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Random; +import java.util.*; /** * Test the Tensorflow integration with Wayang. @@ -56,7 +56,7 @@ public class TensorflowIrisIT { "Iris-virginica", 2 ); - @Test + @Ignore public void test() { final Tuple<Operator, Operator> trainSource = fileOperation(TRAIN_PATH, true); final Tuple<Operator, Operator> testSource = fileOperation(TEST_PATH, false);
