This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 60c49ab2df Add more accurate typing for DbApiHook.run method (#31846)
60c49ab2df is described below

commit 60c49ab2dfabaf450b80a5c7569743dd383500a6
Author: Daniel Reeves <[email protected]>
AuthorDate: Tue Jul 18 17:46:20 2023 -0400

    Add more accurate typing for DbApiHook.run method (#31846)
    
    Co-authored-by: eladkal <[email protected]>
---
 airflow/providers/apache/hive/hooks/hive.py        |  2 +-
 airflow/providers/apache/pinot/hooks/pinot.py      |  4 +-
 airflow/providers/common/sql/hooks/sql.py          | 56 ++++++++++++++++++----
 airflow/providers/common/sql/operators/sql.py      |  4 +-
 .../providers/databricks/hooks/databricks_sql.py   | 35 ++++++++++++--
 airflow/providers/exasol/hooks/exasol.py           | 42 +++++++++++++---
 airflow/providers/google/cloud/hooks/bigquery.py   |  2 +-
 .../providers/google/cloud/operators/cloud_sql.py  |  4 +-
 .../google/suite/transfers/sql_to_sheets.py        |  2 +-
 airflow/providers/neo4j/operators/neo4j.py         |  4 +-
 airflow/providers/presto/hooks/presto.py           | 28 +++--------
 airflow/providers/slack/transfers/sql_to_slack.py  |  8 ++--
 airflow/providers/snowflake/hooks/snowflake.py     | 36 ++++++++++++--
 airflow/providers/snowflake/operators/snowflake.py |  6 +--
 .../snowflake/transfers/snowflake_to_slack.py      |  4 +-
 airflow/providers/trino/hooks/trino.py             | 30 ++++--------
 16 files changed, 181 insertions(+), 86 deletions(-)

diff --git a/airflow/providers/apache/hive/hooks/hive.py 
b/airflow/providers/apache/hive/hooks/hive.py
index 89fcda466f..c67901445a 100644
--- a/airflow/providers/apache/hive/hooks/hive.py
+++ b/airflow/providers/apache/hive/hooks/hive.py
@@ -1014,7 +1014,7 @@ class HiveServer2Hook(DbApiHook):
         self.log.info("Done. Loaded a total of %s rows.", i)
 
     def get_records(
-        self, sql: str | list[str], parameters: Iterable | Mapping | None = 
None, **kwargs
+        self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | 
None = None, **kwargs
     ) -> Any:
         """
         Get a set of records from a Hive query; optionally pass a 'schema' 
kwarg to specify target schema.
diff --git a/airflow/providers/apache/pinot/hooks/pinot.py 
b/airflow/providers/apache/pinot/hooks/pinot.py
index 1053e220af..d9fd9044a2 100644
--- a/airflow/providers/apache/pinot/hooks/pinot.py
+++ b/airflow/providers/apache/pinot/hooks/pinot.py
@@ -288,7 +288,7 @@ class PinotDbApiHook(DbApiHook):
         return f"{conn_type}://{host}/{endpoint}"
 
     def get_records(
-        self, sql: str | list[str], parameters: Iterable | Mapping | None = 
None, **kwargs
+        self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | 
None = None, **kwargs
     ) -> Any:
         """
         Executes the sql and returns a set of records.
@@ -301,7 +301,7 @@ class PinotDbApiHook(DbApiHook):
             cur.execute(sql)
             return cur.fetchall()
 
-    def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | 
None = None) -> Any:
+    def get_first(self, sql: str | list[str], parameters: Iterable | 
Mapping[str, Any] | None = None) -> Any:
         """
         Executes the sql and returns the first resulting row.
 
diff --git a/airflow/providers/common/sql/hooks/sql.py 
b/airflow/providers/common/sql/hooks/sql.py
index 3f2f964a0b..d444a95287 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -18,7 +18,18 @@ from __future__ import annotations
 
 from contextlib import closing
 from datetime import datetime
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Protocol, 
Sequence, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Iterable,
+    Mapping,
+    Protocol,
+    Sequence,
+    TypeVar,
+    cast,
+    overload,
+)
 from urllib.parse import urlparse
 
 import sqlparse
@@ -34,6 +45,9 @@ if TYPE_CHECKING:
     from airflow.providers.openlineage.sqlparser import DatabaseInfo
 
 
+T = TypeVar("T")
+
+
 def return_single_query_results(sql: str | Iterable[str], return_last: bool, 
split_statements: bool):
     """
     Determines when results of single query only should be returned.
@@ -184,7 +198,7 @@ class DbApiHook(BaseForDbApiHook):
             engine_kwargs = {}
         return create_engine(self.get_uri(), **engine_kwargs)
 
-    def get_pandas_df(self, sql, parameters=None, **kwargs):
+    def get_pandas_df(self, sql, parameters: Iterable | Mapping[str, Any] | 
None = None, **kwargs):
         """
         Executes the sql and returns a pandas dataframe.
 
@@ -204,7 +218,9 @@ class DbApiHook(BaseForDbApiHook):
         with closing(self.get_conn()) as conn:
             return psql.read_sql(sql, con=conn, params=parameters, **kwargs)
 
-    def get_pandas_df_by_chunks(self, sql, parameters=None, *, chunksize, 
**kwargs):
+    def get_pandas_df_by_chunks(
+        self, sql, parameters: Iterable | Mapping[str, Any] | None = None, *, 
chunksize: int | None, **kwargs
+    ):
         """
         Executes the sql and returns a generator.
 
@@ -228,7 +244,7 @@ class DbApiHook(BaseForDbApiHook):
     def get_records(
         self,
         sql: str | list[str],
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
     ) -> Any:
         """
         Executes the sql and returns a set of records.
@@ -238,7 +254,7 @@ class DbApiHook(BaseForDbApiHook):
         """
         return self.run(sql=sql, parameters=parameters, 
handler=fetch_all_handler)
 
-    def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | 
None = None) -> Any:
+    def get_first(self, sql: str | list[str], parameters: Iterable | 
Mapping[str, Any] | None = None) -> Any:
         """
         Executes the sql and returns the first resulting row.
 
@@ -268,15 +284,39 @@ class DbApiHook(BaseForDbApiHook):
             return None
         return self.descriptions[-1]
 
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: None = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+    ) -> None:
+        ...
+
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: Callable[[Any], T] = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+    ) -> T | list[T]:
+        ...
+
     def run(
         self,
         sql: str | Iterable[str],
         autocommit: bool = False,
-        parameters: Iterable | Mapping | None = None,
-        handler: Callable | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
+        handler: Callable[[Any], T] | None = None,
         split_statements: bool = False,
         return_last: bool = True,
-    ) -> Any | list[Any] | None:
+    ) -> T | list[T] | None:
         """Run a command or a list of commands.
 
         Pass a list of SQL statements to the sql parameter to get them to
