alexott commented on code in PR #27912:
URL: https://github.com/apache/airflow/pull/27912#discussion_r1032463169


##########
airflow/providers/common/sql/hooks/sql.py:
##########
@@ -30,6 +30,25 @@
 from airflow.version import version
 
 
+def has_scalar_return_value(sql: str | Iterable[str], return_last: bool, 
split_statements: bool):
+    """
+    Determines when scalar value should be returned.
+
+    Scalar value should be returned when:

Review Comment:
   I was thinking about it - I think that terminology is confusing a bit. When 
we think about `scalar` we think about one value. But execution of query always 
return an iterator of rows (IIRc). 
   
   So at the end we have following possibilities:
   
   * a list of query results (list of lists of rows) - when we have multiple 
queries
   * a single query result (list of rows) - when we have one query or when we 
use `return_last`
   
   maybe rename it to something like `single_result` or something like that ?



##########
airflow/providers/common/sql/operators/sql.py:
##########
@@ -225,54 +230,37 @@ def __init__(
         self.split_statements = split_statements
         self.return_last = return_last
 
-    @overload
-    def _process_output(
-        self, results: Any, description: Sequence[Sequence] | None, 
scalar_results: Literal[True]
-    ) -> Any:
-        pass
-
-    @overload
-    def _process_output(
-        self, results: list[Any], description: Sequence[Sequence] | None, 
scalar_results: Literal[False]
-    ) -> Any:
-        pass
-
-    def _process_output(
-        self, results: Any | list[Any], description: Sequence[Sequence] | 
None, scalar_results: bool
-    ) -> Any:
+    def _process_output(self, results: list[Any], descriptions: 
list[Sequence[Sequence] | None]) -> list[Any]:
         """
-        Can be overridden by the subclass in case some extra processing is 
needed.
+        Processes output before it is returned by the operator.
+
+        It can be overridden by the subclass in case some extra processing is 
needed.
         The "process_output" method can override the returned output - 
augmenting or processing the
         output as needed - the output returned will be returned as execute 
return value and if
-        do_xcom_push is set to True, it will be set as XCom returned
+        do_xcom_push is set to True, it will be set as XCom returned.
 
         :param results: results in the form of list of rows.
-        :param description: as returned by ``cur.description`` in the Python 
DBAPI
-        :param scalar_results: True if result is single scalar value rather 
than list of rows
+        :param descriptions: list of descriptions returned by 
``cur.description`` in the Python DBAPI
         """
         return results
 
     def execute(self, context):
         self.log.info("Executing: %s", self.sql)
         hook = self.get_db_hook()
-        if self.do_xcom_push:
-            output = hook.run(
-                sql=self.sql,
-                autocommit=self.autocommit,
-                parameters=self.parameters,
-                handler=self.handler,
-                split_statements=self.split_statements,
-                return_last=self.return_last,
-            )
-        else:
-            output = hook.run(
-                sql=self.sql,
-                autocommit=self.autocommit,
-                parameters=self.parameters,
-                split_statements=self.split_statements,
-            )
-
-        return self._process_output(output, hook.last_description, 
hook.scalar_return_last)
+        output = hook.run(
+            sql=self.sql,
+            autocommit=self.autocommit,
+            parameters=self.parameters,
+            handler=self.handler if self.do_xcom_push else None,
+            split_statements=self.split_statements,
+            return_last=self.return_last,
+        )
+        if has_scalar_return_value(self.sql, self.return_last, 
self.split_statements):
+            # For simplicity, we pass always list as input to _process_output, 
regardless if
+            # scalar is going to be returned, and we return the first element 
of the list in this case
+            # from the list returned by _process_output
+            return self._process_output([output], hook.descriptions)[0]

Review Comment:
   why we return first element? Usually we're interested in the last result 
(original behavior of DBSQL operator)



##########
airflow/providers/databricks/hooks/databricks_sql.py:
##########
@@ -163,38 +163,43 @@ def run(
         :param return_last: Whether to return result for only last statement 
or for all after split
         :return: return only result of the LAST SQL expression if handler was 
provided.
         """
-        self.scalar_return_last = isinstance(sql, str) and return_last
+        self.descriptions = []
         if isinstance(sql, str):
             if split_statements:
-                sql = self.split_sql_string(sql)
+                sql_list = [self.strip_sql_string(s) for s in 
self.split_sql_string(sql)]
             else:
-                sql = [self.strip_sql_string(sql)]
+                sql_list = [self.strip_sql_string(sql)]
+        else:
+            sql_list = [self.strip_sql_string(s) for s in sql]
 
-        if sql:
-            self.log.debug("Executing following statements against Databricks 
DB: %s", list(sql))
+        if sql_list:
+            self.log.debug("Executing following statements against Databricks 
DB: %s", sql_list)
         else:
             raise ValueError("List of SQL statements is empty")
 
         results = []
-        for sql_statement in sql:
+        for sql_statement in sql_list:
             # when using AAD tokens, it could expire if previous query run 
longer than token lifetime
             with closing(self.get_conn()) as conn:
                 self.set_autocommit(conn, autocommit)
 
                 with closing(conn.cursor()) as cur:
                     self._run_command(cur, sql_statement, parameters)
-
                     if handler is not None:
                         result = handler(cur)
-                        results.append(result)
-                    self.last_description = cur.description
+                        if has_scalar_return_value(sql, return_last, 
split_statements):
+                            results = [result]
+                            self.descriptions = [cur.description]
+                        else:
+                            results.append(result)
+                            self.descriptions.append(cur.description)
 
             self._sql_conn = None
 
         if handler is None:
             return None
-        elif self.scalar_return_last:
-            return results[-1]
+        if has_scalar_return_value(sql, return_last, split_statements):
+            return results[0]

Review Comment:
   Same here - we were always returning results of the last query



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to