This is an automated email from the ASF dual-hosted git repository. zhangzp pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit daaa87646710be215b85e4ac2a277a4ad3f4065f Author: yunfengzhou-hub <[email protected]> AuthorDate: Mon Nov 4 14:10:55 2024 +0800 [FLINK-36653] Fix OnlineLogisticRegressionModel updating logic --- .../OnlineLogisticRegressionModel.java | 6 +++- .../OnlineLogisticRegressionTest.java | 34 +++++++++++++++++----- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java index 36d1555e..c38efd5f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java @@ -145,6 +145,10 @@ public class OnlineLogisticRegressionModel LogisticRegressionModelData modelData = streamRecord.getValue(); coefficient = modelData.coefficient; modelDataVersion = modelData.modelVersion; + servable = + new LogisticRegressionModelServable( + new LogisticRegressionModelData(coefficient, modelDataVersion)); + ParamUtils.updateExistingParams(servable, params); for (Row dataPoint : bufferedPointsState.get()) { processElement(new StreamRecord<>(dataPoint)); } @@ -160,7 +164,7 @@ public class OnlineLogisticRegressionModel if (servable == null) { servable = new LogisticRegressionModelServable( - new LogisticRegressionModelData(coefficient, 0L)); + new LogisticRegressionModelData(coefficient, modelDataVersion)); ParamUtils.updateExistingParams(servable, params); } Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol()); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java index cac9473c..a0bb5e83 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java @@ -61,9 +61,14 @@ import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -72,6 +77,7 @@ import java.util.concurrent.TimeoutException; import static org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel.MODEL_DATA_VERSION_GAUGE_KEY; /** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */ +@RunWith(Parameterized.class) public class OnlineLogisticRegressionTest extends TestLogger { @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); @@ -142,10 +148,21 @@ public class OnlineLogisticRegressionTest extends TestLogger { Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.) }; - private static final int defaultParallelism = 4; + @Parameters + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] {{1}, {4}}); + } + + @Parameter public int defaultParallelism; private static final int numTaskManagers = 2; private static final int numSlotsPerTaskManager = 2; - + private static final Configuration config = + new Configuration() { + { + set(RestOptions.BIND_PORT, "18081-19091"); + set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + } + }; private long currentModelDataVersion; private InMemorySourceFunction<Row> trainDenseSource; @@ -170,9 +187,6 @@ public class OnlineLogisticRegressionTest extends TestLogger { @BeforeClass public static void beforeClass() throws Exception { - Configuration config = new Configuration(); - config.set(RestOptions.BIND_PORT, "18081-19091"); - config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); reporter = InMemoryReporter.create(); reporter.addToConfiguration(config); @@ -184,17 +198,17 @@ public class OnlineLogisticRegressionTest extends TestLogger { .setNumSlotsPerTaskManager(numSlotsPerTaskManager) .build()); miniCluster.start(); + } + @Before + public void before() throws Exception { env = StreamExecutionEnvironment.getExecutionEnvironment(config); env.getConfig().enableObjectReuse(); env.setParallelism(defaultParallelism); env.enableCheckpointing(100); env.setRestartStrategy(RestartStrategies.noRestart()); tEnv = StreamTableEnvironment.create(env); - } - @Before - public void before() throws Exception { currentModelDataVersion = 0; trainDenseSource = new InMemorySourceFunction<>(); @@ -562,6 +576,10 @@ public class OnlineLogisticRegressionTest extends TestLogger { @Test public void testBatchSizeLessThanParallelism() { + if (defaultParallelism < 2) { + return; + } + try { new OnlineLogisticRegression() .setInitialModelData(initDenseModel)
