This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new 46fe199 [FLINK-29168] Add Transformer for NGram 46fe199 is described below commit 46fe1995ab058250a0f6acd6383b3a1e6bd3fc0e Author: Zhipeng Zhang <zhangzhipe...@gmail.com> AuthorDate: Tue Sep 6 22:07:47 2022 +0800 [FLINK-29168] Add Transformer for NGram This closes #153. --- docs/content/docs/operators/feature/ngram.md | 155 +++++++++++++++++++++ .../flink/ml/examples/feature/NGramExample.java | 63 +++++++++ .../org/apache/flink/ml/feature/ngram/NGram.java | 101 ++++++++++++++ .../apache/flink/ml/feature/ngram/NGramParams.java | 43 ++++++ .../org/apache/flink/ml/feature/NGramTest.java | 127 +++++++++++++++++ .../pyflink/examples/ml/feature/ngram_example.py | 56 ++++++++ flink-ml-python/pyflink/ml/lib/feature/ngram.py | 74 ++++++++++ .../pyflink/ml/lib/feature/tests/test_ngram.py | 94 +++++++++++++ 8 files changed, 713 insertions(+) diff --git a/docs/content/docs/operators/feature/ngram.md b/docs/content/docs/operators/feature/ngram.md new file mode 100644 index 0000000..f6a8009 --- /dev/null +++ b/docs/content/docs/operators/feature/ngram.md @@ -0,0 +1,155 @@ +--- +title: "NGram" +weight: 1 +type: docs +aliases: +- /operators/feature/ngram.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## NGram +NGram converts the input string array into an array of n-grams, +where each n-gram is represented by a space-separated string of +words. If the length of the input array is less than `n`, no +n-grams are returned. + +### Input Columns + +| Param name | Type | Default | Description | +|:-----------|:---------|:----------|:--------------------| +| inputCol | String[] | `"input"` | Input string array. | + +### Output Columns + +| Param name | Type | Default | Description | +|:-----------|:---------|:-----------|:------------| +| outputCol | String[] | `"output"` | N-grams. | + +### Parameters + +| Key | Default | Type | Required | Description | +|:----------|:-----------|:--------|:---------|:-------------------------------------| +| n | `2` | Integer | no | Number of elements per n-gram (>=1). | +| inputCol | `"input"` | String | no | Input column name. | +| outputCol | `"output"` | String | no | Output column name. | + +### Examples + +{{< tabs examples >}} + +{{< tab "Java">}} + +```java +import org.apache.flink.ml.feature.ngram.NGram; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import java.util.Arrays; + +/** Simple program that creates an NGram instance and uses it for feature engineering. */ +public class NGramExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + DataStream<Row> inputStream = + env.fromElements( + Row.of((Object) new String[0]), + Row.of((Object) new String[] {"a", "b", "c"}), + Row.of((Object) new String[] {"a", "b", "c", "d"})); + Table inputTable = tEnv.fromDataStream(inputStream).as("input"); + + // Creates an NGram object and initializes its parameters. + NGram nGram = new NGram().setN(2).setInputCol("input").setOutputCol("output"); + + // Uses the NGram object for feature transformations. + Table outputTable = nGram.transform(inputTable)[0]; + + // Extracts and displays the results. + for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + + String[] inputValue = (String[]) row.getField(nGram.getInputCol()); + String[] outputValue = (String[]) row.getField(nGram.getOutputCol()); + + System.out.printf( + "Input Value: %s \tOutput Value: %s\n", + Arrays.toString(inputValue), Arrays.toString(outputValue)); + } + } +} + +``` + +{{< /tab>}} + +{{< tab "Python">}} + +```python +# Simple program that creates an NGram instance and uses it for feature +# engineering. + +from pyflink.common import Types +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.ml.lib.feature.ngram import NGram +from pyflink.table import StreamTableEnvironment + +env = StreamExecutionEnvironment.get_execution_environment() + +t_env = StreamTableEnvironment.create(env) + +# Generates input data. +input_data_table = t_env.from_data_stream( + env.from_collection([ + ([],), + (['a', 'b', 'c'],), + (['a', 'b', 'c', 'd'],), + ], + type_info=Types.ROW_NAMED( + ["input", ], + [Types.OBJECT_ARRAY(Types.STRING())]))) + +# Creates an NGram object and initializes its parameters. +n_gram = NGram() \ + .set_input_col('input') \ + .set_n(2) \ + .set_output_col('output') + +# Uses the NGram object for feature transformations. +output = n_gram.transform(input_data_table)[0] + +# Extracts and displays the results. +field_names = output.get_schema().get_field_names() +for result in t_env.to_data_stream(output).execute_and_collect(): + input_value = result[field_names.index(n_gram.get_input_col())] + output_value = result[field_names.index(n_gram.get_output_col())] + print('Input Value: ' + ' '.join(input_value) + '\tOutput Value: ' + str(output_value)) + +``` + +{{< /tab>}} + +{{< /tabs>}} diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NGramExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NGramExample.java new file mode 100644 index 0000000..c042cbe --- /dev/null +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NGramExample.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.examples.feature; + +import org.apache.flink.ml.feature.ngram.NGram; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import java.util.Arrays; + +/** Simple program that creates an NGram instance and uses it for feature engineering. */ +public class NGramExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + DataStream<Row> inputStream = + env.fromElements( + Row.of((Object) new String[0]), + Row.of((Object) new String[] {"a", "b", "c"}), + Row.of((Object) new String[] {"a", "b", "c", "d"})); + Table inputTable = tEnv.fromDataStream(inputStream).as("input"); + + // Creates an NGram object and initializes its parameters. + NGram nGram = new NGram().setN(2).setInputCol("input").setOutputCol("output"); + + // Uses the NGram object for feature transformations. + Table outputTable = nGram.transform(inputTable)[0]; + + // Extracts and displays the results. + for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + + String[] inputValue = (String[]) row.getField(nGram.getInputCol()); + String[] outputValue = (String[]) row.getField(nGram.getOutputCol()); + + System.out.printf( + "Input Value: %s \tOutput Value: %s\n", + Arrays.toString(inputValue), Arrays.toString(outputValue)); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/ngram/NGram.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/ngram/NGram.java new file mode 100644 index 0000000..f2464a8 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/ngram/NGram.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.ngram; + +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.functions.ScalarFunction; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +/** + * A Transformer that converts the input string array into an array of n-grams, where each n-gram is + * represented by a space-separated string of words. If the length of the input array is less than + * `n`, no n-grams are returned. + * + * <p>See https://en.wikipedia.org/wiki/N-gram. + */ +public class NGram implements Transformer<NGram>, NGramParams<NGram> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public NGram() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + Expression nGramUdf = + Expressions.call(NGramUdf.class, $(getInputCol()), getN()).as(getOutputCol()); + Table output = inputs[0].addColumns(nGramUdf); + return new Table[] {output}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + public static NGram load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + /** + * The main logic of {@link NGram}, which converts the input string array to an array of + * n-grams. + */ + public static class NGramUdf extends ScalarFunction { + + public String[] eval(String[] items, int n) { + int numItems = items.length; + if (n > numItems) { + return new String[0]; + } else { + String[] output = new String[numItems - n + 1]; + for (int i = 0; i < numItems - n + 1; i++) { + StringBuilder stringBuilder = new StringBuilder(); + for (int j = 0; j < n; j++) { + stringBuilder.append(items[i + j]); + stringBuilder.append(" "); + } + + output[i] = stringBuilder.deleteCharAt(stringBuilder.length() - 1).toString(); + } + + return output; + } + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/ngram/NGramParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/ngram/NGramParams.java new file mode 100644 index 0000000..57bf3a0 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/ngram/NGramParams.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.ngram; + +import org.apache.flink.ml.common.param.HasInputCol; +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** + * Params of {@link NGram}. + * + * @param <T> The class type of this instance. + */ +public interface NGramParams<T> extends HasInputCol<T>, HasOutputCol<T> { + Param<Integer> N = + new IntParam("n", "Number of elements per n-gram (>=1).", 2, ParamValidators.gtEq(1)); + + default int getN() { + return get(N); + } + + default T setN(int value) { + return set(N, value); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NGramTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NGramTest.java new file mode 100644 index 0000000..2c9f162 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NGramTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.ngram.NGram; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link NGram}. */ +public class NGramTest extends AbstractTestBase { + private StreamTableEnvironment tEnv; + private StreamExecutionEnvironment env; + private Table inputDataTable; + private static final List<Row> EXPECTED_OUTPUT = + Arrays.asList( + Row.of((Object) new String[0]), + Row.of((Object) new String[] {"a b", "b c"}), + Row.of((Object) new String[] {"a b", "b c", "c d"})); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + List<Row> input = + Arrays.asList( + Row.of((Object) new String[0]), + Row.of((Object) new String[] {"a", "b", "c"}), + Row.of((Object) new String[] {"a", "b", "c", "d"})); + DataStream<Row> dataStream = env.fromCollection(input); + inputDataTable = tEnv.fromDataStream(dataStream).as("input"); + } + + @Test + public void testParam() { + NGram nGram = new NGram(); + assertEquals("input", nGram.getInputCol()); + assertEquals("output", nGram.getOutputCol()); + assertEquals(2, nGram.getN()); + + nGram.setInputCol("testInputCol").setOutputCol("testOutputCol").setN(5); + assertEquals("testInputCol", nGram.getInputCol()); + assertEquals("testOutputCol", nGram.getOutputCol()); + assertEquals(5, nGram.getN()); + } + + @Test + public void testOutputSchema() { + NGram nGram = new NGram(); + inputDataTable = + tEnv.fromDataStream(env.fromElements(Row.of(new String[] {""}, ""))) + .as("input", "dummyInput"); + Table output = nGram.transform(inputDataTable)[0]; + assertEquals( + Arrays.asList(nGram.getInputCol(), "dummyInput", nGram.getOutputCol()), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testTransform() throws Exception { + NGram nGram = new NGram(); + Table output = nGram.transform(inputDataTable)[0]; + verifyOutputResult(output, nGram.getOutputCol()); + } + + @Test + public void testSaveLoadAndTransform() throws Exception { + NGram nGram = new NGram(); + NGram loadedNGram = + TestUtils.saveAndReload( + tEnv, nGram, TEMPORARY_FOLDER.newFolder().getAbsolutePath()); + Table output = loadedNGram.transform(inputDataTable)[0]; + verifyOutputResult(output, loadedNGram.getOutputCol()); + } + + private void verifyOutputResult(Table output, String outputCol) throws Exception { + DataStream<Row> dataStream = tEnv.toDataStream(output.select(Expressions.$(outputCol))); + List<Row> actualResults = IteratorUtils.toList(dataStream.executeAndCollect()); + assertEquals(EXPECTED_OUTPUT.size(), actualResults.size()); + actualResults.sort(Comparator.comparingInt(o -> ((String[]) o.getField(0)).length)); + for (int i = 0; i < EXPECTED_OUTPUT.size(); i++) { + assertArrayEquals( + (String[]) EXPECTED_OUTPUT.get(i).getField(0), + (String[]) actualResults.get(i).getField(0)); + } + } +} diff --git a/flink-ml-python/pyflink/examples/ml/feature/ngram_example.py b/flink-ml-python/pyflink/examples/ml/feature/ngram_example.py new file mode 100644 index 0000000..37f71f3 --- /dev/null +++ b/flink-ml-python/pyflink/examples/ml/feature/ngram_example.py @@ -0,0 +1,56 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Simple program that creates an NGram instance and uses it for feature +# engineering. + +from pyflink.common import Types +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.ml.lib.feature.ngram import NGram +from pyflink.table import StreamTableEnvironment + +env = StreamExecutionEnvironment.get_execution_environment() + +t_env = StreamTableEnvironment.create(env) + +# Generates input data. +input_data_table = t_env.from_data_stream( + env.from_collection([ + ([],), + (['a', 'b', 'c'],), + (['a', 'b', 'c', 'd'],), + ], + type_info=Types.ROW_NAMED( + ["input", ], + [Types.OBJECT_ARRAY(Types.STRING())]))) + +# Creates an NGram object and initializes its parameters. +n_gram = NGram() \ + .set_input_col('input') \ + .set_n(2) \ + .set_output_col('output') + +# Uses the NGram object for feature transformations. +output = n_gram.transform(input_data_table)[0] + +# Extracts and displays the results. +field_names = output.get_schema().get_field_names() +for result in t_env.to_data_stream(output).execute_and_collect(): + input_value = result[field_names.index(n_gram.get_input_col())] + output_value = result[field_names.index(n_gram.get_output_col())] + print('Input Value: ' + ' '.join(input_value) + '\tOutput Value: ' + str(output_value)) diff --git a/flink-ml-python/pyflink/ml/lib/feature/ngram.py b/flink-ml-python/pyflink/ml/lib/feature/ngram.py new file mode 100644 index 0000000..10c17fe --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/feature/ngram.py @@ -0,0 +1,74 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import typing + +from pyflink.ml.core.param import IntParam, ParamValidators +from pyflink.ml.core.wrapper import JavaWithParams +from pyflink.ml.lib.feature.common import JavaFeatureTransformer +from pyflink.ml.lib.param import HasInputCol, HasOutputCol + + +class _NGramParams( + JavaWithParams, + HasInputCol, + HasOutputCol +): + """ + Params for :class:`NGram`. + """ + + N: IntParam = IntParam( + "n", + "Number of elements per n-gram (>=1).", + 2, + ParamValidators.gt_eq(1) + ) + + def __init__(self, java_params): + super(_NGramParams, self).__init__(java_params) + + def set_n(self, value: int): + return typing.cast(_NGramParams, self.set(self.N, value)) + + def get_n(self) -> int: + return self.get(self.N) + + @property + def n(self) -> int: + return self.get_n() + + +class NGram(JavaFeatureTransformer, _NGramParams): + """ + A Transformer that converts the input string array into an array of n-grams, + where each n-gram is represented by a space-separated string of words. If + the length of the input array is less than `n`, no n-grams are returned. + + See https://en.wikipedia.org/wiki/N-gram. + """ + + def __init__(self, java_model=None): + super(NGram, self).__init__(java_model) + + @classmethod + def _java_transformer_package_name(cls) -> str: + return "ngram" + + @classmethod + def _java_transformer_class_name(cls) -> str: + return "NGram" diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_ngram.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_ngram.py new file mode 100644 index 0000000..84a3f1e --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_ngram.py @@ -0,0 +1,94 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import os + +from pyflink.common import Types + +from pyflink.ml.lib.feature.ngram import NGram +from pyflink.ml.tests.test_utils import PyFlinkMLTestCase + + +class NGramTest(PyFlinkMLTestCase): + def setUp(self): + super(NGramTest, self).setUp() + self.input_data_table = self.t_env.from_data_stream( + self.env.from_collection([ + ([],), + (['a', 'b', 'c'],), + (['a', 'b', 'c', 'd'],), + ], + type_info=Types.ROW_NAMED( + ["input", ], + [Types.OBJECT_ARRAY(Types.STRING())]))) + + self.expected_output = [ + [], + ['a b', 'b c'], + ['a b', 'b c', 'c d'] + ] + + def test_param(self): + n_gram = NGram() + self.assertEqual('input', n_gram.input_col) + self.assertEqual('output', n_gram.output_col) + self.assertEqual(2, n_gram.n) + + n_gram.set_input_col("test_input_col") \ + .set_output_col("test_output_col") \ + .set_n(5) + + self.assertEqual('test_input_col', n_gram.input_col) + self.assertEqual('test_output_col', n_gram.output_col) + self.assertEqual(5, n_gram.n) + + def test_output_schema(self): + n_gram = NGram() + input_data_table = self.t_env.from_data_stream( + self.env.from_collection([ + ([''], ''), + ], + type_info=Types.ROW_NAMED( + ['input', 'dummy_input'], + [Types.OBJECT_ARRAY(Types.STRING()), Types.STRING()]))) + + output = n_gram.transform(input_data_table)[0] + + self.assertEqual([n_gram.input_col, 'dummy_input', + n_gram.output_col], output.get_schema().get_field_names()) + + def verify_output_result(self, output_table): + predicted_result = [result[1] for result in + self.t_env.to_data_stream(output_table).execute_and_collect()] + predicted_result.sort(key=lambda x: len(x)) + self.assertEqual(len(self.expected_output), len(predicted_result)) + + for i in range(len(self.expected_output)): + self.assertEqual(self.expected_output[i], predicted_result[i]) + + def test_transform(self): + n_gram = NGram() + output = n_gram.transform(self.input_data_table)[0] + self.verify_output_result(output) + + def test_save_load_transform(self): + n_gram = NGram() + path = os.path.join(self.temp_dir, 'test_save_load_transform_ngram') + n_gram.save(path) + n_gram = NGram.load(self.t_env, path) + output = n_gram.transform(self.input_data_table)[0] + self.verify_output_result(output)