This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v1-10-stable
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v1-10-stable by this push:
     new fb90c75  Make airflow upgrade_check a command from a separate dist 
(#12397)
fb90c75 is described below

commit fb90c75f027e1c0f110ecff8da9022f8e1180f84
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Wed Nov 18 13:23:49 2020 +0000

    Make airflow upgrade_check a command from a separate dist (#12397)
    
    By doing it this way we can add new features to upgrade_check, such as
    the adding a flag to make changes automatically, without having to
    release a new version of Airflow.
---
 airflow/bin/cli.py              | 29 ++++++++++++++---------------
 airflow/upgrade/checker.py      | 31 +++++++++++++++++++++++++++++++
 setup.py                        |  2 +-
 tests/upgrade/test_formattes.py |  5 +++--
 4 files changed, 49 insertions(+), 18 deletions(-)

diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index cc8372a..1fb3b88 100644
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -72,8 +72,6 @@ from airflow.models import (
 )
 from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_QUEUED_DEPS)
 from airflow.typing_compat import Protocol
-from airflow.upgrade.checker import check_upgrade
-from airflow.upgrade.formatters import (ConsoleFormatter, JSONFormatter)
 from airflow.utils import cli as cli_utils, db
 from airflow.utils.dot_renderer import render_dag
 from airflow.utils.net import get_hostname
@@ -2278,16 +2276,9 @@ def info(args):
 
 
 def upgrade_check(args):
-    if args.save:
-        filename = args.save
-        if not filename.lower().endswith(".json"):
-            print("Only JSON files are supported", file=sys.stderr)
-        formatter = JSONFormatter(args.save)
-    else:
-        formatter = ConsoleFormatter()
-    all_problems = check_upgrade(formatter)
-    if all_problems:
-        sys.exit(1)
+    sys.exit("""
+Please install apache-airflow-upgrade-check distribution from PyPI to perform 
upgrade checks
+""")
 
 
 class Arg(object):
@@ -3034,9 +3025,10 @@ class CLIFactory(object):
         },
         {
             'name': 'upgrade_check',
-            'help': 'Check if you can upgrade to the new version.',
+            'help': 'Check if you can safely upgrade to the new version.',
             'func': upgrade_check,
-            'args': ('save', ),
+            'from_module': 'airflow.upgrade.checker',
+            'args': (),
         },
     )
     subparsers_dict = {sp['func'].__name__: sp for sp in subparsers}
@@ -3060,6 +3052,14 @@ class CLIFactory(object):
         for sub in subparser_list:
             sub = cls.subparsers_dict[sub]
             sp = subparsers.add_parser(sub['func'].__name__, help=sub['help'])
+            sp.set_defaults(func=sub['func'])
+            if 'from_module' in sub:
+                try:
+                    mod = importlib.import_module(sub['from_module'])
+                    mod.register_arguments(sp)
+                    continue
+                except ImportError:
+                    pass
             for arg in sub['args']:
                 if 'dag_id' in arg and dag_parser:
                     continue
@@ -3068,7 +3068,6 @@ class CLIFactory(object):
                     f: v
                     for f, v in vars(arg).items() if f != 'flags' and v}
                 sp.add_argument(*arg.flags, **kwargs)
-            sp.set_defaults(func=sub['func'])
         return parser
 
 
diff --git a/airflow/upgrade/checker.py b/airflow/upgrade/checker.py
index af01413..0d495da 100644
--- a/airflow/upgrade/checker.py
+++ b/airflow/upgrade/checker.py
@@ -16,6 +16,8 @@
 # under the License.
 
 from __future__ import absolute_import
+import argparse
+import sys
 from typing import List
 
 from airflow.upgrade.formatters import BaseFormatter
@@ -36,3 +38,32 @@ def check_upgrade(formatter):
         formatter.on_next_rule_status(rule_status)
     formatter.end_checking(all_rule_statuses)
     return all_rule_statuses
+
+
+def register_arguments(subparser):
+    subparser.add_argument(
+        "-s", "--save",
+        help="Saves the result to the indicated file. The file format is 
determined by the file extension."
+    )
+    subparser.set_defaults(func=run)
+
+
+def run(args):
+    from airflow.upgrade.formatters import (ConsoleFormatter, JSONFormatter)
+    if args.save:
+        filename = args.save
+        if not filename.lower().endswith(".json"):
+            exit("Only JSON files are supported")
+        formatter = JSONFormatter(args.save)
+    else:
+        formatter = ConsoleFormatter()
+    all_problems = check_upgrade(formatter)
+    if all_problems:
+        sys.exit(1)
+
+
+def __main__():
+    parser = argparse.ArgumentParser()
+    register_arguments(parser)
+    args = parser.parse_args()
+    args.func(args)
diff --git a/setup.py b/setup.py
index 64bedb7..b2a4977 100644
--- a/setup.py
+++ b/setup.py
@@ -652,7 +652,7 @@ def do_setup():
         long_description_content_type='text/markdown',
         license='Apache License 2.0',
         version=version,
-        packages=find_packages(exclude=['tests*']),
+        packages=find_packages(exclude=['tests*', 'airflow.upgrade*']),
         package_data={
             '': ['airflow/alembic.ini', "airflow/git_version", "*.ipynb",
                  "airflow/providers/cncf/kubernetes/example_dags/*.yaml"],
diff --git a/tests/upgrade/test_formattes.py b/tests/upgrade/test_formattes.py
index 0fc8f13..70c2f67 100644
--- a/tests/upgrade/test_formattes.py
+++ b/tests/upgrade/test_formattes.py
@@ -46,9 +46,10 @@ class TestJSONFormatter:
             }
         ]
         parser = cli.CLIFactory.get_parser()
-        with NamedTemporaryFile("w+") as temp:
+        with NamedTemporaryFile("w+", suffix=".json") as temp:
             with pytest.raises(SystemExit):
-                cli.upgrade_check(parser.parse_args(['upgrade_check', '-s', 
temp.name]))
+                args = parser.parse_args(['upgrade_check', '-s', temp.name])
+                args.func(args)
             content = temp.read()
 
         assert json.loads(content) == expected

Reply via email to