This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 341df450831e4c426ff4f8049af8dc52fc0bb598 Author: zhangzp <zhangzhipe...@gmail.com> AuthorDate: Mon Jun 20 09:50:59 2022 +0800 [FLINK-27877] Reduce the length of the operator chain for generating input table --- .../common/DenseVectorArrayGenerator.java | 114 +++++------------- .../datagenerator/common/DenseVectorGenerator.java | 103 +++++----------- .../datagenerator/common/InputTableGenerator.java | 66 ++++++++++ .../common/LabeledPointWithWeightGenerator.java | 134 +++++++-------------- .../datagenerator/common/RowGenerator.java | 77 ++++++++++++ .../flink/ml/benchmark/DataGeneratorTest.java | 45 ++++--- 6 files changed, 275 insertions(+), 264 deletions(-) diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java index 0f8b82f..c1b3a21 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java @@ -18,100 +18,50 @@ package org.apache.flink.ml.benchmark.datagenerator.common; -import org.apache.flink.api.common.functions.RichMapFunction; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; -import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.util.ParamUtils; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.api.Schema; -import org.apache.flink.table.api.Table; -import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; -import org.apache.flink.util.NumberSequenceIterator; +import org.apache.flink.types.Row; import org.apache.flink.util.Preconditions; -import java.util.HashMap; -import java.util.Map; -import java.util.Random; - /** A DataGenerator which creates a table of DenseVector array. */ -public class DenseVectorArrayGenerator - implements InputDataGenerator<DenseVectorArrayGenerator>, - HasArraySize<DenseVectorArrayGenerator>, +public class DenseVectorArrayGenerator extends InputTableGenerator<DenseVectorArrayGenerator> + implements HasArraySize<DenseVectorArrayGenerator>, HasVectorDim<DenseVectorArrayGenerator> { - private final Map<Param<?>, Object> paramMap = new HashMap<>(); - - public DenseVectorArrayGenerator() { - ParamUtils.initializeMapWithDefaultValues(paramMap, this); - } @Override - public Table[] getData(StreamTableEnvironment tEnv) { - StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv); - - DataStream<DenseVector[]> dataStream = - env.fromParallelCollection( - new NumberSequenceIterator(1L, getNumValues()), - BasicTypeInfo.LONG_TYPE_INFO) - .map( - new GenerateRandomContinuousVectorArrayFunction( - getSeed(), getVectorDim(), getArraySize())); - - Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector[].class)).build(); - Table dataTable = tEnv.fromDataStream(dataStream, schema); - if (getColNames() != null) { - Preconditions.checkState(getColNames().length == 1); - Preconditions.checkState(getColNames()[0].length == 1); - dataTable = dataTable.as(getColNames()[0][0]); - } - - return new Table[] {dataTable}; - } + protected RowGenerator[] getRowGenerators() { + String[][] columnNames = getColNames(); + Preconditions.checkState(columnNames.length == 1); + Preconditions.checkState(columnNames[0].length == 1); + int arraySize = getArraySize(); + int vectorDim = getVectorDim(); - private static class GenerateRandomContinuousVectorArrayFunction - extends RichMapFunction<Long, DenseVector[]> { - private final int vectorDim; - private final long initSeed; - private final int arraySize; - private Random random; - - private GenerateRandomContinuousVectorArrayFunction( - long initSeed, int vectorDim, int arraySize) { - this.vectorDim = vectorDim; - this.initSeed = initSeed; - this.arraySize = arraySize; - } - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - int index = getRuntimeContext().getIndexOfThisSubtask(); - random = new Random(Tuple2.of(initSeed, index).hashCode()); - } + return new RowGenerator[] { + new RowGenerator(getNumValues(), getSeed()) { + @Override + protected Row nextRow() { + DenseVector[] result = new DenseVector[arraySize]; + for (int i = 0; i < arraySize; i++) { + result[i] = new DenseVector(vectorDim); + for (int j = 0; j < vectorDim; j++) { + result[i].values[j] = random.nextDouble(); + } + } + Row row = new Row(1); + row.setField(0, result); + return row; + } - @Override - public DenseVector[] map(Long value) { - DenseVector[] result = new DenseVector[arraySize]; - for (int i = 0; i < arraySize; i++) { - result[i] = new DenseVector(vectorDim); - for (int j = 0; j < vectorDim; j++) { - result[i].values[j] = random.nextDouble(); + @Override + protected RowTypeInfo getRowTypeInfo() { + return new RowTypeInfo( + new TypeInformation[] {TypeInformation.of(DenseVector[].class)}, + columnNames[0]); } } - return result; - } - } - - @Override - public Map<Param<?>, Object> getParamMap() { - return paramMap; + }; } } diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java index 4117261..10eae84 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java @@ -18,86 +18,43 @@ package org.apache.flink.ml.benchmark.datagenerator.common; -import org.apache.flink.api.common.functions.RichMapFunction; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; -import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.util.ParamUtils; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.table.api.Table; -import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; -import org.apache.flink.util.NumberSequenceIterator; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.types.Row; import org.apache.flink.util.Preconditions; -import java.util.HashMap; -import java.util.Map; -import java.util.Random; - /** A DataGenerator which creates a table of DenseVector. */ -public class DenseVectorGenerator - implements InputDataGenerator<DenseVectorGenerator>, HasVectorDim<DenseVectorGenerator> { - private final Map<Param<?>, Object> paramMap = new HashMap<>(); - - public DenseVectorGenerator() { - ParamUtils.initializeMapWithDefaultValues(paramMap, this); - } +public class DenseVectorGenerator extends InputTableGenerator<DenseVectorGenerator> + implements HasVectorDim<DenseVectorGenerator> { @Override - public Table[] getData(StreamTableEnvironment tEnv) { - StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv); - - DataStream<DenseVector> dataStream = - env.fromParallelCollection( - new NumberSequenceIterator(1L, getNumValues()), - BasicTypeInfo.LONG_TYPE_INFO) - .map(new RandomDenseVectorGenerator(getSeed(), getVectorDim())); - - Table dataTable = tEnv.fromDataStream(dataStream); - if (getColNames() != null) { - Preconditions.checkState(getColNames().length == 1); - Preconditions.checkState(getColNames()[0].length == 1); - dataTable = dataTable.as(getColNames()[0][0]); - } - - return new Table[] {dataTable}; - } - - private static class RandomDenseVectorGenerator extends RichMapFunction<Long, DenseVector> { - private final int vectorDim; - private final long initSeed; - private Random random; - - private RandomDenseVectorGenerator(long initSeed, int vectorDim) { - this.vectorDim = vectorDim; - this.initSeed = initSeed; - } - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - int index = getRuntimeContext().getIndexOfThisSubtask(); - random = new Random(Tuple2.of(initSeed, index).hashCode()); - } - - @Override - public DenseVector map(Long value) { - double[] values = new double[vectorDim]; - for (int i = 0; i < vectorDim; i++) { - values[i] = random.nextDouble(); + public RowGenerator[] getRowGenerators() { + String[][] columnNames = getColNames(); + Preconditions.checkState(columnNames.length == 1); + Preconditions.checkState(columnNames[0].length == 1); + int vectorDim = getVectorDim(); + + return new RowGenerator[] { + new RowGenerator(getNumValues(), getSeed()) { + + @Override + protected Row nextRow() { + double[] values = new double[vectorDim]; + for (int i = 0; i < values.length; i++) { + values[i] = random.nextDouble(); + } + return Row.of(Vectors.dense(values)); + } + + @Override + protected RowTypeInfo getRowTypeInfo() { + return new RowTypeInfo( + new TypeInformation[] {DenseVectorTypeInfo.INSTANCE}, columnNames[0]); + } } - return Vectors.dense(values); - } - } - - @Override - public Map<Param<?>, Object> getParamMap() { - return paramMap; + }; } } diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/InputTableGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/InputTableGenerator.java new file mode 100644 index 0000000..dd673a7 --- /dev/null +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/InputTableGenerator.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.benchmark.datagenerator.common; + +import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import java.util.HashMap; +import java.util.Map; + +/** Base class for generating data as input table arrays. */ +public abstract class InputTableGenerator<T extends InputTableGenerator<T>> + implements InputDataGenerator<T> { + protected final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public InputTableGenerator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public final Table[] getData(StreamTableEnvironment tEnv) { + StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv); + + RowGenerator[] rowGenerators = getRowGenerators(); + Table[] dataTables = new Table[rowGenerators.length]; + for (int i = 0; i < rowGenerators.length; i++) { + DataStream<Row> dataStream = + env.addSource(rowGenerators[i], "sourceOp-" + i) + .returns(rowGenerators[i].getRowTypeInfo()); + dataTables[i] = tEnv.fromDataStream(dataStream); + } + + return dataTables; + } + + /** Gets generators for all input tables. */ + protected abstract RowGenerator[] getRowGenerators(); + + @Override + public final Map<Param<?>, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java index 0e11071..dff9f07 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java @@ -18,34 +18,19 @@ package org.apache.flink.ml.benchmark.datagenerator.common; -import org.apache.flink.api.common.functions.RichMapFunction; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; -import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; import org.apache.flink.ml.param.IntParam; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.param.ParamValidators; import org.apache.flink.ml.util.ParamUtils; -import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.table.api.Table; -import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; -import org.apache.flink.util.NumberSequenceIterator; import org.apache.flink.util.Preconditions; -import java.util.HashMap; -import java.util.Map; -import java.util.Random; - /** * A DataGenerator which creates a table of features, label and weight. * @@ -58,8 +43,8 @@ import java.util.Random; * </ul> */ public class LabeledPointWithWeightGenerator - implements InputDataGenerator<LabeledPointWithWeightGenerator>, - HasVectorDim<LabeledPointWithWeightGenerator> { + extends InputTableGenerator<LabeledPointWithWeightGenerator> + implements HasVectorDim<LabeledPointWithWeightGenerator> { public static final Param<Integer> FEATURE_ARITY = new IntParam( @@ -79,8 +64,6 @@ public class LabeledPointWithWeightGenerator 2, ParamValidators.gtEq(0)); - private final Map<Param<?>, Object> paramMap = new HashMap<>(); - public LabeledPointWithWeightGenerator() { ParamUtils.initializeMapWithDefaultValues(paramMap, this); } @@ -102,79 +85,46 @@ public class LabeledPointWithWeightGenerator } @Override - public Table[] getData(StreamTableEnvironment tEnv) { - Preconditions.checkState(getColNames().length == 1); - Preconditions.checkState(getColNames()[0].length == 3); - - StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv); - - DataStream<Row> dataStream = - env.fromParallelCollection( - new NumberSequenceIterator(1L, getNumValues()), - BasicTypeInfo.LONG_TYPE_INFO) - .map( - new RandomLabeledPointWithWeightGenerator( - getSeed(), - getVectorDim(), - getFeatureArity(), - getLabelArity()), - new RowTypeInfo( - new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE - }, - getColNames()[0])); - - Table dataTable = tEnv.fromDataStream(dataStream); - - return new Table[] {dataTable}; - } - - @Override - public Map<Param<?>, Object> getParamMap() { - return paramMap; - } - - private static class RandomLabeledPointWithWeightGenerator extends RichMapFunction<Long, Row> { - private final long initSeed; - private final int vectorDim; - private final int featureArity; - private final int labelArity; - private Random random; - - private RandomLabeledPointWithWeightGenerator( - long initSeed, int vectorDim, int featureArity, int labelArity) { - this.initSeed = initSeed; - this.vectorDim = vectorDim; - this.featureArity = featureArity; - this.labelArity = labelArity; - } - - @Override - public void open(Configuration parameters) throws Exception { - super.open(parameters); - int index = getRuntimeContext().getIndexOfThisSubtask(); - random = new Random(Tuple2.of(initSeed, index).hashCode()); - } - - @Override - public Row map(Long ignored) { - double[] features = new double[vectorDim]; - for (int i = 0; i < vectorDim; i++) { - features[i] = getValue(featureArity); - } - - double label = getValue(labelArity); - - double weight = random.nextDouble(); - - return Row.of(Vectors.dense(features), label, weight); - } - - private double getValue(int arity) { - if (arity > 0) { - return random.nextInt(arity); + protected RowGenerator[] getRowGenerators() { + String[][] colNames = getColNames(); + Preconditions.checkState(colNames.length == 1); + Preconditions.checkState(colNames[0].length == 3); + int vectorDim = getVectorDim(); + int featureArity = getFeatureArity(); + int labelArity = getLabelArity(); + + return new RowGenerator[] { + new RowGenerator(getNumValues(), getSeed()) { + @Override + protected Row nextRow() { + double[] features = new double[vectorDim]; + for (int i = 0; i < vectorDim; i++) { + features[i] = getValue(featureArity); + } + + double label = getValue(labelArity); + + double weight = random.nextDouble(); + + return Row.of(Vectors.dense(features), label, weight); + } + + @Override + protected RowTypeInfo getRowTypeInfo() { + return new RowTypeInfo( + new TypeInformation[] { + DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + }, + colNames[0]); + } + + private double getValue(int arity) { + if (arity > 0) { + return random.nextInt(arity); + } + return random.nextDouble(); + } } - return random.nextDouble(); - } + }; } } diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/RowGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/RowGenerator.java new file mode 100644 index 0000000..55fe526 --- /dev/null +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/RowGenerator.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.benchmark.datagenerator.common; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; +import org.apache.flink.types.Row; + +import java.util.Random; + +/** A parallel source to generate user defined rows. */ +public abstract class RowGenerator extends RichParallelSourceFunction<Row> { + /** Random instance to generate data. */ + protected Random random; + /** Number of values to generate in total. */ + private final long numValues; + /** The init seed to generate data. */ + private final long initSeed; + /** Number of tasks to generate in one local task. */ + private long numValuesOnThisTask; + /** Whether this source is still running. */ + private volatile boolean isRunning = true; + + public RowGenerator(long numValues, long initSeed) { + this.numValues = numValues; + this.initSeed = initSeed; + } + + @Override + public final void open(Configuration parameters) throws Exception { + super.open(parameters); + int taskIdx = getRuntimeContext().getIndexOfThisSubtask(); + int numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); + random = new Random(Tuple2.of(initSeed, taskIdx).hashCode()); + long div = numValues / numTasks; + long mod = numValues % numTasks; + numValuesOnThisTask = mod > taskIdx ? div + 1 : div; + } + + @Override + public final void run(SourceContext<Row> ctx) throws Exception { + long cnt = 0; + while (isRunning && cnt < numValuesOnThisTask) { + ctx.collect(nextRow()); + cnt++; + } + } + + @Override + public final void cancel() { + isRunning = false; + } + + /** Generates a new data point. */ + protected abstract Row nextRow(); + + /** Returns the output type information for this generator. */ + protected abstract RowTypeInfo getRowTypeInfo(); +} diff --git a/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java b/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java index 937d25b..7d2883a 100644 --- a/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java +++ b/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java @@ -18,17 +18,22 @@ package org.apache.flink.ml.benchmark; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorArrayGenerator; import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorGenerator; import org.apache.flink.ml.benchmark.datagenerator.common.LabeledPointWithWeightGenerator; import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; import org.apache.flink.util.CloseableIterator; +import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -36,11 +41,22 @@ import static org.junit.Assert.assertTrue; /** Tests data generators. */ public class DataGeneratorTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + } + @Test public void testDenseVectorGenerator() { - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); - DenseVectorGenerator generator = new DenseVectorGenerator() .setColNames(new String[] {"denseVector"}) @@ -51,7 +67,7 @@ public class DataGeneratorTest { for (CloseableIterator<Row> it = generator.getData(tEnv)[0].execute().collect(); it.hasNext(); ) { Row row = it.next(); - assertEquals(row.getArity(), 1); + assertEquals(1, row.getArity()); DenseVector vector = (DenseVector) row.getField(generator.getColNames()[0][0]); assertNotNull(vector); assertEquals(vector.size(), generator.getVectorDim()); @@ -61,10 +77,7 @@ public class DataGeneratorTest { } @Test - public void testDenseVectorArrayGenerator() throws Exception { - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); - + public void testDenseVectorArrayGenerator() { DenseVectorArrayGenerator generator = new DenseVectorArrayGenerator() .setColNames(new String[] {"denseVectors"}) @@ -72,12 +85,13 @@ public class DataGeneratorTest { .setVectorDim(10) .setArraySize(20); - DataStream<DenseVector[]> stream = - tEnv.toDataStream(generator.getData(tEnv)[0], DenseVector[].class); - int count = 0; - for (CloseableIterator<DenseVector[]> it = stream.executeAndCollect(); it.hasNext(); ) { - DenseVector[] vectors = it.next(); + for (CloseableIterator<Row> it = generator.getData(tEnv)[0].execute().collect(); + it.hasNext(); ) { + Row row = it.next(); + assertEquals(1, row.getArity()); + DenseVector[] vectors = (DenseVector[]) row.getField(generator.getColNames()[0][0]); + assertNotNull(vectors); assertEquals(generator.getArraySize(), vectors.length); for (DenseVector vector : vectors) { assertEquals(vector.size(), generator.getVectorDim()); @@ -89,9 +103,6 @@ public class DataGeneratorTest { @Test public void testLabeledPointWithWeightGenerator() { - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); - String featuresCol = "features"; String labelCol = "label"; String weightCol = "weight";