diff --git a/airflow/providers/common/sql/operators/sql.py 
b/airflow/providers/common/sql/operators/sql.py
index f1784f618e..bf2b38f055 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -762,7 +762,7 @@ class SQLCheckOperator(BaseSQLOperator):
         sql: str,
         conn_id: str | None = None,
         database: str | None = None,
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(conn_id=conn_id, database=database, **kwargs)
@@ -1129,7 +1129,7 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin):
         follow_task_ids_if_false: list[str],
         conn_id: str = "default_conn_id",
         database: str | None = None,
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(conn_id=conn_id, database=database, **kwargs)
diff --git a/airflow/providers/databricks/hooks/databricks_sql.py 
b/airflow/providers/databricks/hooks/databricks_sql.py
index 674219e1c8..31c816f39b 100644
--- a/airflow/providers/databricks/hooks/databricks_sql.py
+++ b/airflow/providers/databricks/hooks/databricks_sql.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 from contextlib import closing
 from copy import copy
-from typing import Any, Callable, Iterable, Mapping
+from typing import Any, Callable, Iterable, Mapping, TypeVar, overload
 
 from databricks import sql  # type: ignore[attr-defined]
 from databricks.sql.client import Connection  # type: ignore[attr-defined]
@@ -30,6 +30,9 @@ from airflow.providers.databricks.hooks.databricks_base 
import BaseDatabricksHoo
 LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")
 
 
