This is an automated email from the ASF dual-hosted git repository. jincheng pushed a commit to branch release-1.9 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.9 by this push: new 39f4fc2 [FLINK-13368][python] Add Configuration Class for Python Table API to Align with Java. 39f4fc2 is described below commit 39f4fc28fddcd3b26058ab8f7bdfc1717b63b91f Author: Wei Zhong <weizhong0...@gmail.com> AuthorDate: Mon Jul 22 20:36:58 2019 +0800 [FLINK-13368][python] Add Configuration Class for Python Table API to Align with Java. This closes #9199 --- flink-python/pyflink/common/__init__.py | 2 + flink-python/pyflink/common/configuration.py | 254 +++++++++++++++++++++ .../pyflink/common/tests/test_configuration.py | 165 +++++++++++++ flink-python/pyflink/table/table_config.py | 20 ++ .../pyflink/table/tests/test_table_config.py | 97 ++++++++ .../table/tests/test_table_config_completeness.py | 2 +- .../table/tests/test_table_environment_api.py | 38 --- 7 files changed, 539 insertions(+), 39 deletions(-) diff --git a/flink-python/pyflink/common/__init__.py b/flink-python/pyflink/common/__init__.py index ceef3c8..ca27df7 100644 --- a/flink-python/pyflink/common/__init__.py +++ b/flink-python/pyflink/common/__init__.py @@ -22,12 +22,14 @@ Important classes used by both Flink Streaming and Batch API: - :class:`ExecutionConfig`: A config to define the behavior of the program execution. """ +from pyflink.common.configuration import Configuration from pyflink.common.execution_config import ExecutionConfig from pyflink.common.execution_mode import ExecutionMode from pyflink.common.input_dependency_constraint import InputDependencyConstraint from pyflink.common.restart_strategy import RestartStrategies, RestartStrategyConfiguration __all__ = [ + 'Configuration', 'ExecutionConfig', 'ExecutionMode', 'InputDependencyConstraint', diff --git a/flink-python/pyflink/common/configuration.py b/flink-python/pyflink/common/configuration.py new file mode 100644 index 0000000..0adc463 --- /dev/null +++ b/flink-python/pyflink/common/configuration.py @@ -0,0 +1,254 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +from pyflink.java_gateway import get_gateway + + +class Configuration: + """ + Lightweight configuration object which stores key/value pairs. + """ + + def __init__(self, other=None, j_configuration=None): + """ + Creates a new configuration. + + :param other: Optional, if this parameter exists, creates a new configuration with a + copy of the given configuration. + :type other: Configuration + :param j_configuration: Optional, the py4j java configuration object, if this parameter + exists, creates a wrapper for it. + :type j_configuration: py4j.java_gateway.JavaObject + """ + if j_configuration is not None: + self._j_configuration = j_configuration + else: + gateway = get_gateway() + JConfiguration = gateway.jvm.org.apache.flink.configuration.Configuration + if other is not None: + self._j_configuration = JConfiguration(other._j_configuration) + else: + self._j_configuration = JConfiguration() + + def get_string(self, key, default_value): + """ + Returns the value associated with the given key as a string. + + :param key: The key pointing to the associated value. + :type key: str + :param default_value: The default value which is returned in case there is no value + associated with the given key. + :type default_value: str + :return: The (default) value associated with the given key. + :rtype: str + """ + return self._j_configuration.getString(key, default_value) + + def set_string(self, key, value): + """ + Adds the given key/value pair to the configuration object. + + :param key: The key of the key/value pair to be added. + :type key: str + :param value: The value of the key/value pair to be added. + :type value: str + """ + self._j_configuration.setString(key, value) + + def get_integer(self, key, default_value): + """ + Returns the value associated with the given key as an integer. + + :param key: The key pointing to the associated value. + :type key: str + :param default_value: The default value which is returned in case there is no value + associated with the given key. + :type default_value: int + :return: The (default) value associated with the given key. + :rtype: int + """ + return self._j_configuration.getLong(key, default_value) + + def set_integer(self, key, value): + """ + Adds the given key/value pair to the configuration object. + + :param key: The key of the key/value pair to be added. + :type key: str + :param value: The value of the key/value pair to be added. + :type value: int + """ + self._j_configuration.setLong(key, value) + + def get_boolean(self, key, default_value): + """ + Returns the value associated with the given key as a boolean. + + :param key: The key pointing to the associated value. + :type key: str + :param default_value: The default value which is returned in case there is no value + associated with the given key. + :type default_value: bool + :return: The (default) value associated with the given key. + :rtype: bool + """ + return self._j_configuration.getBoolean(key, default_value) + + def set_boolean(self, key, value): + """ + Adds the given key/value pair to the configuration object. + + :param key: The key of the key/value pair to be added. + :type key: str + :param value: The value of the key/value pair to be added. + :type value: int + """ + self._j_configuration.setBoolean(key, value) + + def get_float(self, key, default_value): + """ + Returns the value associated with the given key as a float. + + :param key: The key pointing to the associated value. + :type key: str + :param default_value: The default value which is returned in case there is no value + associated with the given key. + :type default_value: float + :return: The (default) value associated with the given key. + :rtype: float + """ + return self._j_configuration.getDouble(key, float(default_value)) + + def set_float(self, key, value): + """ + Adds the given key/value pair to the configuration object. + + :param key: The key of the key/value pair to be added. + :type key: str + :param value: The value of the key/value pair to be added. + :type value: float + """ + self._j_configuration.setDouble(key, float(value)) + + def get_bytearray(self, key, default_value): + """ + Returns the value associated with the given key as a byte array. + + :param key: The key pointing to the associated value. + :type key: str + :param default_value: The default value which is returned in case there is no value + associated with the given key. + :type default_value: bytearray + :return: The (default) value associated with the given key. + :rtype: bytearray + """ + return bytearray(self._j_configuration.getBytes(key, default_value)) + + def set_bytearray(self, key, value): + """ + Adds the given byte array to the configuration object. + + :param key: The key under which the bytes are added. + :type key: str + :param value: The byte array to be added. + :type value: bytearray + """ + self._j_configuration.setBytes(key, value) + + def key_set(self): + """ + Returns the keys of all key/value pairs stored inside this configuration object. + + :return: The keys of all key/value pairs stored inside this configuration object. + :rtype: set + """ + return set(self._j_configuration.keySet()) + + def add_all_to_dict(self, target_dict): + """ + Adds all entries in this configuration to the given dict. + + :param target_dict: The dict to be updated. + :type target_dict: dict + """ + properties = get_gateway().jvm.java.util.Properties() + self._j_configuration.addAllToProperties(properties) + target_dict.update(properties) + + def add_all(self, other, prefix=None): + """ + Adds all entries from the given configuration into this configuration. The keys are + prepended with the given prefix if exist. + + :param other: The configuration whose entries are added to this configuration. + :type other: Configuration + :param prefix: Optional, the prefix to prepend. + :type prefix: str + """ + if prefix is None: + self._j_configuration.addAll(other._j_configuration) + else: + self._j_configuration.addAll(other._j_configuration, prefix) + + def contains_key(self, key): + """ + Checks whether there is an entry with the specified key. + + :param key: Key of entry. + :type key: str + :return: True if the key is stored, false otherwise. + :rtype: bool + """ + return self._j_configuration.containsKey(key) + + def to_dict(self): + """ + Converts the configuration into a dict representation of string key-pair. + + :return: Dict representation of the configuration. + :rtype: dict[str, str] + """ + return dict(self._j_configuration.toMap()) + + def remove_config(self, key): + """ + Removes given config key from the configuration. + + :param key: The config key to remove. + :type key: str + :return: True if config has been removed, false otherwise. + :rtype: bool + """ + gateway = get_gateway() + JConfigOptions = gateway.jvm.org.apache.flink.configuration.ConfigOptions + config_option = JConfigOptions.key(key).noDefaultValue() + return self._j_configuration.removeConfig(config_option) + + def __deepcopy__(self, memodict=None): + return Configuration(j_configuration=self._j_configuration.clone()) + + def __hash__(self): + return self._j_configuration.hashCode() + + def __eq__(self, other): + if isinstance(other, Configuration): + return self._j_configuration.equals(other._j_configuration) + else: + return False + + def __str__(self): + return self._j_configuration.toString() diff --git a/flink-python/pyflink/common/tests/test_configuration.py b/flink-python/pyflink/common/tests/test_configuration.py new file mode 100644 index 0000000..ab2fd27 --- /dev/null +++ b/flink-python/pyflink/common/tests/test_configuration.py @@ -0,0 +1,165 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +from copy import deepcopy + +from pyflink.common import Configuration +from pyflink.testing.test_case_utils import PyFlinkTestCase + + +class ConfigurationTests(PyFlinkTestCase): + + def test_init(self): + conf = Configuration() + + self.assertEqual(conf.to_dict(), dict()) + + conf.set_string("k1", "v1") + conf2 = Configuration(conf) + + self.assertEqual(conf2.to_dict(), {"k1": "v1"}) + + def test_getters_and_setters(self): + conf = Configuration() + + conf.set_string("str", "v1") + conf.set_integer("int", 2) + conf.set_boolean("bool", True) + conf.set_float("float", 0.5) + conf.set_bytearray("bytearray", bytearray([1, 2, 3])) + + str_value = conf.get_string("str", "") + int_value = conf.get_integer("int", 0) + bool_value = conf.get_boolean("bool", False) + float_value = conf.get_float("float", 0) + bytearray_value = conf.get_bytearray("bytearray", bytearray()) + + self.assertEqual(str_value, "v1") + self.assertEqual(int_value, 2) + self.assertEqual(bool_value, True) + self.assertEqual(float_value, 0.5) + self.assertEqual(bytearray_value, bytearray([1, 2, 3])) + + def test_key_set(self): + conf = Configuration() + + conf.set_string("k1", "v1") + conf.set_string("k2", "v2") + conf.set_string("k3", "v3") + key_set = conf.key_set() + + self.assertEqual(key_set, {"k1", "k2", "k3"}) + + def test_add_all_to_dict(self): + conf = Configuration() + + conf.set_string("k1", "v1") + conf.set_integer("k2", 1) + conf.set_float("k3", 1.2) + conf.set_boolean("k4", True) + conf.set_bytearray("k5", bytearray([1, 2, 3])) + target_dict = dict() + conf.add_all_to_dict(target_dict) + + self.assertEqual(target_dict, {"k1": "v1", + "k2": 1, + "k3": 1.2, + "k4": True, + "k5": bytearray([1, 2, 3])}) + + def test_add_all(self): + conf = Configuration() + conf.set_string("k1", "v1") + conf2 = Configuration() + + conf2.add_all(conf) + value1 = conf2.get_string("k1", "") + + self.assertEqual(value1, "v1") + + conf2.add_all(conf, "conf_") + value2 = conf2.get_string("conf_k1", "") + + self.assertEqual(value2, "v1") + + def test_deepcopy(self): + conf = Configuration() + conf.set_string("k1", "v1") + + conf2 = deepcopy(conf) + + self.assertEqual(conf2, conf) + + conf2.set_string("k1", "v2") + + self.assertNotEqual(conf2, conf) + + def test_contains_key(self): + conf = Configuration() + conf.set_string("k1", "v1") + + contains_k1 = conf.contains_key("k1") + contains_k2 = conf.contains_key("k2") + + self.assertTrue(contains_k1) + self.assertFalse(contains_k2) + + def test_to_dict(self): + conf = Configuration() + conf.set_string("k1", "v1") + conf.set_integer("k2", 1) + conf.set_float("k3", 1.2) + conf.set_boolean("k4", True) + + target_dict = conf.to_dict() + + self.assertEqual(target_dict, {"k1": "v1", "k2": "1", "k3": "1.2", "k4": "true"}) + + def test_remove_config(self): + conf = Configuration() + conf.set_string("k1", "v1") + conf.set_integer("k2", 1) + + self.assertTrue(conf.contains_key("k1")) + self.assertTrue(conf.contains_key("k2")) + + self.assertTrue(conf.remove_config("k1")) + self.assertFalse(conf.remove_config("k1")) + + self.assertFalse(conf.contains_key("k1")) + + conf.remove_config("k2") + + self.assertFalse(conf.contains_key("k2")) + + def test_hash_equal_str(self): + conf = Configuration() + conf2 = Configuration() + + conf.set_string("k1", "v1") + conf.set_integer("k2", 1) + conf2.set_string("k1", "v1") + + self.assertNotEqual(hash(conf), hash(conf2)) + self.assertNotEqual(conf, conf2) + + conf2.set_integer("k2", 1) + + self.assertEqual(hash(conf), hash(conf2)) + self.assertEqual(conf, conf2) + + self.assertEqual(str(conf), "{k1=v1, k2=1}") diff --git a/flink-python/pyflink/table/table_config.py b/flink-python/pyflink/table/table_config.py index 62b00ba..3a84fca 100644 --- a/flink-python/pyflink/table/table_config.py +++ b/flink-python/pyflink/table/table_config.py @@ -18,6 +18,8 @@ import sys from py4j.compat import long + +from pyflink.common import Configuration from pyflink.java_gateway import get_gateway __all__ = ['TableConfig'] @@ -227,6 +229,24 @@ class TableConfig(object): rounding_mode = j_math_context.getRoundingMode().name() return precision, rounding_mode + def get_configuration(self): + """ + Returns all key/value configuration. + + :return: All key/value configuration. + :rtype: Configuration + """ + return Configuration(j_configuration=self._j_table_config.getConfiguration()) + + def add_configuration(self, configuration): + """ + Adds the given key/value configuration. + + :param configuration: The given key/value configuration. + :type configuration: Configuration + """ + self._j_table_config.addConfiguration(configuration._j_configuration) + @staticmethod def get_default(): """ diff --git a/flink-python/pyflink/table/tests/test_table_config.py b/flink-python/pyflink/table/tests/test_table_config.py new file mode 100644 index 0000000..7393eb2 --- /dev/null +++ b/flink-python/pyflink/table/tests/test_table_config.py @@ -0,0 +1,97 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import datetime + +from pyflink.common import Configuration +from pyflink.table import TableConfig +from pyflink.testing.test_case_utils import PyFlinkTestCase + + +class TableConfigTests(PyFlinkTestCase): + + def test_get_set_idle_state_retention_time(self): + table_config = TableConfig.get_default() + + table_config.set_idle_state_retention_time( + datetime.timedelta(days=1), datetime.timedelta(days=2)) + + self.assertEqual(2 * 24 * 3600 * 1000, table_config.get_max_idle_state_retention_time()) + self.assertEqual(24 * 3600 * 1000, table_config.get_min_idle_state_retention_time()) + + def test_get_set_decimal_context(self): + table_config = TableConfig.get_default() + + table_config.set_decimal_context(20, "UNNECESSARY") + self.assertEqual((20, "UNNECESSARY"), table_config.get_decimal_context()) + table_config.set_decimal_context(20, "HALF_EVEN") + self.assertEqual((20, "HALF_EVEN"), table_config.get_decimal_context()) + table_config.set_decimal_context(20, "HALF_DOWN") + self.assertEqual((20, "HALF_DOWN"), table_config.get_decimal_context()) + table_config.set_decimal_context(20, "HALF_UP") + self.assertEqual((20, "HALF_UP"), table_config.get_decimal_context()) + table_config.set_decimal_context(20, "FLOOR") + self.assertEqual((20, "FLOOR"), table_config.get_decimal_context()) + table_config.set_decimal_context(20, "CEILING") + self.assertEqual((20, "CEILING"), table_config.get_decimal_context()) + table_config.set_decimal_context(20, "DOWN") + self.assertEqual((20, "DOWN"), table_config.get_decimal_context()) + table_config.set_decimal_context(20, "UP") + self.assertEqual((20, "UP"), table_config.get_decimal_context()) + + def test_get_set_local_timezone(self): + table_config = TableConfig.get_default() + + table_config.set_local_timezone("Asia/Shanghai") + timezone = table_config.get_local_timezone() + + self.assertEqual(timezone, "Asia/Shanghai") + + def test_get_set_max_generated_code_length(self): + table_config = TableConfig.get_default() + + table_config.set_max_generated_code_length(32000) + max_generated_code_length = table_config.get_max_generated_code_length() + + self.assertEqual(max_generated_code_length, 32000) + + def test_get_set_null_check(self): + table_config = TableConfig.get_default() + + null_check = table_config.get_null_check() + self.assertTrue(null_check) + + table_config.set_null_check(False) + null_check = table_config.get_null_check() + + self.assertFalse(null_check) + + def test_get_configuration(self): + table_config = TableConfig.get_default() + + table_config.get_configuration().set_string("k1", "v1") + + self.assertEqual(table_config.get_configuration().get_string("k1", ""), "v1") + + def test_add_configuration(self): + table_config = TableConfig.get_default() + configuration = Configuration() + configuration.set_string("k1", "v1") + + table_config.add_configuration(configuration) + + self.assertEqual(table_config.get_configuration().get_string("k1", ""), "v1") diff --git a/flink-python/pyflink/table/tests/test_table_config_completeness.py b/flink-python/pyflink/table/tests/test_table_config_completeness.py index db7eefa..0326c76 100644 --- a/flink-python/pyflink/table/tests/test_table_config_completeness.py +++ b/flink-python/pyflink/table/tests/test_table_config_completeness.py @@ -39,7 +39,7 @@ class TableConfigCompletenessTests(PythonAPICompletenessTestCase, unittest.TestC @classmethod def excluded_methods(cls): # internal interfaces, no need to expose to users. - return {'getPlannerConfig', 'setPlannerConfig', 'addConfiguration', 'getConfiguration'} + return {'getPlannerConfig', 'setPlannerConfig'} @classmethod def java_method_name(cls, python_method_name): diff --git a/flink-python/pyflink/table/tests/test_table_environment_api.py b/flink-python/pyflink/table/tests/test_table_environment_api.py index ef787f8..3a2ac8b 100644 --- a/flink-python/pyflink/table/tests/test_table_environment_api.py +++ b/flink-python/pyflink/table/tests/test_table_environment_api.py @@ -15,7 +15,6 @@ # # See the License for the specific language governing permissions and # # limitations under the License. ################################################################################ -import datetime import os from py4j.compat import unicode @@ -170,32 +169,6 @@ class StreamTableEnvironmentTests(PyFlinkStreamTableTestCase): expected = ['1,Hi,Hello', '2,Hello,Hello'] self.assert_equals(actual, expected) - def test_table_config(self): - - table_config = TableConfig.get_default() - table_config.set_idle_state_retention_time( - datetime.timedelta(days=1), datetime.timedelta(days=2)) - - self.assertEqual(2 * 24 * 3600 * 1000, table_config.get_max_idle_state_retention_time()) - self.assertEqual(24 * 3600 * 1000, table_config.get_min_idle_state_retention_time()) - - table_config.set_decimal_context(20, "UNNECESSARY") - self.assertEqual((20, "UNNECESSARY"), table_config.get_decimal_context()) - table_config.set_decimal_context(20, "HALF_EVEN") - self.assertEqual((20, "HALF_EVEN"), table_config.get_decimal_context()) - table_config.set_decimal_context(20, "HALF_DOWN") - self.assertEqual((20, "HALF_DOWN"), table_config.get_decimal_context()) - table_config.set_decimal_context(20, "HALF_UP") - self.assertEqual((20, "HALF_UP"), table_config.get_decimal_context()) - table_config.set_decimal_context(20, "FLOOR") - self.assertEqual((20, "FLOOR"), table_config.get_decimal_context()) - table_config.set_decimal_context(20, "CEILING") - self.assertEqual((20, "CEILING"), table_config.get_decimal_context()) - table_config.set_decimal_context(20, "DOWN") - self.assertEqual((20, "DOWN"), table_config.get_decimal_context()) - table_config.set_decimal_context(20, "UP") - self.assertEqual((20, "UP"), table_config.get_decimal_context()) - def test_create_table_environment(self): table_config = TableConfig() table_config.set_max_generated_code_length(32000) @@ -260,17 +233,6 @@ class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase): with self.assertRaises(TableException): t_env.explain(extended=True) - def test_table_config(self): - - table_config = TableConfig() - table_config.set_local_timezone("Asia/Shanghai") - table_config.set_max_generated_code_length(64000) - table_config.set_null_check(True) - - self.assertTrue(table_config.get_null_check()) - self.assertEqual(table_config.get_max_generated_code_length(), 64000) - self.assertEqual(table_config.get_local_timezone(), "Asia/Shanghai") - def test_create_table_environment(self): table_config = TableConfig() table_config.set_max_generated_code_length(32000)