This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 60c49ab2df Add more accurate typing for DbApiHook.run method (#31846)
60c49ab2df is described below
commit 60c49ab2dfabaf450b80a5c7569743dd383500a6
Author: Daniel Reeves <[email protected]>
AuthorDate: Tue Jul 18 17:46:20 2023 -0400
Add more accurate typing for DbApiHook.run method (#31846)
Co-authored-by: eladkal <[email protected]>
---
airflow/providers/apache/hive/hooks/hive.py | 2 +-
airflow/providers/apache/pinot/hooks/pinot.py | 4 +-
airflow/providers/common/sql/hooks/sql.py | 56 ++++++++++++++++++----
airflow/providers/common/sql/operators/sql.py | 4 +-
.../providers/databricks/hooks/databricks_sql.py | 35 ++++++++++++--
airflow/providers/exasol/hooks/exasol.py | 42 +++++++++++++---
airflow/providers/google/cloud/hooks/bigquery.py | 2 +-
.../providers/google/cloud/operators/cloud_sql.py | 4 +-
.../google/suite/transfers/sql_to_sheets.py | 2 +-
airflow/providers/neo4j/operators/neo4j.py | 4 +-
airflow/providers/presto/hooks/presto.py | 28 +++--------
airflow/providers/slack/transfers/sql_to_slack.py | 8 ++--
airflow/providers/snowflake/hooks/snowflake.py | 36 ++++++++++++--
airflow/providers/snowflake/operators/snowflake.py | 6 +--
.../snowflake/transfers/snowflake_to_slack.py | 4 +-
airflow/providers/trino/hooks/trino.py | 30 ++++--------
16 files changed, 181 insertions(+), 86 deletions(-)
diff --git a/airflow/providers/apache/hive/hooks/hive.py
b/airflow/providers/apache/hive/hooks/hive.py
index 89fcda466f..c67901445a 100644
--- a/airflow/providers/apache/hive/hooks/hive.py
+++ b/airflow/providers/apache/hive/hooks/hive.py
@@ -1014,7 +1014,7 @@ class HiveServer2Hook(DbApiHook):
self.log.info("Done. Loaded a total of %s rows.", i)
def get_records(
- self, sql: str | list[str], parameters: Iterable | Mapping | None =
None, **kwargs
+ self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] |
None = None, **kwargs
) -> Any:
"""
Get a set of records from a Hive query; optionally pass a 'schema'
kwarg to specify target schema.
diff --git a/airflow/providers/apache/pinot/hooks/pinot.py
b/airflow/providers/apache/pinot/hooks/pinot.py
index 1053e220af..d9fd9044a2 100644
--- a/airflow/providers/apache/pinot/hooks/pinot.py
+++ b/airflow/providers/apache/pinot/hooks/pinot.py
@@ -288,7 +288,7 @@ class PinotDbApiHook(DbApiHook):
return f"{conn_type}://{host}/{endpoint}"
def get_records(
- self, sql: str | list[str], parameters: Iterable | Mapping | None =
None, **kwargs
+ self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] |
None = None, **kwargs
) -> Any:
"""
Executes the sql and returns a set of records.
@@ -301,7 +301,7 @@ class PinotDbApiHook(DbApiHook):
cur.execute(sql)
return cur.fetchall()
- def get_first(self, sql: str | list[str], parameters: Iterable | Mapping |
None = None) -> Any:
+ def get_first(self, sql: str | list[str], parameters: Iterable |
Mapping[str, Any] | None = None) -> Any:
"""
Executes the sql and returns the first resulting row.
diff --git a/airflow/providers/common/sql/hooks/sql.py
b/airflow/providers/common/sql/hooks/sql.py
index 3f2f964a0b..d444a95287 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -18,7 +18,18 @@ from __future__ import annotations
from contextlib import closing
from datetime import datetime
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Protocol,
Sequence, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Iterable,
+ Mapping,
+ Protocol,
+ Sequence,
+ TypeVar,
+ cast,
+ overload,
+)
from urllib.parse import urlparse
import sqlparse
@@ -34,6 +45,9 @@ if TYPE_CHECKING:
from airflow.providers.openlineage.sqlparser import DatabaseInfo
+T = TypeVar("T")
+
+
def return_single_query_results(sql: str | Iterable[str], return_last: bool,
split_statements: bool):
"""
Determines when results of single query only should be returned.
@@ -184,7 +198,7 @@ class DbApiHook(BaseForDbApiHook):
engine_kwargs = {}
return create_engine(self.get_uri(), **engine_kwargs)
- def get_pandas_df(self, sql, parameters=None, **kwargs):
+ def get_pandas_df(self, sql, parameters: Iterable | Mapping[str, Any] |
None = None, **kwargs):
"""
Executes the sql and returns a pandas dataframe.
@@ -204,7 +218,9 @@ class DbApiHook(BaseForDbApiHook):
with closing(self.get_conn()) as conn:
return psql.read_sql(sql, con=conn, params=parameters, **kwargs)
- def get_pandas_df_by_chunks(self, sql, parameters=None, *, chunksize,
**kwargs):
+ def get_pandas_df_by_chunks(
+ self, sql, parameters: Iterable | Mapping[str, Any] | None = None, *,
chunksize: int | None, **kwargs
+ ):
"""
Executes the sql and returns a generator.
@@ -228,7 +244,7 @@ class DbApiHook(BaseForDbApiHook):
def get_records(
self,
sql: str | list[str],
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
) -> Any:
"""
Executes the sql and returns a set of records.
@@ -238,7 +254,7 @@ class DbApiHook(BaseForDbApiHook):
"""
return self.run(sql=sql, parameters=parameters,
handler=fetch_all_handler)
- def get_first(self, sql: str | list[str], parameters: Iterable | Mapping |
None = None) -> Any:
+ def get_first(self, sql: str | list[str], parameters: Iterable |
Mapping[str, Any] | None = None) -> Any:
"""
Executes the sql and returns the first resulting row.
@@ -268,15 +284,39 @@ class DbApiHook(BaseForDbApiHook):
return None
return self.descriptions[-1]
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: None = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ ) -> None:
+ ...
+
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: Callable[[Any], T] = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ ) -> T | list[T]:
+ ...
+
def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
- parameters: Iterable | Mapping | None = None,
- handler: Callable | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
+ handler: Callable[[Any], T] | None = None,
split_statements: bool = False,
return_last: bool = True,
- ) -> Any | list[Any] | None:
+ ) -> T | list[T] | None:
"""Run a command or a list of commands.
Pass a list of SQL statements to the sql parameter to get them to
diff --git a/airflow/providers/common/sql/operators/sql.py
b/airflow/providers/common/sql/operators/sql.py
index f1784f618e..bf2b38f055 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -762,7 +762,7 @@ class SQLCheckOperator(BaseSQLOperator):
sql: str,
conn_id: str | None = None,
database: str | None = None,
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(conn_id=conn_id, database=database, **kwargs)
@@ -1129,7 +1129,7 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin):
follow_task_ids_if_false: list[str],
conn_id: str = "default_conn_id",
database: str | None = None,
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(conn_id=conn_id, database=database, **kwargs)
diff --git a/airflow/providers/databricks/hooks/databricks_sql.py
b/airflow/providers/databricks/hooks/databricks_sql.py
index 674219e1c8..31c816f39b 100644
--- a/airflow/providers/databricks/hooks/databricks_sql.py
+++ b/airflow/providers/databricks/hooks/databricks_sql.py
@@ -18,7 +18,7 @@ from __future__ import annotations
from contextlib import closing
from copy import copy
-from typing import Any, Callable, Iterable, Mapping
+from typing import Any, Callable, Iterable, Mapping, TypeVar, overload
from databricks import sql # type: ignore[attr-defined]
from databricks.sql.client import Connection # type: ignore[attr-defined]
@@ -30,6 +30,9 @@ from airflow.providers.databricks.hooks.databricks_base
import BaseDatabricksHoo
LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")
+T = TypeVar("T")
+
+
class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
"""Hook to interact with Databricks SQL.
@@ -138,15 +141,39 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
)
return self._sql_conn
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: None = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ ) -> None:
+ ...
+
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: Callable[[Any], T] = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ ) -> T | list[T]:
+ ...
+
def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
- parameters: Iterable | Mapping | None = None,
- handler: Callable | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
+ handler: Callable[[Any], T] | None = None,
split_statements: bool = True,
return_last: bool = True,
- ) -> Any | list[Any] | None:
+ ) -> T | list[T] | None:
"""Runs a command or a list of commands.
Pass a list of SQL statements to the SQL parameter to get them to
diff --git a/airflow/providers/exasol/hooks/exasol.py
b/airflow/providers/exasol/hooks/exasol.py
index ac45a7915f..0911e39604 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -18,7 +18,7 @@
from __future__ import annotations
from contextlib import closing
-from typing import Any, Callable, Iterable, Mapping, Sequence
+from typing import Any, Callable, Iterable, Mapping, Sequence, TypeVar,
overload
import pandas as pd
import pyexasol
@@ -26,6 +26,8 @@ from pyexasol import ExaConnection, ExaStatement
from airflow.providers.common.sql.hooks.sql import DbApiHook,
return_single_query_results
+T = TypeVar("T")
+
class ExasolHook(DbApiHook):
"""Interact with Exasol.
@@ -66,7 +68,9 @@ class ExasolHook(DbApiHook):
conn = pyexasol.connect(**conn_args)
return conn
- def get_pandas_df(self, sql: str, parameters: dict | None = None,
**kwargs) -> pd.DataFrame:
+ def get_pandas_df(
+ self, sql, parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs
+ ) -> pd.DataFrame:
"""Execute the SQL and return a Pandas dataframe.
:param sql: The sql statement to be executed (str) or a list of
@@ -83,7 +87,7 @@ class ExasolHook(DbApiHook):
def get_records(
self,
sql: str | list[str],
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
) -> list[dict | tuple[Any, ...]]:
"""Execute the SQL and return a set of records.
@@ -95,7 +99,7 @@ class ExasolHook(DbApiHook):
with closing(conn.execute(sql, parameters)) as cur:
return cur.fetchall()
- def get_first(self, sql: str | list[str], parameters: Iterable | Mapping |
None = None) -> Any:
+ def get_first(self, sql: str | list[str], parameters: Iterable |
Mapping[str, Any] | None = None) -> Any:
"""Execute the SQL and return the first resulting row.
:param sql: the sql statement to be executed (str) or a list of
@@ -157,15 +161,39 @@ class ExasolHook(DbApiHook):
)
return cols
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: None = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ ) -> None:
+ ...
+
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: Callable[[Any], T] = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ ) -> T | list[T]:
+ ...
+
def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
- parameters: Iterable | Mapping | None = None,
- handler: Callable | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
+ handler: Callable[[Any], T] | None = None,
split_statements: bool = False,
return_last: bool = True,
- ) -> Any | list[Any] | None:
+ ) -> T | list[T] | None:
"""Run a command or a list of commands.
Pass a list of SQL statements to the SQL parameter to get them to
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index f76fd1b1e1..01c00046d1 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -241,7 +241,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
def get_pandas_df(
self,
sql: str,
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
dialect: str | None = None,
**kwargs,
) -> DataFrame:
diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py
b/airflow/providers/google/cloud/operators/cloud_sql.py
index c60a5bc09a..61275dbc29 100644
--- a/airflow/providers/google/cloud/operators/cloud_sql.py
+++ b/airflow/providers/google/cloud/operators/cloud_sql.py
@@ -18,7 +18,7 @@
"""This module contains Google Cloud SQL operators."""
from __future__ import annotations
-from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence
from googleapiclient.errors import HttpError
@@ -1189,7 +1189,7 @@ class
CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
*,
sql: str | Iterable[str],
autocommit: bool = False,
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
gcp_conn_id: str = "google_cloud_default",
gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
sql_proxy_binary_path: str | None = None,
diff --git a/airflow/providers/google/suite/transfers/sql_to_sheets.py
b/airflow/providers/google/suite/transfers/sql_to_sheets.py
index f8ee694408..f38d9b230a 100644
--- a/airflow/providers/google/suite/transfers/sql_to_sheets.py
+++ b/airflow/providers/google/suite/transfers/sql_to_sheets.py
@@ -68,7 +68,7 @@ class SQLToGoogleSheetsOperator(BaseSQLOperator):
sql: str,
spreadsheet_id: str,
sql_conn_id: str,
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
database: str | None = None,
spreadsheet_range: str = "Sheet1",
gcp_conn_id: str = "google_cloud_default",
diff --git a/airflow/providers/neo4j/operators/neo4j.py
b/airflow/providers/neo4j/operators/neo4j.py
index 3c52784230..6cd36f9472 100644
--- a/airflow/providers/neo4j/operators/neo4j.py
+++ b/airflow/providers/neo4j/operators/neo4j.py
@@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence
from airflow.models import BaseOperator
from airflow.providers.neo4j.hooks.neo4j import Neo4jHook
@@ -46,7 +46,7 @@ class Neo4jOperator(BaseOperator):
*,
sql: str,
neo4j_conn_id: str = "neo4j_default",
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
diff --git a/airflow/providers/presto/hooks/presto.py
b/airflow/providers/presto/hooks/presto.py
index 477b4cd797..7ce9021807 100644
--- a/airflow/providers/presto/hooks/presto.py
+++ b/airflow/providers/presto/hooks/presto.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import json
import os
-from typing import Any, Callable, Iterable, Mapping
+from typing import Any, Iterable, Mapping, TypeVar
import prestodb
from prestodb.exceptions import DatabaseError
@@ -31,6 +31,8 @@ from airflow.models import Connection
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING,
DEFAULT_FORMAT_PREFIX
+T = TypeVar("T")
+
def generate_presto_client_info() -> str:
"""Return json string with dag_id, task_id, execution_date and
try_number."""
@@ -136,7 +138,7 @@ class PrestoHook(DbApiHook):
def get_records(
self,
sql: str | list[str] = "",
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
) -> Any:
if not isinstance(sql, str):
raise ValueError(f"The sql in Presto Hook must be a string and is
{sql}!")
@@ -145,7 +147,9 @@ class PrestoHook(DbApiHook):
except DatabaseError as e:
raise PrestoException(e)
- def get_first(self, sql: str | list[str] = "", parameters: Iterable |
Mapping | None = None) -> Any:
+ def get_first(
+ self, sql: str | list[str] = "", parameters: Iterable | Mapping[str,
Any] | None = None
+ ) -> Any:
if not isinstance(sql, str):
raise ValueError(f"The sql in Presto Hook must be a string and is
{sql}!")
try:
@@ -170,24 +174,6 @@ class PrestoHook(DbApiHook):
df = pandas.DataFrame(**kwargs)
return df
- def run(
- self,
- sql: str | Iterable[str],
- autocommit: bool = False,
- parameters: Iterable | Mapping | None = None,
- handler: Callable | None = None,
- split_statements: bool = False,
- return_last: bool = True,
- ) -> Any | list[Any] | None:
- return super().run(
- sql=sql,
- autocommit=autocommit,
- parameters=parameters,
- handler=handler,
- split_statements=split_statements,
- return_last=return_last,
- )
-
def insert_rows(
self,
table: str,
diff --git a/airflow/providers/slack/transfers/sql_to_slack.py
b/airflow/providers/slack/transfers/sql_to_slack.py
index caac3eb7d1..97017c80d7 100644
--- a/airflow/providers/slack/transfers/sql_to_slack.py
+++ b/airflow/providers/slack/transfers/sql_to_slack.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence
from pandas import DataFrame
from tabulate import tabulate
@@ -51,7 +51,7 @@ class BaseSqlToSlackOperator(BaseOperator):
sql: str,
sql_conn_id: str,
sql_hook_params: dict | None = None,
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -125,7 +125,7 @@ class SqlToSlackOperator(BaseSqlToSlackOperator):
slack_channel: str | None = None,
slack_message: str,
results_df_name: str = "results_df",
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
) -> None:
@@ -246,7 +246,7 @@ class SqlToSlackApiFileOperator(BaseSqlToSlackOperator):
sql: str,
sql_conn_id: str,
sql_hook_params: dict | None = None,
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
slack_conn_id: str,
slack_filename: str,
slack_channels: str | Sequence[str] | None = None,
diff --git a/airflow/providers/snowflake/hooks/snowflake.py
b/airflow/providers/snowflake/hooks/snowflake.py
index 59199cf8cd..76b39be6ac 100644
--- a/airflow/providers/snowflake/hooks/snowflake.py
+++ b/airflow/providers/snowflake/hooks/snowflake.py
@@ -22,7 +22,7 @@ from contextlib import closing, contextmanager
from functools import wraps
from io import StringIO
from pathlib import Path
-from typing import Any, Callable, Iterable, Mapping
+from typing import Any, Callable, Iterable, Mapping, TypeVar, overload
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
@@ -35,6 +35,8 @@ from airflow import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook,
return_single_query_results
from airflow.utils.strings import to_boolean
+T = TypeVar("T")
+
def _try_to_boolean(value: Any):
if isinstance(value, (str, type(None))):
@@ -321,16 +323,42 @@ class SnowflakeHook(DbApiHook):
def get_autocommit(self, conn):
return getattr(conn, "autocommit_mode", False)
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: None = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ return_dictionaries: bool = ...,
+ ) -> None:
+ ...
+
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: Callable[[Any], T] = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ return_dictionaries: bool = ...,
+ ) -> T | list[T]:
+ ...
+
def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
- parameters: Iterable | Mapping | None = None,
- handler: Callable | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
+ handler: Callable[[Any], T] | None = None,
split_statements: bool = True,
return_last: bool = True,
return_dictionaries: bool = False,
- ) -> Any | list[Any] | None:
+ ) -> T | list[T] | None:
"""Runs a command or a list of commands.
Pass a list of SQL statements to the SQL parameter to get them to
diff --git a/airflow/providers/snowflake/operators/snowflake.py
b/airflow/providers/snowflake/operators/snowflake.py
index 29b57e54eb..1de82218ff 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -201,7 +201,7 @@ class SnowflakeCheckOperator(SQLCheckOperator):
*,
sql: str,
snowflake_conn_id: str = "snowflake_default",
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: str | None = None,
@@ -266,7 +266,7 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
pass_value: Any,
tolerance: Any = None,
snowflake_conn_id: str = "snowflake_default",
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: str | None = None,
@@ -341,7 +341,7 @@ class
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
date_filter_column: str = "ds",
days_back: SupportsAbs[int] = -7,
snowflake_conn_id: str = "snowflake_default",
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: str | None = None,
diff --git a/airflow/providers/snowflake/transfers/snowflake_to_slack.py
b/airflow/providers/snowflake/transfers/snowflake_to_slack.py
index 50d38b9471..8e818f4a4b 100644
--- a/airflow/providers/snowflake/transfers/snowflake_to_slack.py
+++ b/airflow/providers/snowflake/transfers/snowflake_to_slack.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import warnings
-from typing import Iterable, Mapping, Sequence
+from typing import Any, Iterable, Mapping, Sequence
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.slack.transfers.sql_to_slack import SqlToSlackOperator
@@ -68,7 +68,7 @@ class SnowflakeToSlackOperator(SqlToSlackOperator):
snowflake_conn_id: str = "snowflake_default",
slack_conn_id: str = "slack_default",
results_df_name: str = "results_df",
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
warehouse: str | None = None,
database: str | None = None,
schema: str | None = None,
diff --git a/airflow/providers/trino/hooks/trino.py
b/airflow/providers/trino/hooks/trino.py
index 67ee432ca3..5144978dab 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import json
import os
-from typing import Any, Callable, Iterable, Mapping
+from typing import Any, Iterable, Mapping, TypeVar
import trino
from trino.exceptions import DatabaseError
@@ -31,6 +31,8 @@ from airflow.models import Connection
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING,
DEFAULT_FORMAT_PREFIX
+T = TypeVar("T")
+
def generate_trino_client_info() -> str:
"""Return json string with dag_id, task_id, execution_date and
try_number."""
@@ -154,7 +156,7 @@ class TrinoHook(DbApiHook):
def get_records(
self,
sql: str | list[str] = "",
- parameters: Iterable | Mapping | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
) -> Any:
if not isinstance(sql, str):
raise ValueError(f"The sql in Trino Hook must be a string and is
{sql}!")
@@ -163,7 +165,9 @@ class TrinoHook(DbApiHook):
except DatabaseError as e:
raise TrinoException(e)
- def get_first(self, sql: str | list[str] = "", parameters: Iterable |
Mapping | None = None) -> Any:
+ def get_first(
+ self, sql: str | list[str] = "", parameters: Iterable | Mapping[str,
Any] | None = None
+ ) -> Any:
if not isinstance(sql, str):
raise ValueError(f"The sql in Trino Hook must be a string and is
{sql}!")
try:
@@ -172,7 +176,7 @@ class TrinoHook(DbApiHook):
raise TrinoException(e)
def get_pandas_df(
- self, sql: str = "", parameters: Iterable | Mapping | None = None,
**kwargs
+ self, sql: str = "", parameters: Iterable | Mapping[str, Any] | None =
None, **kwargs
): # type: ignore[override]
import pandas
@@ -190,24 +194,6 @@ class TrinoHook(DbApiHook):
df = pandas.DataFrame(**kwargs)
return df
- def run(
- self,
- sql: str | Iterable[str],
- autocommit: bool = False,
- parameters: Iterable | Mapping | None = None,
- handler: Callable | None = None,
- split_statements: bool = False,
- return_last: bool = True,
- ) -> Any | list[Any] | None:
- return super().run(
- sql=sql,
- autocommit=autocommit,
- parameters=parameters,
- handler=handler,
- split_statements=split_statements,
- return_last=return_last,
- )
-
def insert_rows(
self,
table: str,