+T = TypeVar("T")
+
+
 class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
     """Hook to interact with Databricks SQL.
 
@@ -138,15 +141,39 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
             )
         return self._sql_conn
 
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: None = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+    ) -> None:
+        ...
+
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: Callable[[Any], T] = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+    ) -> T | list[T]:
+        ...
+
     def run(
         self,
         sql: str | Iterable[str],
         autocommit: bool = False,
-        parameters: Iterable | Mapping | None = None,
-        handler: Callable | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
+        handler: Callable[[Any], T] | None = None,
         split_statements: bool = True,
         return_last: bool = True,
-    ) -> Any | list[Any] | None:
+    ) -> T | list[T] | None:
         """Runs a command or a list of commands.
 
         Pass a list of SQL statements to the SQL parameter to get them to
diff --git a/airflow/providers/exasol/hooks/exasol.py 
b/airflow/providers/exasol/hooks/exasol.py
index ac45a7915f..0911e39604 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 from contextlib import closing
-from typing import Any, Callable, Iterable, Mapping, Sequence
+from typing import Any, Callable, Iterable, Mapping, Sequence, TypeVar, 
overload
 
 import pandas as pd
 import pyexasol
@@ -26,6 +26,8 @@ from pyexasol import ExaConnection, ExaStatement
 
 from airflow.providers.common.sql.hooks.sql import DbApiHook, 
return_single_query_results
 
+T = TypeVar("T")
+
 
 class ExasolHook(DbApiHook):
     """Interact with Exasol.
@@ -66,7 +68,9 @@ class ExasolHook(DbApiHook):
         conn = pyexasol.connect(**conn_args)
         return conn
 
-    def get_pandas_df(self, sql: str, parameters: dict | None = None, 
**kwargs) -> pd.DataFrame:
+    def get_pandas_df(
+        self, sql, parameters: Iterable | Mapping[str, Any] | None = None, 
**kwargs
+    ) -> pd.DataFrame:
         """Execute the SQL and return a Pandas dataframe.
 
         :param sql: The sql statement to be executed (str) or a list of
@@ -83,7 +87,7 @@ class ExasolHook(DbApiHook):
     def get_records(
         self,
         sql: str | list[str],
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
     ) -> list[dict | tuple[Any, ...]]:
         """Execute the SQL and return a set of records.
 
@@ -95,7 +99,7 @@ class ExasolHook(DbApiHook):
             with closing(conn.execute(sql, parameters)) as cur:
                 return cur.fetchall()
 
-    def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | 
None = None) -> Any:
+    def get_first(self, sql: str | list[str], parameters: Iterable | 
Mapping[str, Any] | None = None) -> Any:
         """Execute the SQL and return the first resulting row.
 
         :param sql: the sql statement to be executed (str) or a list of
