This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ff43e6ebe1c Update HiveToDynamoDBOperator to support Polars (#54221)
ff43e6ebe1c is described below
commit ff43e6ebe1c42f160684bc449d89c810865cc57f
Author: Guan Ming(Wesley) Chiu <[email protected]>
AuthorDate: Tue Aug 12 19:21:32 2025 +0800
Update HiveToDynamoDBOperator to support Polars (#54221)
---
.../amazon/aws/transfers/hive_to_dynamodb.py | 12 +++++--
.../amazon/aws/transfers/test_hive_to_dynamodb.py | 42 ++++++++++++++++++++++
2 files changed, 51 insertions(+), 3 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
index 1cd5e4ce22d..b8ab98ba21d 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
@@ -21,7 +21,7 @@ from __future__ import annotations
import json
from collections.abc import Callable, Sequence
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
from airflow.providers.amazon.version_compat import BaseOperator
@@ -53,6 +53,7 @@ class HiveToDynamoDBOperator(BaseOperator):
:param hiveserver2_conn_id: Reference to the
:ref: `Hive Server2 thrift service connection id
<howto/connection:hiveserver2>`.
:param aws_conn_id: aws connection
+ :param df_type: DataFrame type to use ("pandas" or "polars").
"""
template_fields: Sequence[str] = ("sql",)
@@ -73,6 +74,7 @@ class HiveToDynamoDBOperator(BaseOperator):
schema: str = "default",
hiveserver2_conn_id: str = "hiveserver2_default",
aws_conn_id: str | None = "aws_default",
+ df_type: Literal["pandas", "polars"] = "pandas",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -86,6 +88,7 @@ class HiveToDynamoDBOperator(BaseOperator):
self.schema = schema
self.hiveserver2_conn_id = hiveserver2_conn_id
self.aws_conn_id = aws_conn_id
+ self.df_type = df_type
def execute(self, context: Context):
hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
@@ -93,7 +96,7 @@ class HiveToDynamoDBOperator(BaseOperator):
self.log.info("Extracting data from Hive")
self.log.info(self.sql)
- data = hive.get_df(self.sql, schema=self.schema, df_type="pandas")
+ data = hive.get_df(self.sql, schema=self.schema, df_type=self.df_type)
dynamodb = DynamoDBHook(
aws_conn_id=self.aws_conn_id,
table_name=self.table_name,
@@ -104,7 +107,10 @@ class HiveToDynamoDBOperator(BaseOperator):
self.log.info("Inserting rows into dynamodb")
if self.pre_process is None:
-
dynamodb.write_batch_data(json.loads(data.to_json(orient="records")))
+ if self.df_type == "polars":
+ dynamodb.write_batch_data(data.to_dicts()) #
type:ignore[operator]
+ elif self.df_type == "pandas":
+
dynamodb.write_batch_data(json.loads(data.to_json(orient="records"))) #
type:ignore[union-attr]
else:
dynamodb.write_batch_data(
self.pre_process(data=data, args=self.pre_process_args,
kwargs=self.pre_process_kwargs)
diff --git
a/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py
b/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py
index 82183a44434..4249ce093a6 100644
--- a/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py
+++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py
@@ -22,6 +22,8 @@ import json
from unittest import mock
import pandas as pd
+import polars as pl
+import pytest
from moto import mock_aws
import airflow.providers.amazon.aws.transfers.hive_to_dynamodb
@@ -110,3 +112,43 @@ class TestHiveToDynamoDBOperator:
table = self.hook.get_conn().Table("test_airflow")
table.meta.client.get_waiter("table_exists").wait(TableName="test_airflow")
assert table.item_count == 1
+
+ @pytest.mark.parametrize("df_type", ["pandas", "polars"])
+ @mock_aws
+ def test_df_type_parameter(self, df_type):
+ if df_type == "polars" and pl is None:
+ pytest.skip("Polars not installed")
+
+ if df_type == "pandas":
+ test_df = pd.DataFrame(data=[("1", "sid")], columns=["id", "name"])
+ else:
+ test_df = pl.DataFrame({"id": ["1"], "name": ["sid"]})
+
+ with mock.patch(
+ "airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_df",
+ return_value=test_df,
+ ) as mock_get_df:
+ self.hook.get_conn().create_table(
+ TableName="test_airflow",
+ KeySchema=[
+ {"AttributeName": "id", "KeyType": "HASH"},
+ ],
+ AttributeDefinitions=[{"AttributeName": "id", "AttributeType":
"S"}],
+ ProvisionedThroughput={"ReadCapacityUnits": 10,
"WriteCapacityUnits": 10},
+ )
+
+ operator =
airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator(
+ sql=self.sql,
+ table_name="test_airflow",
+ task_id="hive_to_dynamodb_check",
+ table_keys=["id"],
+ df_type=df_type,
+ dag=self.dag,
+ )
+
+ operator.execute(None)
+ mock_get_df.assert_called_once_with(self.sql, schema="default",
df_type=df_type)
+
+ table = self.hook.get_conn().Table("test_airflow")
+
table.meta.client.get_waiter("table_exists").wait(TableName="test_airflow")
+ assert table.item_count == 1