lindong28 commented on a change in pull request #10: URL: https://github.com/apache/flink-ml/pull/10#discussion_r743412222
########## File path: flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java ########## @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.api.core; + +import org.apache.flink.ml.param.BooleanParam; +import org.apache.flink.ml.param.DoubleArrayParam; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.FloatArrayParam; +import org.apache.flink.ml.param.FloatParam; +import org.apache.flink.ml.param.IntArrayParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.LongArrayParam; +import org.apache.flink.ml.param.LongParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidator; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** Tests the behavior of Stage and WithParams. */ +public class StageTest { + + // A WithParams subclass which has one parameter for each pre-defined parameter type. + private interface MyParams<T> extends WithParams<T> { + Param<Boolean> BOOLEAN_PARAM = new BooleanParam("booleanParam", "Description", false); + + Param<Integer> INT_PARAM = + new IntParam("intParam", "Description", 1, ParamValidators.lt(100)); + + Param<Long> LONG_PARAM = + new LongParam("longParam", "Description", 2L, ParamValidators.lt(100)); + + Param<Float> FLOAT_PARAM = + new FloatParam("floatParam", "Description", 3.0f, ParamValidators.lt(100)); + + Param<Double> DOUBLE_PARAM = + new DoubleParam("doubleParam", "Description", 4.0, ParamValidators.lt(100)); + + Param<String> STRING_PARAM = new StringParam("stringParam", "Description", "5"); + + Param<Integer[]> INT_ARRAY_PARAM = + new IntArrayParam("intArrayParam", "Description", new Integer[] {6, 7}); + + Param<Long[]> LONG_ARRAY_PARAM = + new LongArrayParam( + "longArrayParam", + "Description", + new Long[] {8L, 9L}, + ParamValidators.alwaysTrue()); + + Param<Float[]> FLOAT_ARRAY_PARAM = + new FloatArrayParam("floatArrayParam", "Description", new Float[] {10.0f, 11.0f}); + + Param<Double[]> DOUBLE_ARRAY_PARAM = + new DoubleArrayParam( + "doubleArrayParam", + "Description", + new Double[] {12.0, 13.0}, + ParamValidators.alwaysTrue()); + + Param<String[]> STRING_ARRAY_PARAM = + new StringArrayParam("stringArrayParam", "Description", new String[] {"14", "15"}); + } + + // A Stage subclass which inherits all parameters from MyParams and defines an extra parameter. + private static class MyStage implements Stage<MyStage>, MyParams<MyStage> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public final Param<Integer> extraIntParam = + new IntParam("extraIntParam", "Description", 20, ParamValidators.alwaysTrue()); + + public MyStage() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static MyStage load(String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + } + + // A Stage subclass without the static load() method. + private static class MyStageWithoutLoad implements Stage<MyStage>, MyParams<MyStage> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public MyStageWithoutLoad() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + } + + // Asserts that m1 and m2 are equivalent. + private static void assertParamMapEquals(Map<Param<?>, Object> m1, Map<Param<?>, Object> m2) { + Assert.assertTrue(m1 != null && m2 != null); + Assert.assertEquals(m1.size(), m2.size()); + + for (Map.Entry<Param<?>, Object> entry : m1.entrySet()) { + Assert.assertTrue(m2.containsKey(entry.getKey())); + Object v1 = entry.getValue(); + Object v2 = m2.get(entry.getKey()); + if (v1.getClass().isArray() && v2.getClass().isArray()) { + Assert.assertArrayEquals((Object[]) v1, (Object[]) v2); + } else { + Assert.assertEquals(v1, v2); + } + } + } + + // Saves and loads the given stage. And verifies that the loaded stage has same parameter values + // as the original stage. + private static Stage<?> validateStageReadWrite( + Stage<?> stage, Map<String, Object> paramOverrides) throws IOException { + for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) { + Param<?> param = stage.getParam(entry.getKey()); + ReadWriteUtils.setStageParam(stage, param, entry.getValue()); + } + + String tempDir = Files.createTempDirectory("").toString(); + String path = Paths.get(tempDir, "test").toString(); + stage.save(path); + try { + stage.save(path); + Assert.fail("Expected IOException"); + } catch (IOException e) { + // This is expected. + } Review comment: Hmm... are you suggesting to add something like the code below? In my opinion this is a bit overkill to add a dedicated method that is only used in `validateStageSaveLoad`. The logic seems pretty simple here. I can add this if you prefer. ``` private static void validateExceptionOnSecondSave(Stage<?> stage, String path) throws IOException { stage.save(path); try { stage.save(path); Assert.fail("Expected IOException"); } catch (IOException e) { // This is expected. } } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