@@ -157,15 +161,39 @@ class ExasolHook(DbApiHook):
             )
         return cols
 
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: None = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+    ) -> None:
+        ...
+
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: Callable[[Any], T] = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+    ) -> T | list[T]:
+        ...
+
     def run(
         self,
         sql: str | Iterable[str],
         autocommit: bool = False,
-        parameters: Iterable | Mapping | None = None,
-        handler: Callable | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
+        handler: Callable[[Any], T] | None = None,
         split_statements: bool = False,
         return_last: bool = True,
-    ) -> Any | list[Any] | None:
+    ) -> T | list[T] | None:
         """Run a command or a list of commands.
 
         Pass a list of SQL statements to the SQL parameter to get them to
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py 
b/airflow/providers/google/cloud/hooks/bigquery.py
index f76fd1b1e1..01c00046d1 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -241,7 +241,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
     def get_pandas_df(
         self,
         sql: str,
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         dialect: str | None = None,
         **kwargs,
     ) -> DataFrame:
diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py 
b/airflow/providers/google/cloud/operators/cloud_sql.py
index c60a5bc09a..61275dbc29 100644
--- a/airflow/providers/google/cloud/operators/cloud_sql.py
+++ b/airflow/providers/google/cloud/operators/cloud_sql.py
@@ -18,7 +18,7 @@
 """This module contains Google Cloud SQL operators."""
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence
 
 from googleapiclient.errors import HttpError
 
@@ -1189,7 +1189,7 @@ class 
CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
         *,
         sql: str | Iterable[str],
         autocommit: bool = False,
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         gcp_conn_id: str = "google_cloud_default",
         gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
         sql_proxy_binary_path: str | None = None,
diff --git a/airflow/providers/google/suite/transfers/sql_to_sheets.py 
b/airflow/providers/google/suite/transfers/sql_to_sheets.py
index f8ee694408..f38d9b230a 100644
--- a/airflow/providers/google/suite/transfers/sql_to_sheets.py
+++ b/airflow/providers/google/suite/transfers/sql_to_sheets.py
@@ -68,7 +68,7 @@ class SQLToGoogleSheetsOperator(BaseSQLOperator):
         sql: str,
         spreadsheet_id: str,
         sql_conn_id: str,
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         database: str | None = None,
         spreadsheet_range: str = "Sheet1",
         gcp_conn_id: str = "google_cloud_default",
diff --git a/airflow/providers/neo4j/operators/neo4j.py 
b/airflow/providers/neo4j/operators/neo4j.py
index 3c52784230..6cd36f9472 100644
--- a/airflow/providers/neo4j/operators/neo4j.py
+++ b/airflow/providers/neo4j/operators/neo4j.py
@@ -17,7 +17,7 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence
 
 from airflow.models import BaseOperator
 from airflow.providers.neo4j.hooks.neo4j import Neo4jHook
@@ -46,7 +46,7 @@ class Neo4jOperator(BaseOperator):
         *,
         sql: str,
         neo4j_conn_id: str = "neo4j_default",
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
diff --git a/airflow/providers/presto/hooks/presto.py 
b/airflow/providers/presto/hooks/presto.py
index 477b4cd797..7ce9021807 100644
--- a/airflow/providers/presto/hooks/presto.py
+++ b/airflow/providers/presto/hooks/presto.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 import json
 import os
-from typing import Any, Callable, Iterable, Mapping
+from typing import Any, Iterable, Mapping, TypeVar
 
 import prestodb
 from prestodb.exceptions import DatabaseError
@@ -31,6 +31,8 @@ from airflow.models import Connection
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING, 
DEFAULT_FORMAT_PREFIX
 
+T = TypeVar("T")
+
 
 def generate_presto_client_info() -> str:
     """Return json string with dag_id, task_id, execution_date and 
