This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push: new 427a4a8 Replace get accessible dag ids (#11027) 427a4a8 is described below commit 427a4a8f01c414ab571578bb6b8fbe5a8c6b32ef Author: James Timmins <ja...@astronomer.io> AuthorDate: Thu Oct 1 09:37:00 2020 -0700 Replace get accessible dag ids (#11027) --- airflow/www/security.py | 64 +++++++++++++++++------------------ airflow/www/views.py | 21 ++++++------ tests/www/test_security.py | 84 ++++++++++++++++++++++++++++------------------ tests/www/test_views.py | 2 +- 4 files changed, 93 insertions(+), 78 deletions(-) diff --git a/airflow/www/security.py b/airflow/www/security.py index 355ccf0..20686b7 100644 --- a/airflow/www/security.py +++ b/airflow/www/security.py @@ -22,7 +22,9 @@ from typing import Set from flask import current_app, g from flask_appbuilder.security.sqla import models as sqla_models from flask_appbuilder.security.sqla.manager import SecurityManager +from flask_appbuilder.security.sqla.models import PermissionView, Role, User from sqlalchemy import and_, or_ +from sqlalchemy.orm import joinedload from airflow import models from airflow.exceptions import AirflowException @@ -41,7 +43,9 @@ EXISTING_ROLES = { CAN_CREATE = 'can_create' CAN_READ = 'can_read' +CAN_DAG_READ = 'can_dag_read' CAN_EDIT = 'can_edit' +CAN_DAG_EDIT = 'can_dag_edit' CAN_DELETE = 'can_delete' @@ -276,60 +280,54 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin): def get_readable_dags(self, user): """Gets the DAGs readable by authenticated user.""" - return self.get_accessible_dags(CAN_READ, user) + return self.get_accessible_dags([CAN_READ, CAN_DAG_READ], user) def get_editable_dags(self, user): """Gets the DAGs editable by authenticated user.""" - return self.get_accessible_dags(CAN_EDIT, user) + return self.get_accessible_dags([CAN_EDIT, CAN_DAG_EDIT], user) - def get_readable_dag_ids(self, user): + def get_readable_dag_ids(self, user) -> Set[str]: """Gets the DAG IDs readable by authenticated user.""" - return [dag.dag_id for dag in self.get_readable_dags(user)] + return set(dag.dag_id for dag in self.get_readable_dags(user)) - def get_editable_dag_ids(self, user): + def get_editable_dag_ids(self, user) -> Set[str]: """Gets the DAG IDs editable by authenticated user.""" - return [dag.dag_id for dag in self.get_editable_dags(user)] + return set(dag.dag_id for dag in self.get_editable_dags(user)) + + def get_accessible_dag_ids(self, user) -> Set[str]: + """Gets the DAG IDs editable or readable by authenticated user.""" + accessible_dags = self.get_accessible_dags([CAN_EDIT, CAN_DAG_EDIT, CAN_READ, CAN_DAG_READ], user) + return set(dag.dag_id for dag in accessible_dags) @provide_session - def get_accessible_dags(self, user_action, user, session=None): + def get_accessible_dags(self, user_actions, user, session=None): """Generic function to get readable or writable DAGs for authenticated user.""" if user.is_anonymous: return set() + user_query = ( + session.query(User) + .options( + joinedload(User.roles) + .subqueryload(Role.permissions) + .options(joinedload(PermissionView.permission), joinedload(PermissionView.view_menu)) + ) + .filter(User.id == user.id) + .first() + ) resources = set() - for role in user.roles: + for role in user_query.roles: for permission in role.permissions: resource = permission.view_menu.name action = permission.permission.name - if action == user_action: + if action in user_actions: resources.add(resource) - if 'Dag' in resources: + + if bool({'Dag', 'all_dags'}.intersection(resources)): return session.query(DagModel) return session.query(DagModel).filter(DagModel.dag_id.in_(resources)) - def get_accessible_dag_ids(self, username=None) -> Set[str]: - """ - Return a set of dags that user has access to(either read or write). - - :param username: Name of the user. - :return: A set of dag ids that the user could access. - """ - if not username: - username = g.user - - if username.is_anonymous or 'Public' in username.roles: - # return an empty set if the role is public - return set() - - roles = {role.name for role in username.roles} - if {'Admin', 'Viewer', 'User', 'Op'} & roles: - return self.DAG_VMS - - user_perms_views = self.get_all_permissions_views() - # return a set of all dags that the user could access - return {view for perm, view in user_perms_views if perm in self.DAG_PERMS} - def has_access(self, permission, view_name, user=None) -> bool: """ Verify whether a given user could perform certain permission @@ -414,7 +412,7 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin): def _merge_perm(self, permission_name, view_menu_name): """ - Add the new permission , view_menu to ab_permission_view_role if not exists. + Add the new (permission, view_menu) to assoc_permissionview_role if it doesn't exist. It will add the related entry to ab_permission and ab_view_menu two meta tables as well. diff --git a/airflow/www/views.py b/airflow/www/views.py index b6b1978..95a949b 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -33,8 +33,8 @@ import lazy_object_proxy import nvd3 import sqlalchemy as sqla from flask import ( - Markup, Response, current_app, escape, flash, jsonify, make_response, redirect, render_template, request, - session as flask_session, url_for, + Markup, Response, current_app, escape, flash, g, jsonify, make_response, redirect, render_template, + request, session as flask_session, url_for, ) from flask_appbuilder import BaseView, ModelView, expose, has_access, permission_name from flask_appbuilder.actions import action @@ -442,7 +442,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m end = start + dags_per_page # Get all the dag id the user could access - filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids() + filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) with create_session() as session: # read orm_dags from the db @@ -543,7 +543,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m """Dag statistics.""" dr = models.DagRun - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids() + allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) if 'all_dags' in allowed_dag_ids: allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)] @@ -588,7 +588,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m @provide_session def task_stats(self, session=None): """Task Statistics""" - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids() + allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) if not allowed_dag_ids: return wwwutils.json_response({}) @@ -702,7 +702,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m @provide_session def last_dagruns(self, session=None): """Last DAG runs""" - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids() + allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) if 'all_dags' in allowed_dag_ids: allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)] @@ -1385,7 +1385,7 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m @provide_session def blocked(self, session=None): """Mark Dag Blocked.""" - allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids() + allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) if 'all_dags' in allowed_dag_ids: allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)] @@ -2287,7 +2287,6 @@ class Airflow(AirflowBaseView): # noqa: D101 pylint: disable=too-many-public-m return response task = dag.get_task(task_id) - try: url = task.get_extra_links(dttm, link_name) except ValueError as err: @@ -2416,7 +2415,7 @@ class DagFilter(BaseFilter): def apply(self, query, func): # noqa pylint: disable=redefined-outer-name,unused-argument if current_app.appbuilder.sm.has_all_dags_access(): return query - filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids() + filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) return query.filter(self.model.dag_id.in_(filter_dag_ids)) @@ -3136,9 +3135,9 @@ class DagModelView(AirflowModelView): dag_ids_query = dag_ids_query.filter(DagModel.is_paused) owners_query = owners_query.filter(DagModel.is_paused) - filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids() + filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user) # pylint: disable=no-member - if 'all_dags' not in filter_dag_ids: + if not bool({'all_dags', 'Dag'}.intersection(filter_dag_ids)): dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids)) owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids)) # pylint: enable=no-member diff --git a/tests/www/test_security.py b/tests/www/test_security.py index 2399dca..fc7f57a 100644 --- a/tests/www/test_security.py +++ b/tests/www/test_security.py @@ -20,15 +20,17 @@ import logging import unittest from unittest import mock -from flask import Flask -from flask_appbuilder import SQLA, AppBuilder, Model, expose, has_access +from flask_appbuilder import SQLA, Model, expose, has_access from flask_appbuilder.security.sqla import models as sqla_models from flask_appbuilder.views import BaseView, ModelView from sqlalchemy import Column, Date, Float, Integer, String +from airflow import settings from airflow.exceptions import AirflowException -from airflow.www.security import AirflowSecurityManager +from airflow.models import DagModel +from airflow.www import app as application from airflow.www.utils import CustomSQLAInterface +from tests.test_utils.db import clear_db_runs from tests.test_utils.mock_security_manager import MockSecurityManager READ_WRITE = {'can_dag_read', 'can_dag_edit'} @@ -66,22 +68,24 @@ class SomeBaseView(BaseView): class TestSecurity(unittest.TestCase): + @classmethod + def setUpClass(cls): + settings.configure_orm() + cls.session = settings.Session + cls.app = application.create_app(testing=True) + cls.appbuilder = cls.app.appbuilder # pylint: disable=no-member + cls.app.config['WTF_CSRF_ENABLED'] = False + cls.security_manager = cls.appbuilder.sm + cls.role_admin = cls.security_manager.find_role('Admin') + cls.user = cls.appbuilder.sm.add_user( + 'admin', 'admin', 'user', 'ad...@fab.org', cls.role_admin, 'general' + ) + def setUp(self): - self.app = Flask(__name__) - self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' - self.app.config['SECRET_KEY'] = 'secret_key' - self.app.config['CSRF_ENABLED'] = False - self.app.config['WTF_CSRF_ENABLED'] = False self.db = SQLA(self.app) - self.appbuilder = AppBuilder(self.app, - self.db.session, - security_manager_class=AirflowSecurityManager) - self.security_manager = self.appbuilder.sm self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews") self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews") - role_admin = self.security_manager.find_role('Admin') - self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', 'ad...@fab.org', - role_admin, 'general') + log.debug("Complete setup!") def expect_user_is_in_role(self, user, rolename): @@ -112,13 +116,14 @@ class TestSecurity(unittest.TestCase): self.user) def tearDown(self): + clear_db_runs() self.appbuilder = None self.app = None self.db = None log.debug("Complete teardown!") def test_init_role_baseview(self): - role_name = 'MyRole1' + role_name = 'MyRole3' role_perms = ['can_some_action'] role_vms = ['SomeBaseView'] self.security_manager.init_role(role_name, role_vms, role_perms) @@ -159,7 +164,7 @@ class TestSecurity(unittest.TestCase): @mock.patch('airflow.www.security.AirflowSecurityManager.get_user_roles') def test_get_all_permissions_views(self, mock_get_user_roles): - role_name = 'MyRole1' + role_name = 'MyRole5' role_perms = ['can_some_action'] role_vms = ['SomeBaseView'] self.security_manager.init_role(role_name, role_vms, role_perms) @@ -174,23 +179,27 @@ class TestSecurity(unittest.TestCase): self.assertEqual(len(self.security_manager .get_all_permissions_views()), 0) - @mock.patch('airflow.www.security.AirflowSecurityManager' - '.get_all_permissions_views') - @mock.patch('airflow.www.security.AirflowSecurityManager' - '.get_user_roles') - def test_get_accessible_dag_ids(self, mock_get_user_roles, - mock_get_all_permissions_views): - user = mock.MagicMock() + def test_get_accessible_dag_ids(self): role_name = 'MyRole1' - role_perms = ['can_dag_read'] - role_vms = ['dag_id'] - self.security_manager.init_role(role_name, role_vms, role_perms) + permission_action = ['can_dag_read'] + dag_id = 'dag_id' + username = "Mr. User" + self.security_manager.init_role(role_name, [], []) + self.security_manager.sync_perm_for_dag( # type: ignore # pylint: disable=no-member + dag_id, access_control={role_name: permission_action} + ) role = self.security_manager.find_role(role_name) - user.roles = [role] - user.is_anonymous = False - mock_get_all_permissions_views.return_value = {('can_dag_read', 'dag_id')} - - mock_get_user_roles.return_value = [role] + user = self.security_manager.add_user( + username=username, + first_name=username, + last_name=username, + email=f"{username}@fab.org", + role=role, + password=username, + ) + dag_model = DagModel(dag_id="dag_id", fileloc="/tmp/dag_.py", schedule_interval="2 2 * * *") + self.session.add(dag_model) + self.session.commit() self.assertEqual(self.security_manager .get_accessible_dag_ids(user), {'dag_id'}) @@ -235,8 +244,17 @@ class TestSecurity(unittest.TestCase): 'can_varimport', # a real permission, but not a member of DAG_PERMS 'can_eat_pudding', # clearly not a real permission ] + username = "Mrs. User" + user = self.security_manager.add_user( + username=username, + first_name=username, + last_name=username, + email=f"{username}@fab.org", + role=self.role_admin, + password=username, + ) for permission in invalid_permissions: - self.expect_user_is_in_role(self.user, rolename='team-a') + self.expect_user_is_in_role(user, rolename='team-a') with self.assertRaises(AirflowException) as context: self.security_manager.sync_perm_for_dag( 'access_control_test', diff --git a/tests/www/test_views.py b/tests/www/test_views.py index a1b412e..761208e 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -441,7 +441,7 @@ class TestAirflowBaseViews(TestBase): state=State.RUNNING) def test_index(self): - with assert_queries_count(40): + with assert_queries_count(43): resp = self.client.get('/', follow_redirects=True) self.check_content_in_response('DAGs', resp)