This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ff43e6ebe1c Update HiveToDynamoDBOperator to support Polars (#54221)
ff43e6ebe1c is described below

commit ff43e6ebe1c42f160684bc449d89c810865cc57f
Author: Guan Ming(Wesley) Chiu <[email protected]>
AuthorDate: Tue Aug 12 19:21:32 2025 +0800

    Update HiveToDynamoDBOperator to support Polars (#54221)
---
 .../amazon/aws/transfers/hive_to_dynamodb.py       | 12 +++++--
 .../amazon/aws/transfers/test_hive_to_dynamodb.py  | 42 ++++++++++++++++++++++
 2 files changed, 51 insertions(+), 3 deletions(-)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
index 1cd5e4ce22d..b8ab98ba21d 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
@@ -21,7 +21,7 @@ from __future__ import annotations
 
 import json
 from collections.abc import Callable, Sequence
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
 
 from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
 from airflow.providers.amazon.version_compat import BaseOperator
@@ -53,6 +53,7 @@ class HiveToDynamoDBOperator(BaseOperator):
     :param hiveserver2_conn_id: Reference to the
         :ref: `Hive Server2 thrift service connection id 
<howto/connection:hiveserver2>`.
     :param aws_conn_id: aws connection
+    :param df_type: DataFrame type to use ("pandas" or "polars").
     """
 
     template_fields: Sequence[str] = ("sql",)
@@ -73,6 +74,7 @@ class HiveToDynamoDBOperator(BaseOperator):
         schema: str = "default",
         hiveserver2_conn_id: str = "hiveserver2_default",
         aws_conn_id: str | None = "aws_default",
+        df_type: Literal["pandas", "polars"] = "pandas",
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -86,6 +88,7 @@ class HiveToDynamoDBOperator(BaseOperator):
         self.schema = schema
         self.hiveserver2_conn_id = hiveserver2_conn_id
         self.aws_conn_id = aws_conn_id
+        self.df_type = df_type
 
     def execute(self, context: Context):
         hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
@@ -93,7 +96,7 @@ class HiveToDynamoDBOperator(BaseOperator):
         self.log.info("Extracting data from Hive")
         self.log.info(self.sql)
 
-        data = hive.get_df(self.sql, schema=self.schema, df_type="pandas")
+        data = hive.get_df(self.sql, schema=self.schema, df_type=self.df_type)
         dynamodb = DynamoDBHook(
             aws_conn_id=self.aws_conn_id,
             table_name=self.table_name,
@@ -104,7 +107,10 @@ class HiveToDynamoDBOperator(BaseOperator):
         self.log.info("Inserting rows into dynamodb")
 
         if self.pre_process is None:
-            
dynamodb.write_batch_data(json.loads(data.to_json(orient="records")))
+            if self.df_type == "polars":
+                dynamodb.write_batch_data(data.to_dicts())  # 
type:ignore[operator]
+            elif self.df_type == "pandas":
+                
dynamodb.write_batch_data(json.loads(data.to_json(orient="records")))  # 
type:ignore[union-attr]
         else:
             dynamodb.write_batch_data(
                 self.pre_process(data=data, args=self.pre_process_args, 
kwargs=self.pre_process_kwargs)
diff --git 
a/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py 
b/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py
index 82183a44434..4249ce093a6 100644
--- a/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py
+++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_hive_to_dynamodb.py
@@ -22,6 +22,8 @@ import json
 from unittest import mock
 
 import pandas as pd
+import polars as pl
+import pytest
 from moto import mock_aws
 
 import airflow.providers.amazon.aws.transfers.hive_to_dynamodb
@@ -110,3 +112,43 @@ class TestHiveToDynamoDBOperator:
         table = self.hook.get_conn().Table("test_airflow")
         
table.meta.client.get_waiter("table_exists").wait(TableName="test_airflow")
         assert table.item_count == 1
+
+    @pytest.mark.parametrize("df_type", ["pandas", "polars"])
+    @mock_aws
+    def test_df_type_parameter(self, df_type):
+        if df_type == "polars" and pl is None:
+            pytest.skip("Polars not installed")
+
+        if df_type == "pandas":
+            test_df = pd.DataFrame(data=[("1", "sid")], columns=["id", "name"])
+        else:
+            test_df = pl.DataFrame({"id": ["1"], "name": ["sid"]})
+
+        with mock.patch(
+            "airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_df",
+            return_value=test_df,
+        ) as mock_get_df:
+            self.hook.get_conn().create_table(
+                TableName="test_airflow",
+                KeySchema=[
+                    {"AttributeName": "id", "KeyType": "HASH"},
+                ],
+                AttributeDefinitions=[{"AttributeName": "id", "AttributeType": 
"S"}],
+                ProvisionedThroughput={"ReadCapacityUnits": 10, 
"WriteCapacityUnits": 10},
+            )
+
+            operator = 
airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator(
+                sql=self.sql,
+                table_name="test_airflow",
+                task_id="hive_to_dynamodb_check",
+                table_keys=["id"],
+                df_type=df_type,
+                dag=self.dag,
+            )
+
+            operator.execute(None)
+            mock_get_df.assert_called_once_with(self.sql, schema="default", 
df_type=df_type)
+
+            table = self.hook.get_conn().Table("test_airflow")
+            
table.meta.client.get_waiter("table_exists").wait(TableName="test_airflow")
+            assert table.item_count == 1

Reply via email to