try_number."""
@@ -136,7 +138,7 @@ class PrestoHook(DbApiHook):
     def get_records(
         self,
         sql: str | list[str] = "",
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
     ) -> Any:
         if not isinstance(sql, str):
             raise ValueError(f"The sql in Presto Hook must be a string and is 
{sql}!")
@@ -145,7 +147,9 @@ class PrestoHook(DbApiHook):
         except DatabaseError as e:
             raise PrestoException(e)
 
-    def get_first(self, sql: str | list[str] = "", parameters: Iterable | 
Mapping | None = None) -> Any:
+    def get_first(
+        self, sql: str | list[str] = "", parameters: Iterable | Mapping[str, 
Any] | None = None
+    ) -> Any:
         if not isinstance(sql, str):
             raise ValueError(f"The sql in Presto Hook must be a string and is 
{sql}!")
         try:
@@ -170,24 +174,6 @@ class PrestoHook(DbApiHook):
             df = pandas.DataFrame(**kwargs)
         return df
 
-    def run(
-        self,
-        sql: str | Iterable[str],
-        autocommit: bool = False,
-        parameters: Iterable | Mapping | None = None,
-        handler: Callable | None = None,
-        split_statements: bool = False,
-        return_last: bool = True,
-    ) -> Any | list[Any] | None:
-        return super().run(
-            sql=sql,
-            autocommit=autocommit,
-            parameters=parameters,
-            handler=handler,
-            split_statements=split_statements,
-            return_last=return_last,
-        )
-
     def insert_rows(
         self,
         table: str,
diff --git a/airflow/providers/slack/transfers/sql_to_slack.py 
b/airflow/providers/slack/transfers/sql_to_slack.py
index caac3eb7d1..97017c80d7 100644
--- a/airflow/providers/slack/transfers/sql_to_slack.py
+++ b/airflow/providers/slack/transfers/sql_to_slack.py
@@ -17,7 +17,7 @@
 from __future__ import annotations
 
 from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence
 
 from pandas import DataFrame
 from tabulate import tabulate
@@ -51,7 +51,7 @@ class BaseSqlToSlackOperator(BaseOperator):
         sql: str,
         sql_conn_id: str,
         sql_hook_params: dict | None = None,
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -125,7 +125,7 @@ class SqlToSlackOperator(BaseSqlToSlackOperator):
         slack_channel: str | None = None,
         slack_message: str,
         results_df_name: str = "results_df",
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         **kwargs,
     ) -> None:
 
@@ -246,7 +246,7 @@ class SqlToSlackApiFileOperator(BaseSqlToSlackOperator):
         sql: str,
         sql_conn_id: str,
         sql_hook_params: dict | None = None,
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         slack_conn_id: str,
         slack_filename: str,
         slack_channels: str | Sequence[str] | None = None,
diff --git a/airflow/providers/snowflake/hooks/snowflake.py 
b/airflow/providers/snowflake/hooks/snowflake.py
index 59199cf8cd..76b39be6ac 100644
--- a/airflow/providers/snowflake/hooks/snowflake.py
+++ b/airflow/providers/snowflake/hooks/snowflake.py
@@ -22,7 +22,7 @@ from contextlib import closing, contextmanager
 from functools import wraps
 from io import StringIO
 from pathlib import Path
-from typing import Any, Callable, Iterable, Mapping
+from typing import Any, Callable, Iterable, Mapping, TypeVar, overload
 
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
@@ -35,6 +35,8 @@ from airflow import AirflowException
 from airflow.providers.common.sql.hooks.sql import DbApiHook, 
return_single_query_results
 from airflow.utils.strings import to_boolean
 
+T = TypeVar("T")
+
 
 def _try_to_boolean(value: Any):
     if isinstance(value, (str, type(None))):
@@ -321,16 +323,42 @@ class SnowflakeHook(DbApiHook):
     def get_autocommit(self, conn):
         return getattr(conn, "autocommit_mode", False)
 
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: None = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+        return_dictionaries: bool = ...,
+    ) -> None:
+        ...
+
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: Callable[[Any], T] = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+        return_dictionaries: bool = ...,
+    ) -> T | list[T]:
+        ...
+
     def run(
         self,
         sql: str | Iterable[str],
         autocommit: bool = False,
-        parameters: Iterable | Mapping | None = None,
-        handler: Callable | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
+        handler: Callable[[Any], T] | None = None,
         split_statements: bool = True,
         return_last: bool = True,
         return_dictionaries: bool = False,
-    ) -> Any | list[Any] | None:
+    ) -> T | list[T] | None:
         """Runs a command or a list of commands.
 
         Pass a list of SQL statements to the SQL parameter to get them to
diff --git a/airflow/providers/snowflake/operators/snowflake.py 
b/airflow/providers/snowflake/operators/snowflake.py
index 29b57e54eb..1de82218ff 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -201,7 +201,7 @@ class SnowflakeCheckOperator(SQLCheckOperator):
         *,
         sql: str,
         snowflake_conn_id: str = "snowflake_default",
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         autocommit: bool = True,
         do_xcom_push: bool = True,
         warehouse: str | None = None,
@@ -266,7 +266,7 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
         pass_value: Any,
         tolerance: Any = None,
         snowflake_conn_id: str = "snowflake_default",
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         autocommit: bool = True,
         do_xcom_push: bool = True,
         warehouse: str | None = None,
