This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 9e8d5660d6fb29b6161835bcd9b0d2230467fab5 Author: Kaxil Naik <[email protected]> AuthorDate: Sat Aug 15 16:01:33 2020 +0100 Webserver: Sanitize values passed to origin param (#10334) (cherry-picked from 5c2bb7b0b0e717b11f093910b443243330ad93ca) --- airflow/www/views.py | 37 +++++++++++++++++++++++++++---------- airflow/www_rbac/views.py | 37 +++++++++++++++++++++++++++---------- tests/www/test_views.py | 23 +++++++++++++++++++++++ tests/www_rbac/test_views.py | 16 ++++++++++++++++ 4 files changed, 93 insertions(+), 20 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index b496e72..6087356 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -54,7 +54,7 @@ from past.builtins import basestring from pygments import highlight, lexers import six from pygments.formatters.html import HtmlFormatter -from six.moves.urllib.parse import quote, unquote +from six.moves.urllib.parse import quote, unquote, urlparse from sqlalchemy import or_, desc, and_, union_all from wtforms import ( @@ -328,6 +328,23 @@ def get_chart_height(dag): return 600 + len(dag.tasks) * 10 +def get_safe_url(url): + """Given a user-supplied URL, ensure it points to our web server""" + try: + valid_schemes = ['http', 'https', ''] + valid_netlocs = [request.host, ''] + + parsed = urlparse(url) + if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs: + return url + except Exception as e: # pylint: disable=broad-except + log.debug("Error validating value in origin parameter passed to URL: %s", url) + log.debug("Error: %s", e) + pass + + return "/admin/" + + def get_date_time_num_runs_dag_runs_form_data(request, session, dag): dttm = request.args.get('execution_date') if dttm: @@ -1108,7 +1125,7 @@ class Airflow(AirflowViewMixin, BaseView): def run(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) dag = dagbag.get_dag(dag_id) task = dag.get_task(task_id) @@ -1179,7 +1196,7 @@ class Airflow(AirflowViewMixin, BaseView): from airflow.exceptions import DagNotFound, DagFileExists dag_id = request.values.get('dag_id') - origin = request.values.get('origin') or "/admin/" + origin = get_safe_url(request.values.get('origin')) try: delete_dag.delete_dag(dag_id) @@ -1203,7 +1220,7 @@ class Airflow(AirflowViewMixin, BaseView): @provide_session def trigger(self, session=None): dag_id = request.values.get('dag_id') - origin = request.values.get('origin') or "/admin/" + origin = get_safe_url(request.values.get('origin')) if request.method == 'GET': return self.render( @@ -1304,7 +1321,7 @@ class Airflow(AirflowViewMixin, BaseView): def clear(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) dag = dagbag.get_dag(dag_id) execution_date = request.form.get('execution_date') @@ -1334,7 +1351,7 @@ class Airflow(AirflowViewMixin, BaseView): @wwwutils.notify_owner def dagrun_clear(self): dag_id = request.form.get('dag_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" @@ -1437,7 +1454,7 @@ class Airflow(AirflowViewMixin, BaseView): dag_id = request.form.get('dag_id') execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == 'true' - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) return self._mark_dagrun_state_as_failed(dag_id, execution_date, confirmed, origin) @@ -1449,7 +1466,7 @@ class Airflow(AirflowViewMixin, BaseView): dag_id = request.form.get('dag_id') execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == 'true' - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) return self._mark_dagrun_state_as_success(dag_id, execution_date, confirmed, origin) @@ -1502,7 +1519,7 @@ class Airflow(AirflowViewMixin, BaseView): def failed(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" @@ -1522,7 +1539,7 @@ class Airflow(AirflowViewMixin, BaseView): def success(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py index f098b25..9d46d03 100644 --- a/airflow/www_rbac/views.py +++ b/airflow/www_rbac/views.py @@ -31,7 +31,7 @@ from datetime import timedelta from urllib.parse import unquote import six -from six.moves.urllib.parse import quote +from six.moves.urllib.parse import quote, urlparse import pendulum import sqlalchemy as sqla @@ -89,6 +89,23 @@ else: dagbag = models.DagBag(os.devnull, include_examples=False) +def get_safe_url(url): + """Given a user-supplied URL, ensure it points to our web server""" + try: + valid_schemes = ['http', 'https', ''] + valid_netlocs = [request.host, ''] + + parsed = urlparse(url) + if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs: + return url + except Exception as e: # pylint: disable=broad-except + logging.debug("Error validating value in origin parameter passed to URL: %s", url) + logging.debug("Error: %s", e) + pass + + return url_for('Airflow.index') + + def get_date_time_num_runs_dag_runs_form_data(request, session, dag): dttm = request.args.get('execution_date') if dttm: @@ -930,7 +947,7 @@ class Airflow(AirflowBaseView): def run(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) dag = dagbag.get_dag(dag_id) task = dag.get_task(task_id) @@ -1000,7 +1017,7 @@ class Airflow(AirflowBaseView): from airflow.exceptions import DagNotFound, DagFileExists dag_id = request.values.get('dag_id') - origin = request.values.get('origin') or url_for('Airflow.index') + origin = get_safe_url(request.values.get('origin')) try: delete_dag.delete_dag(dag_id) @@ -1027,7 +1044,7 @@ class Airflow(AirflowBaseView): def trigger(self, session=None): dag_id = request.values.get('dag_id') - origin = request.values.get('origin') or url_for('Airflow.index') + origin = get_safe_url(request.values.get('origin')) if request.method == 'GET': return self.render_template( @@ -1128,7 +1145,7 @@ class Airflow(AirflowBaseView): def clear(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) dag = dagbag.get_dag(dag_id) execution_date = request.form.get('execution_date') @@ -1158,7 +1175,7 @@ class Airflow(AirflowBaseView): @action_logging def dagrun_clear(self): dag_id = request.form.get('dag_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" @@ -1280,7 +1297,7 @@ class Airflow(AirflowBaseView): dag_id = request.form.get('dag_id') execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == 'true' - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) return self._mark_dagrun_state_as_failed(dag_id, execution_date, confirmed, origin) @@ -1292,7 +1309,7 @@ class Airflow(AirflowBaseView): dag_id = request.form.get('dag_id') execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == 'true' - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) return self._mark_dagrun_state_as_success(dag_id, execution_date, confirmed, origin) @@ -1345,7 +1362,7 @@ class Airflow(AirflowBaseView): def failed(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" @@ -1365,7 +1382,7 @@ class Airflow(AirflowBaseView): def success(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" diff --git a/tests/www/test_views.py b/tests/www/test_views.py index ac71ebb..438830c 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -37,6 +37,7 @@ from flask._compat import PY2 from airflow.operators.bash_operator import BashOperator from airflow.utils import timezone from airflow.utils.db import create_session +from parameterized import parameterized from tests.compat import mock from six.moves.urllib.parse import quote_plus @@ -1115,6 +1116,28 @@ class TestTriggerDag(unittest.TestCase): 'Triggered example_bash_operator, it should start any moment now.', response.data.decode('utf-8')) + @parameterized.expand([ + ("javascript:alert(1)", "/admin/"), + ("http://google.com", "/admin/"), + ( + "%2Fadmin%2Fairflow%2Ftree%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator", + "/admin/airflow/tree?dag_id=example_bash_operator" + ), + ( + "%2Fadmin%2Fairflow%2Fgraph%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator", + "/admin/airflow/graph?dag_id=example_bash_operator" + ), + ("", ""), + ]) + def test_trigger_dag_form_origin_url(self, test_origin, expected_origin): + test_dag_id = "example_bash_operator" + response = self.app.get( + '/admin/airflow/trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin)) + self.assertIn( + '<button class="btn" onclick="location.href = \'{}\'; return false">'.format( + expected_origin), + response.data.decode('utf-8')) + class HelpersTest(unittest.TestCase): @classmethod diff --git a/tests/www_rbac/test_views.py b/tests/www_rbac/test_views.py index 33a8338..4e06b57 100644 --- a/tests/www_rbac/test_views.py +++ b/tests/www_rbac/test_views.py @@ -2244,6 +2244,22 @@ class TestTriggerDag(TestBase): self.check_content_in_response( 'Triggered example_bash_operator, it should start any moment now.', response) + @parameterized.expand([ + ("javascript:alert(1)", "/home"), + ("http://google.com", "/home"), + ("%2Ftree%3Fdag_id%3Dexample_bash_operator", "/tree?dag_id=example_bash_operator"), + ("%2Fgraph%3Fdag_id%3Dexample_bash_operator", "/graph?dag_id=example_bash_operator"), + ("", ""), + ]) + def test_trigger_dag_form_origin_url(self, test_origin, expected_origin): + test_dag_id = "example_bash_operator" + + resp = self.client.get('trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin)) + self.check_content_in_response( + '<button class="btn" onclick="location.href = \'{}\'; return false">'.format( + expected_origin), + resp) + @mock.patch('airflow.www_rbac.views.dagbag.get_dag') def test_trigger_endpoint_uses_existing_dagbag(self, mock_get_dag): """
