This is an automated email from the ASF dual-hosted git repository. dimberman pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push: new 7cadb63 Import connections from a file (#15177) 7cadb63 is described below commit 7cadb63d38900f581b5d81011a1de534fe713c3a Author: natanweinberger <naweinberger+git...@gmail.com> AuthorDate: Mon Apr 5 13:55:14 2021 -0400 Import connections from a file (#15177) * Add connections import CLI command * Add tests for CLI connections import * Add connections import overwrite test When a connections file contains collisions with existing connections, skip them and print a message to stdout indicating that the connection was not imported. * Resolve lint errors --- airflow/cli/cli_parser.py | 11 ++ airflow/cli/commands/connection_command.py | 41 ++++++ tests/cli/commands/test_connection_command.py | 173 ++++++++++++++++++++++++++ 3 files changed, 225 insertions(+) diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index b5c709f..0bf92cc 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -602,6 +602,7 @@ ARG_CONN_EXPORT = Arg( ARG_CONN_EXPORT_FORMAT = Arg( ('--format',), help='Format of the connections data in file', type=str, choices=['json', 'yaml', 'env'] ) +ARG_CONN_IMPORT = Arg(("file",), help="Import connections from a file") # providers ARG_PROVIDER_NAME = Arg( @@ -1200,6 +1201,16 @@ CONNECTIONS_COMMANDS = ( ARG_CONN_EXPORT_FORMAT, ), ), + ActionCommand( + name='import', + help='Import connections from a file', + description=( + "Connections can be imported from the output of the export command.\n" + "The filetype must by json, yaml or env and will be automatically inferred." + ), + func=lazy_load_command('airflow.cli.commands.connection_command.connections_import'), + args=(ARG_CONN_IMPORT,), + ), ) PROVIDERS_COMMANDS = ( ActionCommand( diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py index f35ed36..6e45e2a 100644 --- a/airflow/cli/commands/connection_command.py +++ b/airflow/cli/commands/connection_command.py @@ -29,6 +29,7 @@ from airflow.cli.simple_table import AirflowConsole from airflow.exceptions import AirflowNotFoundException from airflow.hooks.base import BaseHook from airflow.models import Connection +from airflow.secrets.local_filesystem import _create_connection, load_connections_dict from airflow.utils import cli as cli_utils from airflow.utils.cli import suppress_logs_and_warning from airflow.utils.session import create_session @@ -234,3 +235,43 @@ def connections_delete(args): else: session.delete(to_delete) print(f"Successfully deleted connection with `conn_id`={to_delete.conn_id}") + + +@cli_utils.action_logging +def connections_import(args): + """Imports connections from a given file""" + if os.path.exists(args.file): + _import_helper(args.file) + else: + raise SystemExit("Missing connections file.") + + +def _import_helper(file_path): + """Helps import connections from a file""" + connections_dict = load_connections_dict(file_path) + with create_session() as session: + for conn_id, conn_values in connections_dict.items(): + if session.query(Connection).filter(Connection.conn_id == conn_id).first(): + print(f'Could not import connection {conn_id}: connection already exists.') + continue + + allowed_fields = [ + 'extra', + 'description', + 'conn_id', + 'login', + 'conn_type', + 'host', + 'password', + 'schema', + 'port', + 'uri', + 'extra_dejson', + ] + filtered_connection_values = { + key: value for key, value in conn_values.items() if key in allowed_fields + } + connection = _create_connection(conn_id, filtered_connection_values) + session.add(connection) + session.commit() + print(f'Imported connection {conn_id}') diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py index c81ff81..cf27941 100644 --- a/tests/cli/commands/test_connection_command.py +++ b/tests/cli/commands/test_connection_command.py @@ -27,6 +27,7 @@ from parameterized import parameterized from airflow.cli import cli_parser from airflow.cli.commands import connection_command +from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.utils.db import merge_conn from airflow.utils.session import create_session, provide_session @@ -716,3 +717,175 @@ class TestCliDeleteConnections(unittest.TestCase): # Attempt to delete a non-existing connection with pytest.raises(SystemExit, match=r"Did not find a connection with `conn_id`=fake"): connection_command.connections_delete(self.parser.parse_args(["connections", "delete", "fake"])) + + +class TestCliImportConnections(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.parser = cli_parser.get_parser() + clear_db_connections(add_default_connections_back=False) + + @classmethod + def tearDownClass(cls): + clear_db_connections() + + @mock.patch('os.path.exists') + def test_cli_connections_import_should_return_error_if_file_does_not_exist(self, mock_exists): + mock_exists.return_value = False + filepath = '/does/not/exist.json' + with pytest.raises(SystemExit, match=r"Missing connections file."): + connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath])) + + @parameterized.expand( + [ + ("sample.jso",), + ("sample.yml",), + ("sample.environ",), + ] + ) + @mock.patch('os.path.exists') + def test_cli_connections_import_should_return_error_if_file_format_is_invalid( + self, filepath, mock_exists + ): + mock_exists.return_value = True + with pytest.raises( + AirflowException, + match=r"Unsupported file format. The file must have the extension .env or .json or .yaml", + ): + connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath])) + + @mock.patch('airflow.cli.commands.connection_command.load_connections_dict') + @mock.patch('os.path.exists') + def test_cli_connections_import_should_load_connections(self, mock_exists, mock_load_connections_dict): + mock_exists.return_value = True + + # Sample connections to import + expected_connections = { + "new0": { + "conn_type": "postgres", + "description": "new0 description", + "host": "host", + "is_encrypted": False, + "is_extra_encrypted": False, + "login": "airflow", + "port": 5432, + "schema": "airflow", + }, + "new1": { + "conn_type": "mysql", + "description": "new1 description", + "host": "host", + "is_encrypted": False, + "is_extra_encrypted": False, + "login": "airflow", + "port": 3306, + "schema": "airflow", + }, + } + + # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env + mock_load_connections_dict.return_value = expected_connections + + connection_command.connections_import( + self.parser.parse_args(["connections", "import", 'sample.json']) + ) + + # Verify that the imported connections match the expected, sample connections + with create_session() as session: + current_conns = session.query(Connection).all() + + comparable_attrs = [ + "conn_type", + "description", + "host", + "is_encrypted", + "is_extra_encrypted", + "login", + "port", + "schema", + ] + + current_conns_as_dicts = { + current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs} + for current_conn in current_conns + } + assert expected_connections == current_conns_as_dicts + + @provide_session + @mock.patch('airflow.cli.commands.connection_command.load_connections_dict') + @mock.patch('os.path.exists') + def test_cli_connections_import_should_not_overwrite_existing_connections( + self, mock_exists, mock_load_connections_dict, session=None + ): + mock_exists.return_value = True + + # Add a pre-existing connection "new1" + merge_conn( + Connection( + conn_id="new1", + conn_type="mysql", + description="mysql description", + host="mysql", + login="root", + password="", + schema="airflow", + ), + session=session, + ) + + # Sample connections to import, including a collision with "new1" + expected_connections = { + "new0": { + "conn_type": "postgres", + "description": "new0 description", + "host": "host", + "is_encrypted": False, + "is_extra_encrypted": False, + "login": "airflow", + "port": 5432, + "schema": "airflow", + }, + "new1": { + "conn_type": "mysql", + "description": "new1 description", + "host": "host", + "is_encrypted": False, + "is_extra_encrypted": False, + "login": "airflow", + "port": 3306, + "schema": "airflow", + }, + } + + # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env + mock_load_connections_dict.return_value = expected_connections + + with redirect_stdout(io.StringIO()) as stdout: + connection_command.connections_import( + self.parser.parse_args(["connections", "import", 'sample.json']) + ) + + assert 'Could not import connection new1: connection already exists.' in stdout.getvalue() + + # Verify that the imported connections match the expected, sample connections + current_conns = session.query(Connection).all() + + comparable_attrs = [ + "conn_type", + "description", + "host", + "is_encrypted", + "is_extra_encrypted", + "login", + "port", + "schema", + ] + + current_conns_as_dicts = { + current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs} + for current_conn in current_conns + } + assert current_conns_as_dicts['new0'] == expected_connections['new0'] + + # The existing connection's description should not have changed + assert current_conns_as_dicts['new1']['description'] == 'new1 description'