@@ -341,7 +341,7 @@ class 
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
         date_filter_column: str = "ds",
         days_back: SupportsAbs[int] = -7,
         snowflake_conn_id: str = "snowflake_default",
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         autocommit: bool = True,
         do_xcom_push: bool = True,
         warehouse: str | None = None,
diff --git a/airflow/providers/snowflake/transfers/snowflake_to_slack.py 
b/airflow/providers/snowflake/transfers/snowflake_to_slack.py
index 50d38b9471..8e818f4a4b 100644
--- a/airflow/providers/snowflake/transfers/snowflake_to_slack.py
+++ b/airflow/providers/snowflake/transfers/snowflake_to_slack.py
@@ -17,7 +17,7 @@
 from __future__ import annotations
 
 import warnings
-from typing import Iterable, Mapping, Sequence
+from typing import Any, Iterable, Mapping, Sequence
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.slack.transfers.sql_to_slack import SqlToSlackOperator
@@ -68,7 +68,7 @@ class SnowflakeToSlackOperator(SqlToSlackOperator):
         snowflake_conn_id: str = "snowflake_default",
         slack_conn_id: str = "slack_default",
         results_df_name: str = "results_df",
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         warehouse: str | None = None,
         database: str | None = None,
         schema: str | None = None,
diff --git a/airflow/providers/trino/hooks/trino.py 
b/airflow/providers/trino/hooks/trino.py
index 67ee432ca3..5144978dab 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 import json
 import os
-from typing import Any, Callable, Iterable, Mapping
+from typing import Any, Iterable, Mapping, TypeVar
 
 import trino
 from trino.exceptions import DatabaseError
@@ -31,6 +31,8 @@ from airflow.models import Connection
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING, 
DEFAULT_FORMAT_PREFIX
 
+T = TypeVar("T")
+
 
 def generate_trino_client_info() -> str:
     """Return json string with dag_id, task_id, execution_date and 
try_number."""
@@ -154,7 +156,7 @@ class TrinoHook(DbApiHook):
     def get_records(
         self,
         sql: str | list[str] = "",
-        parameters: Iterable | Mapping | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
     ) -> Any:
         if not isinstance(sql, str):
             raise ValueError(f"The sql in Trino Hook must be a string and is 
{sql}!")
@@ -163,7 +165,9 @@ class TrinoHook(DbApiHook):
         except DatabaseError as e:
             raise TrinoException(e)
 
-    def get_first(self, sql: str | list[str] = "", parameters: Iterable | 
Mapping | None = None) -> Any:
+    def get_first(
+        self, sql: str | list[str] = "", parameters: Iterable | Mapping[str, 
Any] | None = None
+    ) -> Any:
         if not isinstance(sql, str):
             raise ValueError(f"The sql in Trino Hook must be a string and is 
{sql}!")
         try:
@@ -172,7 +176,7 @@ class TrinoHook(DbApiHook):
             raise TrinoException(e)
 
     def get_pandas_df(
-        self, sql: str = "", parameters: Iterable | Mapping | None = None, 
**kwargs
+        self, sql: str = "", parameters: Iterable | Mapping[str, Any] | None = 
None, **kwargs
     ):  # type: ignore[override]
         import pandas
 
@@ -190,24 +194,6 @@ class TrinoHook(DbApiHook):
             df = pandas.DataFrame(**kwargs)
         return df
 
-    def run(
-        self,
-        sql: str | Iterable[str],
-        autocommit: bool = False,
-        parameters: Iterable | Mapping | None = None,
-        handler: Callable | None = None,
-        split_statements: bool = False,
-        return_last: bool = True,
-    ) -> Any | list[Any] | None:
-        return super().run(
-            sql=sql,
-            autocommit=autocommit,
-            parameters=parameters,
-            handler=handler,
-            split_statements=split_statements,
-            return_last=return_last,
-        )
-
     def insert_rows(
         self,
         table: str,

Reply via email to