This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e1c106d7c Fix updating variables during variable imports (#33932)
0e1c106d7c is described below

commit 0e1c106d7cd0703125528a691088e42e17c99929
Author: Ephraim Anierobi <splendidzig...@gmail.com>
AuthorDate: Fri Sep 1 15:51:20 2023 +0100

    Fix updating variables during variable imports (#33932)
    
    * Fix updating variables during variable imports
    
    We should only create new variables during variable imports and not update
    already existing variables.
    
    * Apply suggestions from code review
    
    Co-authored-by: Tzu-ping Chung <uranu...@gmail.com>
    
    * Use flag for variable import in cli and UI
    
    * apply suggestions from code review
    
    * Update airflow/cli/commands/variable_command.py
    
    Co-authored-by: Tzu-ping Chung <uranu...@gmail.com>
    
    * fixup! Update airflow/cli/commands/variable_command.py
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <uranu...@gmail.com>
---
 airflow/cli/cli_config.py                        |  8 +++-
 airflow/cli/commands/variable_command.py         | 20 +++++++++-
 airflow/www/templates/airflow/variable_list.html | 12 ++++++
 airflow/www/views.py                             | 26 ++++++++++++-
 tests/cli/commands/test_variable_command.py      | 18 +++++++++
 tests/www/views/test_views_variable.py           | 49 ++++++++++++++++++++++++
 6 files changed, 129 insertions(+), 4 deletions(-)

diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py
index fadf988a4a..e8e096ee6f 100644
--- a/airflow/cli/cli_config.py
+++ b/airflow/cli/cli_config.py
@@ -557,6 +557,12 @@ ARG_VAR_EXPORT = Arg(
     help="Export all variables to JSON file",
     type=argparse.FileType("w", encoding="UTF-8"),
 )
+ARG_VAR_ACTION_ON_EXISTING_KEY = Arg(
+    ("-a", "--action-on-existing-key"),
+    help="Action to take if we encounter a variable key that already exists.",
+    default="overwrite",
+    choices=("overwrite", "fail", "skip"),
+)
 
 # kerberos
 ARG_PRINCIPAL = Arg(("principal",), help="kerberos principal", nargs="?")
@@ -1454,7 +1460,7 @@ VARIABLES_COMMANDS = (
         name="import",
         help="Import variables",
         
func=lazy_load_command("airflow.cli.commands.variable_command.variables_import"),
-        args=(ARG_VAR_IMPORT, ARG_VERBOSE),
+        args=(ARG_VAR_IMPORT, ARG_VAR_ACTION_ON_EXISTING_KEY, ARG_VERBOSE),
     ),
     ActionCommand(
         name="export",
diff --git a/airflow/cli/commands/variable_command.py 
b/airflow/cli/commands/variable_command.py
index db0e92de28..bc09d0208d 100644
--- a/airflow/cli/commands/variable_command.py
+++ b/airflow/cli/commands/variable_command.py
@@ -31,7 +31,7 @@ from airflow.models import Variable
 from airflow.utils import cli as cli_utils
 from airflow.utils.cli import suppress_logs_and_warning
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
-from airflow.utils.session import create_session
+from airflow.utils.session import create_session, provide_session
 
 
 @suppress_logs_and_warning
@@ -76,7 +76,8 @@ def variables_delete(args):
 
 @cli_utils.action_cli
 @providers_configuration_loaded
-def variables_import(args):
+@provide_session
+def variables_import(args, session):
     """Import variables from a given file."""
     if not os.path.exists(args.file):
         raise SystemExit("Missing variables file.")
@@ -86,7 +87,17 @@ def variables_import(args):
         except JSONDecodeError:
             raise SystemExit("Invalid variables file.")
     suc_count = fail_count = 0
+    skipped = set()
+    action_on_existing = args.action_on_existing_key
+    existing_keys = set()
+    if action_on_existing != "overwrite":
+        existing_keys = 
set(session.scalars(select(Variable.key).where(Variable.key.in_(var_json))))
+    if action_on_existing == "fail" and existing_keys:
+        raise SystemExit(f"Failed. These keys: {sorted(existing_keys)} already 
exists.")
     for k, v in var_json.items():
+        if action_on_existing == "skip" and k in existing_keys:
+            skipped.add(k)
+            continue
         try:
             Variable.set(k, v, serialize_json=not isinstance(v, str))
         except Exception as e:
@@ -97,6 +108,11 @@ def variables_import(args):
     print(f"{suc_count} of {len(var_json)} variables successfully updated.")
     if fail_count:
         print(f"{fail_count} variable(s) failed to be updated.")
+    if skipped:
+        print(
+            f"The variables with these keys: {list(sorted(skipped))} "
+            f"were skipped because they already exists"
+        )
 
 
 @providers_configuration_loaded
diff --git a/airflow/www/templates/airflow/variable_list.html 
b/airflow/www/templates/airflow/variable_list.html
index fe2b4182fc..bd2171bdc1 100644
--- a/airflow/www/templates/airflow/variable_list.html
+++ b/airflow/www/templates/airflow/variable_list.html
@@ -29,6 +29,18 @@
       <div class="form-group">
         <input class="form-control-file" type="file" name="file">
       </div>
+      <div class="form-group form-check">
+         <input type="radio" class="form-check-input" name="action_if_exists" 
value="overwrite" checked/>
+          <label class="form-check-label">Overwrite if exists</label>
+      </div>
+      <div class="form-group form-check">
+         <input type="radio" class="form-check-input" name="action_if_exists" 
value="fail"/>
+          <label class="form-check-label">Fail if exists</label>
+      </div>
+      <div class="form-group form-check">
+         <input type="radio" class="form-check-input" name="action_if_exists" 
value="skip" />
+          <label class="form-check-label">Skip if exists</label>
+      </div>
       <button type="submit" class="btn">
         <span class="material-icons">cloud_upload</span>
         Import Variables
diff --git a/airflow/www/views.py b/airflow/www/views.py
index d3ecae5044..6b11e99b72 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -5137,17 +5137,34 @@ class VariableModelView(AirflowModelView):
     @expose("/varimport", methods=["POST"])
     @auth.has_access([(permissions.ACTION_CAN_CREATE, 
permissions.RESOURCE_VARIABLE)])
     @action_logging(event=f"{permissions.RESOURCE_VARIABLE.lower()}.varimport")
-    def varimport(self):
+    @provide_session
+    def varimport(self, session):
         """Import variables."""
         try:
             variable_dict = json.loads(request.files["file"].read())
+            action_on_existing = request.form.get("action_if_exists", 
"overwrite").lower()
         except Exception:
             self.update_redirect()
             flash("Missing file or syntax error.", "error")
             return redirect(self.get_redirect())
         else:
+            existing_keys = set()
+            if action_on_existing != "overwrite":
+                existing_keys = set(
+                    
session.scalars(select(models.Variable.key).where(models.Variable.key.in_(variable_dict)))
+                )
+            if action_on_existing == "fail" and existing_keys:
+                failed_repr = ", ".join(repr(k) for k in sorted(existing_keys))
+                flash(f"Failed. The variables with these keys: {failed_repr}  
already exists.")
+                logging.error(f"Failed. The variables with these keys: 
{failed_repr}  already exists.")
+                return redirect(location=request.referrer)
+            skipped = set()
             suc_count = fail_count = 0
             for k, v in variable_dict.items():
+                if action_on_existing == "skip" and k in existing_keys:
+                    logging.warning("Variable: %s already exists, skipping.", 
k)
+                    skipped.add(k)
+                    continue
                 try:
                     models.Variable.set(k, v, serialize_json=not isinstance(v, 
str))
                 except Exception as exc:
@@ -5158,6 +5175,13 @@ class VariableModelView(AirflowModelView):
             flash(f"{suc_count} variable(s) successfully updated.")
             if fail_count:
                 flash(f"{fail_count} variable(s) failed to be updated.", 
"error")
+            if skipped:
+                skipped_repr = ", ".join(repr(k) for k in sorted(skipped))
+                flash(
+                    f"The variables with these keys: {skipped_repr} were 
skipped "
+                    "because they already exists",
+                    "warning",
+                )
             self.update_redirect()
             return redirect(self.get_redirect())
 
diff --git a/tests/cli/commands/test_variable_command.py 
b/tests/cli/commands/test_variable_command.py
index b07cfbba8c..93f0fab156 100644
--- a/tests/cli/commands/test_variable_command.py
+++ b/tests/cli/commands/test_variable_command.py
@@ -106,6 +106,24 @@ class TestCliVariables:
         assert Variable.get("false", deserialize_json=True) is False
         assert Variable.get("null", deserialize_json=True) is None
 
+        # test variable import skip existing
+        # set varliable list to ["airflow"] and have it skip during import
+        variable_command.variables_set(self.parser.parse_args(["variables", 
"set", "list", '["airflow"]']))
+        variable_command.variables_import(
+            self.parser.parse_args(
+                ["variables", "import", "variables_types.json", 
"--action-on-existing-key", "skip"]
+            )
+        )
+        assert ["airflow"] == Variable.get("list", deserialize_json=True)  # 
should not be overwritten
+
+        # test variable import fails on existing when action is set to fail
+        with pytest.raises(SystemExit):
+            variable_command.variables_import(
+                self.parser.parse_args(
+                    ["variables", "import", "variables_types.json", 
"--action-on-existing-key", "fail"]
+                )
+            )
+
         os.remove("variables_types.json")
 
     def test_variables_list(self):
diff --git a/tests/www/views/test_views_variable.py 
b/tests/www/views/test_views_variable.py
index aca8e2aeee..22ee898c9a 100644
--- a/tests/www/views/test_views_variable.py
+++ b/tests/www/views/test_views_variable.py
@@ -127,6 +127,55 @@ def test_import_variables_success(session, admin_client):
     _check_last_log(session, dag_id=None, event="variables.varimport", 
execution_date=None)
 
 
+def test_import_variables_override_existing_variables_if_set(session, 
admin_client, caplog):
+    assert session.query(Variable).count() == 0
+    Variable.set("str_key", "str_value")
+    content = '{"str_key": "str_value", "int_key": 60}'  # str_key already 
exists
+    bytes_content = io.BytesIO(bytes(content, encoding="utf-8"))
+
+    resp = admin_client.post(
+        "/variable/varimport",
+        data={"file": (bytes_content, "test.json"), "action_if_exist": 
"overwrite"},
+        follow_redirects=True,
+    )
+    check_content_in_response("2 variable(s) successfully updated.", resp)
+    _check_last_log(session, dag_id=None, event="variables.varimport", 
execution_date=None)
+
+
+def test_import_variables_skips_update_if_set(session, admin_client, caplog):
+    assert session.query(Variable).count() == 0
+    Variable.set("str_key", "str_value")
+    content = '{"str_key": "str_value", "int_key": 60}'  # str_key already 
exists
+    bytes_content = io.BytesIO(bytes(content, encoding="utf-8"))
+
+    resp = admin_client.post(
+        "/variable/varimport",
+        data={"file": (bytes_content, "test.json"), "action_if_exists": 
"skip"},
+        follow_redirects=True,
+    )
+    check_content_in_response("1 variable(s) successfully updated.", resp)
+
+    check_content_in_response(
+        "The variables with these keys: &#39;str_key&#39; were skipped because 
they already exists", resp
+    )
+    _check_last_log(session, dag_id=None, event="variables.varimport", 
execution_date=None)
+    assert "Variable: str_key already exists, skipping." in caplog.text
+
+
+def test_import_variables_fails_if_action_if_exists_is_fail(session, 
admin_client, caplog):
+    assert session.query(Variable).count() == 0
+    Variable.set("str_key", "str_value")
+    content = '{"str_key": "str_value", "int_key": 60}'  # str_key already 
exists
+    bytes_content = io.BytesIO(bytes(content, encoding="utf-8"))
+
+    admin_client.post(
+        "/variable/varimport",
+        data={"file": (bytes_content, "test.json"), "action_if_exists": 
"fail"},
+        follow_redirects=True,
+    )
+    assert "Failed. The variables with these keys: 'str_key'  already exists." 
in caplog.text
+
+
 def test_import_variables_anon(session, app):
     assert session.query(Variable).count() == 0
 

Reply via email to