This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new f9f8021 [FLINK-28502] Add Transformer for RegexTokenizer f9f8021 is described below commit f9f802125d604f0155221804237fd4140e239602 Author: Zhipeng Zhang <zhangzhipe...@gmail.com> AuthorDate: Wed Jul 20 17:11:43 2022 +0800 [FLINK-28502] Add Transformer for RegexTokenizer This closes #130. --- .../ml/examples/feature/RegexTokenizerExample.java | 64 +++++++ .../ml/feature/regextokenizer/RegexTokenizer.java | 128 ++++++++++++++ .../regextokenizer/RegexTokenizerParams.java | 78 +++++++++ .../flink/ml/feature/RegexTokenizerTest.java | 194 +++++++++++++++++++++ .../examples/ml/feature/regextokenizer_example.py | 60 +++++++ .../pyflink/ml/lib/feature/regextokenizer.py | 129 ++++++++++++++ .../ml/lib/feature/tests/test_regextokenizer.py | 93 ++++++++++ 7 files changed, 746 insertions(+) diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RegexTokenizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RegexTokenizerExample.java new file mode 100644 index 0000000..4a743bc --- /dev/null +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RegexTokenizerExample.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.examples.feature; + +import org.apache.flink.ml.feature.regextokenizer.RegexTokenizer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import java.util.Arrays; + +/** Simple program that creates a RegexTokenizer instance and uses it for feature engineering. */ +public class RegexTokenizerExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + DataStream<Row> inputStream = + env.fromElements(Row.of("Test for tokenization."), Row.of("Te,st. punct")); + Table inputTable = tEnv.fromDataStream(inputStream).as("input"); + + // Creates a RegexTokenizer object and initializes its parameters. + RegexTokenizer regexTokenizer = + new RegexTokenizer() + .setInputCol("input") + .setOutputCol("output") + .setPattern("\\w+|\\p{Punct}"); + + // Uses the Tokenizer object for feature transformations. + Table outputTable = regexTokenizer.transform(inputTable)[0]; + + // Extracts and displays the results. + for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + + String inputValue = (String) row.getField(regexTokenizer.getInputCol()); + String[] outputValues = (String[]) row.getField(regexTokenizer.getOutputCol()); + + System.out.printf( + "Input Value: %s \tOutput Values: %s\n", + inputValue, Arrays.toString(outputValues)); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/regextokenizer/RegexTokenizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/regextokenizer/RegexTokenizer.java new file mode 100644 index 0000000..6f5f860 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/regextokenizer/RegexTokenizer.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.regextokenizer; + +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.functions.ScalarFunction; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.apache.flink.table.api.Expressions.$; + +/** + * A Transformer which converts the input string to lowercase and then splits it by white spaces + * based on regex. It provides two options to extract tokens: + * + * <ul> + * <li>if "gaps" is true: uses the provided pattern to split the input string. + * <li>else: repeatedly matches the regex (the provided pattern) with the input string. + * </ul> + * + * <p>Moreover, it provides parameters to filter tokens with a minimal length and converts input to + * lowercase. The output of each input string is an array of strings that can be empty. + */ +public class RegexTokenizer + implements Transformer<RegexTokenizer>, RegexTokenizerParams<RegexTokenizer> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public RegexTokenizer() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + Expression tokenizerUdf = + Expressions.call( + RegexTokenizerUdf.class, + $(getInputCol()), + getPattern(), + getGaps(), + getToLowercase(), + getMinTokenLength()) + .as(getOutputCol()); + Table output = inputs[0].addColumns(tokenizerUdf); + return new Table[] {output}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + public static RegexTokenizer load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + /** + * The main logic of ${@link RegexTokenizer}, which converts the input string to an array of + * tokens. + */ + public static class RegexTokenizerUdf extends ScalarFunction { + + public String[] eval( + String input, + String pattern, + Boolean gaps, + boolean toLowercase, + int minTokenLength) { + Pattern regPattern = Pattern.compile(pattern); + input = toLowercase ? input.toLowerCase() : input; + + List<String> tokens = new ArrayList<>(); + if (gaps) { + String[] tokenArray = regPattern.split(input); + for (String token : tokenArray) { + if (token.length() >= minTokenLength) { + tokens.add(token); + } + } + } else { + Matcher matcher = regPattern.matcher(input); + while (matcher.find()) { + String token = matcher.group(); + if (token.length() >= minTokenLength) { + tokens.add(token); + } + } + } + + return tokens.toArray(new String[0]); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/regextokenizer/RegexTokenizerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/regextokenizer/RegexTokenizerParams.java new file mode 100644 index 0000000..029069b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/regextokenizer/RegexTokenizerParams.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.regextokenizer; + +import org.apache.flink.ml.common.param.HasInputCol; +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.BooleanParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** + * Params for {@link RegexTokenizer}. + * + * @param <T> The class type of this instance. + */ +public interface RegexTokenizerParams<T> extends HasInputCol<T>, HasOutputCol<T> { + IntParam MIN_TOKEN_LENGTH = + new IntParam("minTokenLength", "Minimum token length", 1, ParamValidators.gtEq(0)); + + BooleanParam GAPS = new BooleanParam("gaps", "Set regex to match gaps or tokens", true); + + StringParam PATTERN = new StringParam("pattern", "Regex pattern used for tokenizing", "\\s+"); + + BooleanParam TO_LOWERCASE = + new BooleanParam( + "toLowercase", + "Whether to convert all characters to lowercase before tokenizing", + true); + + default T setMinTokenLength(int value) { + return set(MIN_TOKEN_LENGTH, value); + } + + default int getMinTokenLength() { + return get(MIN_TOKEN_LENGTH); + } + + default T setGaps(boolean value) { + return set(GAPS, value); + } + + default Boolean getGaps() { + return get(GAPS); + } + + default T setPattern(String pattern) { + return set(PATTERN, pattern); + } + + default String getPattern() { + return get(PATTERN); + } + + default T setToLowercase(boolean toLowercase) { + return set(TO_LOWERCASE, toLowercase); + } + + default Boolean getToLowercase() { + return get(TO_LOWERCASE); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RegexTokenizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RegexTokenizerTest.java new file mode 100644 index 0000000..af5381d --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RegexTokenizerTest.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.regextokenizer.RegexTokenizer; +import org.apache.flink.ml.feature.regextokenizer.RegexTokenizerParams; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link RegexTokenizer}. */ +public class RegexTokenizerTest extends AbstractTestBase { + private StreamTableEnvironment tEnv; + private StreamExecutionEnvironment env; + private Table inputDataTable; + + private static final List<Row> INPUT = + Arrays.asList(Row.of("Test for tokenization."), Row.of("Te,st. punct")); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + DataStream<Row> dataStream = env.fromCollection(INPUT); + inputDataTable = tEnv.fromDataStream(dataStream).as("input"); + } + + @Test + public void testParam() { + RegexTokenizer regexTokenizer = new RegexTokenizer(); + assertEquals("input", regexTokenizer.getInputCol()); + assertEquals("output", regexTokenizer.getOutputCol()); + assertEquals(1, regexTokenizer.getMinTokenLength()); + assertEquals(true, regexTokenizer.getGaps()); + assertEquals("\\s+", regexTokenizer.getPattern()); + assertEquals(true, regexTokenizer.getToLowercase()); + + regexTokenizer + .setInputCol("testInputCol") + .setOutputCol("testOutputCol") + .setMinTokenLength(3) + .setGaps(false) + .setPattern("\\s") + .setToLowercase(false); + + assertEquals("testInputCol", regexTokenizer.getInputCol()); + assertEquals("testOutputCol", regexTokenizer.getOutputCol()); + assertEquals(3, regexTokenizer.getMinTokenLength()); + assertEquals(false, regexTokenizer.getGaps()); + assertEquals("\\s", regexTokenizer.getPattern()); + assertEquals(false, regexTokenizer.getToLowercase()); + } + + @Test + public void testOutputSchema() { + RegexTokenizer regexTokenizer = new RegexTokenizer(); + inputDataTable = + tEnv.fromDataStream(env.fromElements(Row.of("", ""))).as("input", "dummyInput"); + Table output = regexTokenizer.transform(inputDataTable)[0]; + assertEquals( + Arrays.asList( + regexTokenizer.getInputCol(), "dummyInput", regexTokenizer.getOutputCol()), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testTransform() throws Exception { + List<Row> expectedRows; + int minTokenLength = RegexTokenizerParams.MIN_TOKEN_LENGTH.defaultValue; + boolean gaps = RegexTokenizerParams.GAPS.defaultValue; + String pattern = RegexTokenizerParams.PATTERN.defaultValue; + boolean toLowercase = RegexTokenizerParams.TO_LOWERCASE.defaultValue; + + // default option. + expectedRows = + Arrays.asList( + Row.of((Object) new String[] {"test", "for", "tokenization."}), + Row.of((Object) new String[] {"te,st.", "punct"})); + checkTransform(minTokenLength, gaps, pattern, toLowercase, expectedRows); + + // default option except toLowercase = false. + expectedRows = + Arrays.asList( + Row.of((Object) new String[] {"Test", "for", "tokenization."}), + Row.of((Object) new String[] {"Te,st.", "punct"})); + toLowercase = false; + checkTransform(minTokenLength, gaps, pattern, toLowercase, expectedRows); + + // default option except gaps = false, pattern = "\\w+|\\p{Punct}". + expectedRows = + Arrays.asList( + Row.of((Object) new String[] {"test", "for", "tokenization", "."}), + Row.of((Object) new String[] {"te", ",", "st", ".", "punct"})); + gaps = false; + pattern = "\\w+|\\p{Punct}"; + toLowercase = true; + checkTransform(minTokenLength, gaps, pattern, toLowercase, expectedRows); + + // default option except gaps = false, minTokenLength = 3, pattern = "\\w+|\\p{Punct}". + gaps = false; + minTokenLength = 3; + pattern = "\\w+|\\p{Punct}"; + expectedRows = + Arrays.asList( + Row.of((Object) new String[] {"test", "for", "tokenization"}), + Row.of((Object) new String[] {"punct"})); + checkTransform(minTokenLength, gaps, pattern, toLowercase, expectedRows); + } + + @Test + public void testSaveLoadAndTransform() throws Exception { + RegexTokenizer regexTokenizer = new RegexTokenizer(); + List<Row> expectedRows = + Arrays.asList( + Row.of((Object) new String[] {"test", "for", "tokenization."}), + Row.of((Object) new String[] {"te,st.", "punct"})); + RegexTokenizer loadedRegexTokenizer = + TestUtils.saveAndReload( + tEnv, regexTokenizer, TEMPORARY_FOLDER.newFolder().getAbsolutePath()); + Table output = loadedRegexTokenizer.transform(inputDataTable)[0]; + verifyOutputResult(output, loadedRegexTokenizer.getOutputCol(), expectedRows); + } + + private void checkTransform( + int minTokenLength, + boolean gaps, + String pattern, + boolean toLowercase, + List<Row> expectedOutput) + throws Exception { + RegexTokenizer regexTokenizer = + new RegexTokenizer() + .setMinTokenLength(minTokenLength) + .setGaps(gaps) + .setPattern(pattern) + .setToLowercase(toLowercase); + Table output = regexTokenizer.transform(inputDataTable)[0]; + verifyOutputResult(output, regexTokenizer.getOutputCol(), expectedOutput); + } + + private void verifyOutputResult(Table output, String outputCol, List<Row> expectedOutput) + throws Exception { + DataStream<Row> dataStream = tEnv.toDataStream(output.select(Expressions.$(outputCol))); + List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect()); + assertEquals(expectedOutput.size(), results.size()); + results.sort(Comparator.comparingInt(o -> ((String[]) o.getField(0))[0].hashCode())); + expectedOutput.sort(Comparator.comparingInt(o -> ((String[]) o.getField(0))[0].hashCode())); + for (int i = 0; i < expectedOutput.size(); i++) { + assertArrayEquals( + (String[]) results.get(i).getField(0), + (String[]) expectedOutput.get(i).getField(0)); + } + } +} diff --git a/flink-ml-python/pyflink/examples/ml/feature/regextokenizer_example.py b/flink-ml-python/pyflink/examples/ml/feature/regextokenizer_example.py new file mode 100644 index 0000000..0a5b2a4 --- /dev/null +++ b/flink-ml-python/pyflink/examples/ml/feature/regextokenizer_example.py @@ -0,0 +1,60 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Simple program that creates a VectorAssembler instance and uses it for feature +# engineering. +# +# Before executing this program, please make sure you have followed Flink ML's +# quick start guideline to set up Flink ML and Flink environment. The guideline +# can be found at +# +# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/ + +from pyflink.common import Types +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.ml.lib.feature.regextokenizer import RegexTokenizer +from pyflink.table import StreamTableEnvironment + +env = StreamExecutionEnvironment.get_execution_environment() + +t_env = StreamTableEnvironment.create(env) + +# Generates input data. +input_data_table = t_env.from_data_stream( + env.from_collection([ + ('Test for tokenization.',), + ('Te,st. punct',), + ], + type_info=Types.ROW_NAMED( + ['input'], + [Types.STRING()]))) + +# Creates a RegexTokenizer object and initializes its parameters. +regex_tokenizer = RegexTokenizer() \ + .set_input_col("input") \ + .set_output_col("output") + +# Uses the Tokenizer object for feature transformations. +output = regex_tokenizer.transform(input_data_table)[0] + +# Extracts and displays the results. +field_names = output.get_schema().get_field_names() +for result in t_env.to_data_stream(output).execute_and_collect(): + input_value = result[field_names.index(regex_tokenizer.get_input_col())] + output_value = result[field_names.index(regex_tokenizer.get_output_col())] + print('Input Values: ' + str(input_value) + '\tOutput Value: ' + str(output_value)) diff --git a/flink-ml-python/pyflink/ml/lib/feature/regextokenizer.py b/flink-ml-python/pyflink/ml/lib/feature/regextokenizer.py new file mode 100644 index 0000000..5f7e66a --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/feature/regextokenizer.py @@ -0,0 +1,129 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import typing + +from pyflink.ml.core.param import IntParam, BooleanParam, StringParam, ParamValidators +from pyflink.ml.core.wrapper import JavaWithParams +from pyflink.ml.lib.feature.common import JavaFeatureTransformer +from pyflink.ml.lib.param import HasInputCol, HasOutputCol + + +class _RegexTokenizerParams( + JavaWithParams, + HasInputCol, + HasOutputCol +): + """ + Params for :class:`RegexTokenizer`. + """ + + MIN_TOKEN_LENGTH: IntParam = IntParam( + "min_token_length", + "Minimum token length", + 1, + ParamValidators.gt_eq(0) + ) + + GAPS: BooleanParam = BooleanParam( + "gaps", + "Set regex to match gaps or tokens", + True + ) + + PATTERN: StringParam = StringParam( + "pattern", + "Regex pattern used for tokenizing", + "\\s+" + ) + + TO_LOWERCASE: BooleanParam = BooleanParam( + "to_lowercase", + "Whether to convert all characters to lowercase before tokenizing", + True + ) + + def __init__(self, java_params): + super(_RegexTokenizerParams, self).__init__(java_params) + + def set_min_token_length(self, value: int): + return typing.cast(_RegexTokenizerParams, self.set(self.MIN_TOKEN_LENGTH, value)) + + def get_min_token_length(self) -> int: + return self.get(self.MIN_TOKEN_LENGTH) + + def set_gaps(self, value: bool): + return typing.cast(_RegexTokenizerParams, self.set(self.GAPS, value)) + + def get_gaps(self) -> bool: + return self.get(self.GAPS) + + def set_pattern(self, value: str): + return typing.cast(_RegexTokenizerParams, self.set(self.PATTERN, value)) + + def get_pattern(self) -> str: + return self.get(self.PATTERN) + + def set_to_lowercase(self, value: bool): + return typing.cast(_RegexTokenizerParams, self.set(self.TO_LOWERCASE, value)) + + def get_to_lowertcase(self) -> bool: + return self.get(self.TO_LOWERCASE) + + @property + def min_token_length(self) -> int: + return self.get_min_token_length() + + @property + def gaps(self) -> bool: + return self.get_gaps() + + @property + def pattern(self) -> str: + return self.get_pattern() + + @property + def to_lowercase(self): + return self.get_to_lowertcase() + + +class RegexTokenizer(JavaFeatureTransformer, _RegexTokenizerParams): + """ + A Transformer which converts the input string to lowercase and then splits it by white spaces + based on regex. It provides two options to extract tokens: + + <ul> + <li>if "gaps" is true: uses the provided pattern to split the input string. + <li>else: repeatedly matches the regex (the provided pattern) with the input string. + </ul> + + Moreover, it provides parameters to filter tokens with a minimal length and converts input to + lowercase. The output of each input string is an array of strings that can be empty. + + """ + + def __init__(self, java_model=None): + super(RegexTokenizer, self).__init__(java_model) + + @classmethod + def _java_transformer_package_name(cls) -> str: + return "regextokenizer" + + @classmethod + def _java_transformer_class_name(cls) -> str: + return "RegexTokenizer" diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_regextokenizer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_regextokenizer.py new file mode 100644 index 0000000..7ef1bff --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_regextokenizer.py @@ -0,0 +1,93 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import os + +from pyflink.common import Types + +from pyflink.ml.lib.feature.regextokenizer import RegexTokenizer +from pyflink.ml.tests.test_utils import PyFlinkMLTestCase + + +class RegexTokenizerTest(PyFlinkMLTestCase): + def setUp(self): + super(RegexTokenizerTest, self).setUp() + self.input_data_table = self.t_env.from_data_stream( + self.env.from_collection([ + ('Test for tokenization.',), + ('Te,st. punct',), + ], + type_info=Types.ROW_NAMED( + ['input'], + [Types.STRING()]))) + self.expected_output = [ + ['test', 'for', 'tokenization.'], + ['te,st.', 'punct'] + ] + + def test_param(self): + regex_tokenizer = RegexTokenizer() + self.assertEqual('input', regex_tokenizer.input_col) + self.assertEqual('output', regex_tokenizer.output_col) + self.assertEqual(1, regex_tokenizer.min_token_length) + self.assertEqual(True, regex_tokenizer.gaps) + self.assertEqual('\\s+', regex_tokenizer.pattern) + self.assertEqual(True, regex_tokenizer.to_lowercase) + + regex_tokenizer \ + .set_input_col("testInputCol") \ + .set_output_col("testOutputCol") \ + .set_min_token_length(3) \ + .set_gaps(False) \ + .set_pattern("\\s") \ + .set_to_lowercase(False) + + self.assertEqual('testInputCol', regex_tokenizer.input_col) + self.assertEqual('testOutputCol', regex_tokenizer.output_col) + self.assertEqual(3, regex_tokenizer.min_token_length) + self.assertEqual(False, regex_tokenizer.gaps) + self.assertEqual('\\s', regex_tokenizer.pattern) + self.assertEqual(False, regex_tokenizer.to_lowercase) + + def test_output_schema(self): + regex_tokenizer = RegexTokenizer() + input_data_table = self.t_env.from_data_stream( + self.env.from_collection([ + ('', ''), + ], + type_info=Types.ROW_NAMED( + ['input', 'dummy_input'], + [Types.STRING(), Types.STRING()]))) + output = regex_tokenizer.transform(input_data_table)[0] + + self.assertEqual( + [regex_tokenizer.input_col, 'dummy_input', regex_tokenizer.output_col], + output.get_schema().get_field_names()) + + def test_save_load_transform(self): + regex_tokenizer = RegexTokenizer() + path = os.path.join(self.temp_dir, 'test_save_load_transform_regextokenizer') + regex_tokenizer.save(path) + regex_tokenizer = RegexTokenizer.load(self.t_env, path) + output_table = regex_tokenizer.transform(self.input_data_table)[0] + predicted_results = [result[1] for result in + self.t_env.to_data_stream(output_table).execute_and_collect()] + + predicted_results.sort(key=lambda x: x[0]) + self.expected_output.sort(key=lambda x: x[0]) + self.assertEqual(self.expected_output, predicted_results)