mistercrunch closed pull request #4351: use enum.Enum to rewrite querystatus URL: https://github.com/apache/incubator-superset/pull/4351
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 6ccddbe79c..b42f060315 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -42,11 +42,11 @@ def query(self, query_obj): qry = qry.filter(Annotation.layer_id == query_obj['filter'][0]['val']) qry = qry.filter(Annotation.start_dttm >= query_obj['from_dttm']) qry = qry.filter(Annotation.end_dttm <= query_obj['to_dttm']) - status = QueryStatus.SUCCESS + status = QueryStatus.SUCCESS.value try: df = pd.read_sql_query(qry.statement, db.engine) except Exception as e: - status = QueryStatus.FAILED + status = QueryStatus.FAILED.value logging.exception(e) error_message = ( utils.error_msg_from_exception(e)) @@ -679,13 +679,13 @@ def _get_top_groups(self, df, dimensions): def query(self, query_obj): qry_start_dttm = datetime.now() sql = self.get_query_str(query_obj) - status = QueryStatus.SUCCESS + status = QueryStatus.SUCCESS.value error_message = None df = None try: df = self.database.get_df(sql, self.schema) except Exception as e: - status = QueryStatus.FAILED + status = QueryStatus.FAILED.value logging.exception(e) error_message = ( self.database.db_engine_spec.extract_error_message(e)) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index d26f633bbd..1d73fdd626 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -585,7 +585,7 @@ def handle_cursor(cls, cursor, query, session): stats = polled.get('stats', {}) query = session.query(type(query)).filter_by(id=query.id).one() - if query.status == QueryStatus.STOPPED: + if query.status == QueryStatus.STOPPED.value: cursor.cancel() break @@ -914,7 +914,7 @@ def handle_cursor(cls, cursor, query, session): job_id = None while polled.operationState in unfinished_states: query = session.query(type(query)).filter_by(id=query.id).one() - if query.status == QueryStatus.STOPPED: + if query.status == QueryStatus.STOPPED.value: cursor.cancel() break diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 948cf0d49e..5e60068dbc 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -292,7 +292,7 @@ def __init__( # noqa df, query, duration, - status=QueryStatus.SUCCESS, + status=QueryStatus.SUCCESS.value, error_message=None): self.df = df self.query = query diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 44b692b915..53069b3c9c 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -35,7 +35,7 @@ class Query(Model): # Store the tmp table into the DB only if the user asks for it. tmp_table_name = Column(String(256)) user_id = Column(Integer, ForeignKey('ab_user.id'), nullable=True) - status = Column(String(16), default=QueryStatus.PENDING) + status = Column(String(16), default=QueryStatus.PENDING.value) tab_name = Column(String(256)) sql_editor_id = Column(String(256)) schema = Column(String(256)) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 63225f3e2d..f2faf8ae6f 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -100,7 +100,7 @@ def get_sql_results( sesh = get_session(not ctask.request.called_directly) query = get_query(query_id, sesh) query.error_message = str(e) - query.status = QueryStatus.FAILED + query.status = QueryStatus.FAILED.value query.tmp_table_name = None sesh.commit() raise @@ -127,7 +127,7 @@ def handle_error(msg): resolutions at: {}'.format(msg, troubleshooting_link) \ if troubleshooting_link else msg query.error_message = msg - query.status = QueryStatus.FAILED + query.status = QueryStatus.FAILED.value query.tmp_table_name = None session.commit() payload.update({ @@ -150,7 +150,6 @@ def handle_error(msg): return handle_error( 'Only `SELECT` statements can be used with the CREATE TABLE ' 'feature.') - return if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) query.tmp_table_name = 'tmp_{}_table_{}'.format( @@ -173,7 +172,7 @@ def handle_error(msg): return handle_error(msg) query.executed_sql = executed_sql - query.status = QueryStatus.RUNNING + query.status = QueryStatus.RUNNING.value query.start_running_time = utils.now_as_float() session.merge(query) session.commit() @@ -215,7 +214,7 @@ def handle_error(msg): conn.commit() conn.close() - if query.status == utils.QueryStatus.STOPPED: + if query.status == utils.QueryStatus.STOPPED.value: return json.dumps( { 'query_id': query.id, @@ -232,7 +231,7 @@ def handle_error(msg): query.rows = cdf.size query.progress = 100 - query.status = QueryStatus.SUCCESS + query.status = QueryStatus.SUCCESS.value if query.select_as_cta: query.select_sql = '{}'.format( database.select_star( diff --git a/superset/utils.py b/superset/utils.py index 8224843213..9c61ead450 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -11,6 +11,7 @@ from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.utils import formatdate +from enum import Enum, unique import functools import json import logging @@ -532,7 +533,8 @@ def ping_connection(connection, branch): connection.should_close_with_result = save_should_close_with_result -class QueryStatus(object): +@unique +class QueryStatus(Enum): """Enum-type class for query statuses""" STOPPED = 'stopped' diff --git a/superset/views/core.py b/superset/views/core.py index ec4cce1fb1..e31aa5eb3a 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1041,7 +1041,7 @@ def generate_json(self, datasource_type, datasource_id, form_data, return json_error_response(utils.error_msg_from_exception(e)) status = 200 - if payload.get('status') == QueryStatus.FAILED: + if payload.get('status') == QueryStatus.FAILED.value: status = 400 return json_success(viz_obj.json_dumps(payload), status=status) @@ -1086,7 +1086,7 @@ def annotation_json(self, layer_id): logging.exception(e) return json_error_response(utils.error_msg_from_exception(e)) status = 200 - if payload.get('status') == QueryStatus.FAILED: + if payload.get('status') == QueryStatus.FAILED.value: status = 400 return json_success(viz_obj.json_dumps(payload), status=status) @@ -2214,7 +2214,7 @@ def stop_query(self): db.session.query(Query) .filter_by(client_id=client_id).one() ) - query.status = utils.QueryStatus.STOPPED + query.status = utils.QueryStatus.STOPPED.value db.session.commit() except Exception: pass @@ -2261,7 +2261,7 @@ def sql_json(self): select_as_cta=request.form.get('select_as_cta') == 'true', start_time=utils.now_as_float(), tab_name=request.form.get('tab'), - status=QueryStatus.PENDING if async else QueryStatus.RUNNING, + status=QueryStatus.PENDING.value if async else QueryStatus.RUNNING.value, sql_editor_id=request.form.get('sql_editor_id'), tmp_table_name=tmp_table_name, user_id=int(g.user.get_id()), @@ -2292,7 +2292,7 @@ def sql_json(self): 'Tell your administrator to verify the availability of ' 'the message queue.' ) - query.status = QueryStatus.FAILED + query.status = QueryStatus.FAILED.value query.error_message = msg session.commit() return json_error_response('{}'.format(msg)) @@ -2320,7 +2320,7 @@ def sql_json(self): except Exception as e: logging.exception(e) return json_error_response('{}'.format(e)) - if data.get('status') == QueryStatus.FAILED: + if data.get('status') == QueryStatus.FAILED.value: return json_error_response(payload=data) return json_success(payload) diff --git a/superset/viz.py b/superset/viz.py index bb0bcf604d..d0bcc5a16b 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -111,7 +111,7 @@ def get_df(self, query_obj=None): # If the datetime format is unix, the parse will use the corresponding # parsing logic. if df is None or df.empty: - self.status = utils.QueryStatus.FAILED + self.status = utils.QueryStatus.FAILED.value if not self.error_message: self.error_message = 'No data.' return pd.DataFrame() @@ -278,7 +278,7 @@ def get_payload(self, force=False): logging.exception(e) if not self.error_message: self.error_message = str(e) - self.status = utils.QueryStatus.FAILED + self.status = utils.QueryStatus.FAILED.value data = None stacktrace = traceback.format_exc() @@ -286,7 +286,7 @@ def get_payload(self, force=False): data and cache_key and cache and - self.status != utils.QueryStatus.FAILED): + self.status != utils.QueryStatus.FAILED.value): cached_dttm = datetime.utcnow().isoformat().split('.')[0] try: cache_value = self.json_dumps({ diff --git a/tests/celery_tests.py b/tests/celery_tests.py index 591e793945..31fca878cc 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -187,7 +187,7 @@ def test_run_sync_query_cta(self): "SELECT name FROM ab_permission WHERE name='{}'".format(perm_name)) result2 = self.run_sql( db_id, sql_where, '2', tmp_table='tmp_table_2', cta='true') - self.assertEqual(QueryStatus.SUCCESS, result2['query']['state']) + self.assertEqual(QueryStatus.SUCCESS.value, result2['query']['state']) self.assertEqual([], result2['data']) self.assertEqual([], result2['columns']) query2 = self.get_query_by_id(result2['query']['serverId']) @@ -203,12 +203,12 @@ def test_run_sync_query_cta_no_data(self): sql_empty_result = 'SELECT * FROM ab_user WHERE id=666' result3 = self.run_sql( db_id, sql_empty_result, '3', tmp_table='tmp_table_3', cta='true') - self.assertEqual(QueryStatus.SUCCESS, result3['query']['state']) + self.assertEqual(QueryStatus.SUCCESS.value, result3['query']['state']) self.assertEqual([], result3['data']) self.assertEqual([], result3['columns']) query3 = self.get_query_by_id(result3['query']['serverId']) - self.assertEqual(QueryStatus.SUCCESS, query3.status) + self.assertEqual(QueryStatus.SUCCESS.value, query3.status) def test_run_async_query(self): main_db = self.get_main_database(db.session) @@ -218,15 +218,15 @@ def test_run_async_query(self): main_db.id, sql_where, '4', async='true', tmp_table='tmp_async_1', cta='true') assert result['query']['state'] in ( - QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS) + QueryStatus.PENDING.value, QueryStatus.RUNNING.value, QueryStatus.SUCCESS.value) time.sleep(1) query = self.get_query_by_id(result['query']['serverId']) df = pd.read_sql_query(query.select_sql, con=eng) - self.assertEqual(QueryStatus.SUCCESS, query.status) + self.assertEqual(QueryStatus.SUCCESS.value, query.status) self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records')) - self.assertEqual(QueryStatus.SUCCESS, query.status) + self.assertEqual(QueryStatus.SUCCESS.value, query.status) self.assertTrue('FROM tmp_async_1' in query.select_sql) self.assertTrue('LIMIT 666' in query.select_sql) self.assertEqual( diff --git a/tests/viz_tests.py b/tests/viz_tests.py index abf29adb62..16addab28b 100644 --- a/tests/viz_tests.py +++ b/tests/viz_tests.py @@ -48,7 +48,7 @@ def test_get_df_returns_empty_df(self): self.assertEqual(type(result), pd.DataFrame) self.assertTrue(result.empty) self.assertEqual(test_viz.error_message, 'No data.') - self.assertEqual(test_viz.status, utils.QueryStatus.FAILED) + self.assertEqual(test_viz.status, utils.QueryStatus.FAILED.value) def test_get_df_handles_dttm_col(self): datasource = Mock() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services