This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6d182beec6 Use a single statement with multiple contexts instead of
nested statements in providers (#33768)
6d182beec6 is described below
commit 6d182beec6e86b372c37fb164a31c2f8811d8c03
Author: Hussein Awala <[email protected]>
AuthorDate: Sat Aug 26 12:55:22 2023 +0200
Use a single statement with multiple contexts instead of nested
statements in providers (#33768)
---
airflow/providers/apache/hive/hooks/hive.py | 124 ++++++++++-----------
.../apache/hive/transfers/mysql_to_hive.py | 29 +++--
airflow/providers/apache/pig/hooks/pig.py | 65 ++++++-----
airflow/providers/dbt/cloud/hooks/dbt.py | 15 +--
airflow/providers/exasol/hooks/exasol.py | 10 +-
airflow/providers/google/cloud/hooks/gcs.py | 7 +-
.../google/cloud/hooks/kubernetes_engine.py | 58 +++++-----
airflow/providers/microsoft/azure/hooks/asb.py | 53 +++++----
.../providers/mysql/transfers/vertica_to_mysql.py | 59 +++++-----
airflow/providers/postgres/hooks/postgres.py | 10 +-
.../providers/snowflake/hooks/snowflake_sql_api.py | 11 +-
11 files changed, 213 insertions(+), 228 deletions(-)
diff --git a/airflow/providers/apache/hive/hooks/hive.py
b/airflow/providers/apache/hive/hooks/hive.py
index 5b0c91083a..773ea4af7d 100644
--- a/airflow/providers/apache/hive/hooks/hive.py
+++ b/airflow/providers/apache/hive/hooks/hive.py
@@ -236,58 +236,55 @@ class HiveCliHook(BaseHook):
if schema:
hql = f"USE {schema};\n{hql}"
- with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir:
- with NamedTemporaryFile(dir=tmp_dir) as f:
- hql += "\n"
- f.write(hql.encode("UTF-8"))
- f.flush()
- hive_cmd = self._prepare_cli_cmd()
- env_context = get_context_from_env_var()
- # Only extend the hive_conf if it is defined.
- if hive_conf:
- env_context.update(hive_conf)
- hive_conf_params = self._prepare_hiveconf(env_context)
- if self.mapred_queue:
- hive_conf_params.extend(
- [
- "-hiveconf",
- f"mapreduce.job.queuename={self.mapred_queue}",
- "-hiveconf",
- f"mapred.job.queue.name={self.mapred_queue}",
- "-hiveconf",
- f"tez.queue.name={self.mapred_queue}",
- ]
- )
+ with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir,
NamedTemporaryFile(dir=tmp_dir) as f:
+ hql += "\n"
+ f.write(hql.encode("UTF-8"))
+ f.flush()
+ hive_cmd = self._prepare_cli_cmd()
+ env_context = get_context_from_env_var()
+ # Only extend the hive_conf if it is defined.
+ if hive_conf:
+ env_context.update(hive_conf)
+ hive_conf_params = self._prepare_hiveconf(env_context)
+ if self.mapred_queue:
+ hive_conf_params.extend(
+ [
+ "-hiveconf",
+ f"mapreduce.job.queuename={self.mapred_queue}",
+ "-hiveconf",
+ f"mapred.job.queue.name={self.mapred_queue}",
+ "-hiveconf",
+ f"tez.queue.name={self.mapred_queue}",
+ ]
+ )
- if self.mapred_queue_priority:
- hive_conf_params.extend(
- ["-hiveconf",
f"mapreduce.job.priority={self.mapred_queue_priority}"]
- )
+ if self.mapred_queue_priority:
+ hive_conf_params.extend(["-hiveconf",
f"mapreduce.job.priority={self.mapred_queue_priority}"])
- if self.mapred_job_name:
- hive_conf_params.extend(["-hiveconf",
f"mapred.job.name={self.mapred_job_name}"])
+ if self.mapred_job_name:
+ hive_conf_params.extend(["-hiveconf",
f"mapred.job.name={self.mapred_job_name}"])
- hive_cmd.extend(hive_conf_params)
- hive_cmd.extend(["-f", f.name])
+ hive_cmd.extend(hive_conf_params)
+ hive_cmd.extend(["-f", f.name])
+ if verbose:
+ self.log.info("%s", " ".join(hive_cmd))
+ sub_process: Any = subprocess.Popen(
+ hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
cwd=tmp_dir, close_fds=True
+ )
+ self.sub_process = sub_process
+ stdout = ""
+ for line in iter(sub_process.stdout.readline, b""):
+ line = line.decode()
+ stdout += line
if verbose:
- self.log.info("%s", " ".join(hive_cmd))
- sub_process: Any = subprocess.Popen(
- hive_cmd, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
- )
- self.sub_process = sub_process
- stdout = ""
- for line in iter(sub_process.stdout.readline, b""):
- line = line.decode()
- stdout += line
- if verbose:
- self.log.info(line.strip())
- sub_process.wait()
+ self.log.info(line.strip())
+ sub_process.wait()
- if sub_process.returncode:
- raise AirflowException(stdout)
+ if sub_process.returncode:
+ raise AirflowException(stdout)
- return stdout
+ return stdout
def test_hql(self, hql: str) -> None:
"""Test an hql statement using the hive cli and EXPLAIN."""
@@ -376,25 +373,26 @@ class HiveCliHook(BaseHook):
if pandas_kwargs is None:
pandas_kwargs = {}
- with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir:
- with NamedTemporaryFile(dir=tmp_dir, mode="w") as f:
- if field_dict is None:
- field_dict = _infer_field_types_from_df(df)
-
- df.to_csv(
- path_or_buf=f,
- sep=delimiter,
- header=False,
- index=False,
- encoding=encoding,
- date_format="%Y-%m-%d %H:%M:%S",
- **pandas_kwargs,
- )
- f.flush()
+ with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir,
NamedTemporaryFile(
+ dir=tmp_dir, mode="w"
+ ) as f:
+ if field_dict is None:
+ field_dict = _infer_field_types_from_df(df)
+
+ df.to_csv(
+ path_or_buf=f,
+ sep=delimiter,
+ header=False,
+ index=False,
+ encoding=encoding,
+ date_format="%Y-%m-%d %H:%M:%S",
+ **pandas_kwargs,
+ )
+ f.flush()
- return self.load_file(
- filepath=f.name, table=table, delimiter=delimiter,
field_dict=field_dict, **kwargs
- )
+ return self.load_file(
+ filepath=f.name, table=table, delimiter=delimiter,
field_dict=field_dict, **kwargs
+ )
def load_file(
self,
diff --git a/airflow/providers/apache/hive/transfers/mysql_to_hive.py
b/airflow/providers/apache/hive/transfers/mysql_to_hive.py
index ee1cb082bc..bd7876efd4 100644
--- a/airflow/providers/apache/hive/transfers/mysql_to_hive.py
+++ b/airflow/providers/apache/hive/transfers/mysql_to_hive.py
@@ -136,21 +136,20 @@ class MySqlToHiveOperator(BaseOperator):
mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
self.log.info("Dumping MySQL query results to local file")
with NamedTemporaryFile(mode="w", encoding="utf-8") as f:
- with closing(mysql.get_conn()) as conn:
- with closing(conn.cursor()) as cursor:
- cursor.execute(self.sql)
- csv_writer = csv.writer(
- f,
- delimiter=self.delimiter,
- quoting=self.quoting,
- quotechar=self.quotechar if self.quoting !=
csv.QUOTE_NONE else None,
- escapechar=self.escapechar,
- )
- field_dict = {}
- if cursor.description is not None:
- for field in cursor.description:
- field_dict[field[0]] = self.type_map(field[1])
- csv_writer.writerows(cursor) # type: ignore[arg-type]
+ with closing(mysql.get_conn()) as conn, closing(conn.cursor()) as
cursor:
+ cursor.execute(self.sql)
+ csv_writer = csv.writer(
+ f,
+ delimiter=self.delimiter,
+ quoting=self.quoting,
+ quotechar=self.quotechar if self.quoting != csv.QUOTE_NONE
else None,
+ escapechar=self.escapechar,
+ )
+ field_dict = {}
+ if cursor.description is not None:
+ for field in cursor.description:
+ field_dict[field[0]] = self.type_map(field[1])
+ csv_writer.writerows(cursor) # type: ignore[arg-type]
f.flush()
self.log.info("Loading file into Hive")
hive.load_file(
diff --git a/airflow/providers/apache/pig/hooks/pig.py
b/airflow/providers/apache/pig/hooks/pig.py
index 71c39536d3..31e6006de3 100644
--- a/airflow/providers/apache/pig/hooks/pig.py
+++ b/airflow/providers/apache/pig/hooks/pig.py
@@ -64,41 +64,40 @@ class PigCliHook(BaseHook):
>>> ("hdfs://" in result)
True
"""
- with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir:
- with NamedTemporaryFile(dir=tmp_dir) as f:
- f.write(pig.encode("utf-8"))
- f.flush()
- fname = f.name
- pig_bin = "pig"
- cmd_extra: list[str] = []
-
- pig_cmd = [pig_bin]
-
- if self.pig_properties:
- pig_cmd.extend(self.pig_properties)
- if pig_opts:
- pig_opts_list = pig_opts.split()
- pig_cmd.extend(pig_opts_list)
+ with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir,
NamedTemporaryFile(dir=tmp_dir) as f:
+ f.write(pig.encode("utf-8"))
+ f.flush()
+ fname = f.name
+ pig_bin = "pig"
+ cmd_extra: list[str] = []
+
+ pig_cmd = [pig_bin]
+
+ if self.pig_properties:
+ pig_cmd.extend(self.pig_properties)
+ if pig_opts:
+ pig_opts_list = pig_opts.split()
+ pig_cmd.extend(pig_opts_list)
+
+ pig_cmd.extend(["-f", fname] + cmd_extra)
+
+ if verbose:
+ self.log.info("%s", " ".join(pig_cmd))
+ sub_process: Any = subprocess.Popen(
+ pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
cwd=tmp_dir, close_fds=True
+ )
+ self.sub_process = sub_process
+ stdout = ""
+ for line in iter(sub_process.stdout.readline, b""):
+ stdout += line.decode("utf-8")
+ if verbose:
+ self.log.info(line.strip())
+ sub_process.wait()
- pig_cmd.extend(["-f", fname] + cmd_extra)
+ if sub_process.returncode:
+ raise AirflowException(stdout)
- if verbose:
- self.log.info("%s", " ".join(pig_cmd))
- sub_process: Any = subprocess.Popen(
- pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
cwd=tmp_dir, close_fds=True
- )
- self.sub_process = sub_process
- stdout = ""
- for line in iter(sub_process.stdout.readline, b""):
- stdout += line.decode("utf-8")
- if verbose:
- self.log.info(line.strip())
- sub_process.wait()
-
- if sub_process.returncode:
- raise AirflowException(stdout)
-
- return stdout
+ return stdout
def kill(self) -> None:
"""Kill Pig job."""
diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py
b/airflow/providers/dbt/cloud/hooks/dbt.py
index 4a9785da3e..72446efecc 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -234,13 +234,14 @@ class DbtCloudHook(HttpHook):
endpoint = f"{account_id}/runs/{run_id}/"
headers, tenant = await self.get_headers_tenants_from_connection()
url, params = self.get_request_url_params(tenant, endpoint,
include_related)
- async with aiohttp.ClientSession(headers=headers) as session:
- async with session.get(url, params=params) as response:
- try:
- response.raise_for_status()
- return await response.json()
- except ClientResponseError as e:
- raise AirflowException(str(e.status) + ":" + e.message)
+ async with aiohttp.ClientSession(headers=headers) as session,
session.get(
+ url, params=params
+ ) as response:
+ try:
+ response.raise_for_status()
+ return await response.json()
+ except ClientResponseError as e:
+ raise AirflowException(f"{e.status}:{e.message}")
async def get_job_status(
self, run_id: int, account_id: int | None = None, include_related:
list[str] | None = None
diff --git a/airflow/providers/exasol/hooks/exasol.py
b/airflow/providers/exasol/hooks/exasol.py
index ed71205ebc..ffadf46072 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -97,9 +97,8 @@ class ExasolHook(DbApiHook):
sql statements to execute
:param parameters: The parameters to render the SQL query with.
"""
- with closing(self.get_conn()) as conn:
- with closing(conn.execute(sql, parameters)) as cur:
- return cur.fetchall()
+ with closing(self.get_conn()) as conn, closing(conn.execute(sql,
parameters)) as cur:
+ return cur.fetchall()
def get_first(self, sql: str | list[str], parameters: Iterable |
Mapping[str, Any] | None = None) -> Any:
"""Execute the SQL and return the first resulting row.
@@ -108,9 +107,8 @@ class ExasolHook(DbApiHook):
sql statements to execute
:param parameters: The parameters to render the SQL query with.
"""
- with closing(self.get_conn()) as conn:
- with closing(conn.execute(sql, parameters)) as cur:
- return cur.fetchone()
+ with closing(self.get_conn()) as conn, closing(conn.execute(sql,
parameters)) as cur:
+ return cur.fetchone()
def export_to_file(
self,
diff --git a/airflow/providers/google/cloud/hooks/gcs.py
b/airflow/providers/google/cloud/hooks/gcs.py
index f489c9200b..72c555bbda 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -550,10 +550,9 @@ class GCSHook(GoogleBaseHook):
if gzip:
filename_gz = filename + ".gz"
- with open(filename, "rb") as f_in:
- with gz.open(filename_gz, "wb") as f_out:
- shutil.copyfileobj(f_in, f_out)
- filename = filename_gz
+ with open(filename, "rb") as f_in, gz.open(filename_gz, "wb")
as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ filename = filename_gz
_call_with_retry(
partial(blob.upload_from_filename, filename=filename,
content_type=mime_type, timeout=timeout)
diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
index 00df2b9b28..1837613ea0 100644
--- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
@@ -493,19 +493,18 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook):
:param name: Name of the pod.
:param namespace: Name of the pod's namespace.
"""
- async with Token(scopes=self.scopes) as token:
- async with self.get_conn(token) as connection:
- try:
- v1_api = async_client.CoreV1Api(connection)
- await v1_api.delete_namespaced_pod(
- name=name,
- namespace=namespace,
- body=client.V1DeleteOptions(),
- )
- except async_client.ApiException as e:
- # If the pod is already deleted
- if e.status != 404:
- raise
+ async with Token(scopes=self.scopes) as token, self.get_conn(token) as
connection:
+ try:
+ v1_api = async_client.CoreV1Api(connection)
+ await v1_api.delete_namespaced_pod(
+ name=name,
+ namespace=namespace,
+ body=client.V1DeleteOptions(),
+ )
+ except async_client.ApiException as e:
+ # If the pod is already deleted
+ if e.status != 404:
+ raise
async def read_logs(self, name: str, namespace: str):
"""Read logs inside the pod while starting containers inside.
@@ -518,20 +517,19 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook):
:param name: Name of the pod.
:param namespace: Name of the pod's namespace.
"""
- async with Token(scopes=self.scopes) as token:
- async with self.get_conn(token) as connection:
- try:
- v1_api = async_client.CoreV1Api(connection)
- logs = await v1_api.read_namespaced_pod_log(
- name=name,
- namespace=namespace,
- follow=False,
- timestamps=True,
- )
- logs = logs.splitlines()
- for line in logs:
- self.log.info("Container logs from %s", line)
- return logs
- except HTTPError:
- self.log.exception("There was an error reading the
kubernetes API.")
- raise
+ async with Token(scopes=self.scopes) as token, self.get_conn(token) as
connection:
+ try:
+ v1_api = async_client.CoreV1Api(connection)
+ logs = await v1_api.read_namespaced_pod_log(
+ name=name,
+ namespace=namespace,
+ follow=False,
+ timestamps=True,
+ )
+ logs = logs.splitlines()
+ for line in logs:
+ self.log.info("Container logs from %s", line)
+ return logs
+ except HTTPError:
+ self.log.exception("There was an error reading the kubernetes
API.")
+ raise
diff --git a/airflow/providers/microsoft/azure/hooks/asb.py
b/airflow/providers/microsoft/azure/hooks/asb.py
index 7001db46f9..80273d6f96 100644
--- a/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/airflow/providers/microsoft/azure/hooks/asb.py
@@ -215,19 +215,18 @@ class MessageHook(BaseAzureServiceBusHook):
raise ValueError("Messages list cannot be empty.")
with self.get_conn() as service_bus_client,
service_bus_client.get_queue_sender(
queue_name=queue_name
- ) as sender:
- with sender:
- if isinstance(messages, str):
- if not batch_message_flag:
- msg = ServiceBusMessage(messages)
- sender.send_messages(msg)
- else:
- self.send_batch_message(sender, [messages])
+ ) as sender, sender:
+ if isinstance(messages, str):
+ if not batch_message_flag:
+ msg = ServiceBusMessage(messages)
+ sender.send_messages(msg)
else:
- if not batch_message_flag:
- self.send_list_messages(sender, messages)
- else:
- self.send_batch_message(sender, messages)
+ self.send_batch_message(sender, [messages])
+ else:
+ if not batch_message_flag:
+ self.send_list_messages(sender, messages)
+ else:
+ self.send_batch_message(sender, messages)
@staticmethod
def send_list_messages(sender: ServiceBusSender, messages: list[str]):
@@ -256,14 +255,13 @@ class MessageHook(BaseAzureServiceBusHook):
with self.get_conn() as service_bus_client,
service_bus_client.get_queue_receiver(
queue_name=queue_name
- ) as receiver:
- with receiver:
- received_msgs = receiver.receive_messages(
- max_message_count=max_message_count,
max_wait_time=max_wait_time
- )
- for msg in received_msgs:
- self.log.info(msg)
- receiver.complete_message(msg)
+ ) as receiver, receiver:
+ received_msgs = receiver.receive_messages(
+ max_message_count=max_message_count,
max_wait_time=max_wait_time
+ )
+ for msg in received_msgs:
+ self.log.info(msg)
+ receiver.complete_message(msg)
def receive_subscription_message(
self,
@@ -293,11 +291,10 @@ class MessageHook(BaseAzureServiceBusHook):
raise TypeError("Topic name cannot be None.")
with self.get_conn() as service_bus_client,
service_bus_client.get_subscription_receiver(
topic_name, subscription_name
- ) as subscription_receiver:
- with subscription_receiver:
- received_msgs = subscription_receiver.receive_messages(
- max_message_count=max_message_count,
max_wait_time=max_wait_time
- )
- for msg in received_msgs:
- self.log.info(msg)
- subscription_receiver.complete_message(msg)
+ ) as subscription_receiver, subscription_receiver:
+ received_msgs = subscription_receiver.receive_messages(
+ max_message_count=max_message_count,
max_wait_time=max_wait_time
+ )
+ for msg in received_msgs:
+ self.log.info(msg)
+ subscription_receiver.complete_message(msg)
diff --git a/airflow/providers/mysql/transfers/vertica_to_mysql.py
b/airflow/providers/mysql/transfers/vertica_to_mysql.py
index 16be186fc7..fd196315d7 100644
--- a/airflow/providers/mysql/transfers/vertica_to_mysql.py
+++ b/airflow/providers/mysql/transfers/vertica_to_mysql.py
@@ -99,17 +99,16 @@ class VerticaToMySqlOperator(BaseOperator):
self.log.info("Done")
def _non_bulk_load_transfer(self, mysql, vertica):
- with closing(vertica.get_conn()) as conn:
- with closing(conn.cursor()) as cursor:
- cursor.execute(self.sql)
- selected_columns = [d.name for d in cursor.description]
- self.log.info("Selecting rows from Vertica...")
- self.log.info(self.sql)
+ with closing(vertica.get_conn()) as conn, closing(conn.cursor()) as
cursor:
+ cursor.execute(self.sql)
+ selected_columns = [d.name for d in cursor.description]
+ self.log.info("Selecting rows from Vertica...")
+ self.log.info(self.sql)
- result = cursor.fetchall()
- count = len(result)
+ result = cursor.fetchall()
+ count = len(result)
- self.log.info("Selected rows from Vertica %s", count)
+ self.log.info("Selected rows from Vertica %s", count)
self._run_preoperator(mysql)
try:
self.log.info("Inserting rows into MySQL...")
@@ -121,31 +120,29 @@ class VerticaToMySqlOperator(BaseOperator):
def _bulk_load_transfer(self, mysql, vertica):
count = 0
- with closing(vertica.get_conn()) as conn:
- with closing(conn.cursor()) as cursor:
- cursor.execute(self.sql)
- selected_columns = [d.name for d in cursor.description]
- with NamedTemporaryFile("w", encoding="utf-8") as tmpfile:
- self.log.info("Selecting rows from Vertica to local file
%s...", tmpfile.name)
- self.log.info(self.sql)
-
- csv_writer = csv.writer(tmpfile, delimiter="\t")
- for row in cursor.iterate():
- csv_writer.writerow(row)
- count += 1
-
- tmpfile.flush()
+ with closing(vertica.get_conn()) as conn, closing(conn.cursor()) as
cursor:
+ cursor.execute(self.sql)
+ selected_columns = [d.name for d in cursor.description]
+ with NamedTemporaryFile("w", encoding="utf-8") as tmpfile:
+ self.log.info("Selecting rows from Vertica to local file
%s...", tmpfile.name)
+ self.log.info(self.sql)
+
+ csv_writer = csv.writer(tmpfile, delimiter="\t")
+ for row in cursor.iterate():
+ csv_writer.writerow(row)
+ count += 1
+
+ tmpfile.flush()
self._run_preoperator(mysql)
try:
self.log.info("Bulk inserting rows into MySQL...")
- with closing(mysql.get_conn()) as conn:
- with closing(conn.cursor()) as cursor:
- cursor.execute(
- f"LOAD DATA LOCAL INFILE '{tmpfile.name}' "
- f"INTO TABLE {self.mysql_table} "
- f"LINES TERMINATED BY '\r\n' ({',
'.join(selected_columns)})"
- )
- conn.commit()
+ with closing(mysql.get_conn()) as conn, closing(conn.cursor()) as
cursor:
+ cursor.execute(
+ f"LOAD DATA LOCAL INFILE '{tmpfile.name}' "
+ f"INTO TABLE {self.mysql_table} "
+ f"LINES TERMINATED BY '\r\n' ({',
'.join(selected_columns)})"
+ )
+ conn.commit()
tmpfile.close()
self.log.info("Inserted rows into MySQL %s", count)
except (MySQLdb.Error, MySQLdb.Warning):
diff --git a/airflow/providers/postgres/hooks/postgres.py
b/airflow/providers/postgres/hooks/postgres.py
index b6b214e990..95e7be94cb 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -170,12 +170,10 @@ class PostgresHook(DbApiHook):
with open(filename, "w"):
pass
- with open(filename, "r+") as file:
- with closing(self.get_conn()) as conn:
- with closing(conn.cursor()) as cur:
- cur.copy_expert(sql, file)
- file.truncate(file.tell())
- conn.commit()
+ with open(filename, "r+") as file, closing(self.get_conn()) as conn,
closing(conn.cursor()) as cur:
+ cur.copy_expert(sql, file)
+ file.truncate(file.tell())
+ conn.commit()
def get_uri(self) -> str:
"""Extract the URI from the connection.
diff --git a/airflow/providers/snowflake/hooks/snowflake_sql_api.py
b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index b10a3c670f..49cedf115a 100644
--- a/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++ b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -271,8 +271,9 @@ class SnowflakeSqlApiHook(SnowflakeHook):
"""
self.log.info("Retrieving status for query id %s", query_id)
header, params, url = self.get_request_url_header_params(query_id)
- async with aiohttp.ClientSession(headers=header) as session:
- async with session.get(url, params=params) as response:
- status_code = response.status
- resp = await response.json()
- return self._process_response(status_code, resp)
+ async with aiohttp.ClientSession(headers=header) as session,
session.get(
+ url, params=params
+ ) as response:
+ status_code = response.status
+ resp = await response.json()
+ return self._process_response(status_code, resp)