o-nikolas commented on code in PR #55694:
URL: https://github.com/apache/airflow/pull/55694#discussion_r2356101420


##########
providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py:
##########
@@ -65,8 +65,41 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
     hook_name = "Amazon Athena"
     supports_autocommit = True
 
-    def __init__(self, athena_conn_id: str = default_conn_name, *args, 
**kwargs) -> None:
-        super().__init__(*args, **kwargs)
+    def __init__(
+        self,
+        athena_conn_id: str | None = None,  # keep positional compatibility
+        *,
+        s3_staging_dir: str | None = None,
+        work_group: str | None = None,
+        driver: str | None = None,
+        aws_domain: str | None = None,
+        session_kwargs: dict | None = None,
+        config_kwargs: dict | None = None,
+        role_arn: str | None = None,
+        assume_role_method: str | None = None,
+        assume_role_kwargs: dict | None = None,
+        aws_session_token: str | None = None,
+        endpoint_url: str | None = None,
+        **kwargs,
+    ) -> None:
+        # prefer explicit arg; fall back to kwargs; finally default
+        if athena_conn_id is None:
+            athena_conn_id = kwargs.pop("athena_conn_id", 
self.default_conn_name)
+        else:
+            kwargs.pop("athena_conn_id", None)  # avoid conflicts
+        super().__init__(**kwargs)
+        # Store explicit params on self

Review Comment:
   These types of comments (that often come from AI) aren't very useful in the 
final code.



##########
providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py:
##########
@@ -153,3 +163,123 @@ def test_conn_id_override_setter(self):
         hook = AthenaSQLHook(athena_conn_id=AWS_ATHENA_CONN_ID, 
aws_conn_id=AWS_CONN_ID)
         assert hook.athena_conn_id == AWS_ATHENA_CONN_ID
         assert hook.aws_conn_id == AWS_CONN_ID
+
+    def test_hook_params_handling(self):
+        """Test that hook_params are properly handled and don't cause 
TypeError."""
+        # Test that hook_params with Athena-specific parameters don't cause 
errors

Review Comment:
   You don't need both of these



##########
providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py:
##########
@@ -153,3 +163,123 @@ def test_conn_id_override_setter(self):
         hook = AthenaSQLHook(athena_conn_id=AWS_ATHENA_CONN_ID, 
aws_conn_id=AWS_CONN_ID)
         assert hook.athena_conn_id == AWS_ATHENA_CONN_ID
         assert hook.aws_conn_id == AWS_CONN_ID
+
+    def test_hook_params_handling(self):
+        """Test that hook_params are properly handled and don't cause 
TypeError."""
+        # Test that hook_params with Athena-specific parameters don't cause 
errors
+        hook = AthenaSQLHook(
+            athena_conn_id="test_conn",
+            s3_staging_dir="s3://test-bucket/staging/",
+            work_group="test-workgroup",
+            driver="rest",
+            aws_domain="amazonaws.com",
+            session_kwargs={"profile_name": "test"},
+            config_kwargs={"retries": {"max_attempts": 5}},
+            role_arn="arn:aws:iam::123456789012:role/test-role",
+            assume_role_method="assume_role",
+            assume_role_kwargs={"RoleSessionName": "airflow-test"},
+            aws_session_token="test-token",
+            endpoint_url="https://athena.us-east-1.amazonaws.com";,
+        )
+
+        # Verify that the parameters were extracted correctly
+        assert hook.s3_staging_dir == "s3://test-bucket/staging/"
+        assert hook.work_group == "test-workgroup"
+        assert hook.driver == "rest"
+        assert hook.aws_domain == "amazonaws.com"
+        assert hook.session_kwargs == {"profile_name": "test"}
+        assert hook.config_kwargs == {"retries": {"max_attempts": 5}}
+        assert hook.role_arn == "arn:aws:iam::123456789012:role/test-role"
+        assert hook.assume_role_method == "assume_role"
+        assert hook.assume_role_kwargs == {"RoleSessionName": "airflow-test"}
+        assert hook.aws_session_token == "test-token"
+        assert hook.endpoint_url == "https://athena.us-east-1.amazonaws.com";
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.athena_sql.pyathena.connect")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.athena_sql.AthenaSQLHook.get_session")
+    def test_get_conn_with_hook_params(self, mock_get_session, mock_connect):
+        """Test that get_conn uses hook_params when provided."""
+        # Create hook with hook_params
+        hook = AthenaSQLHook(
+            athena_conn_id="test_conn",
+            s3_staging_dir="s3://test-bucket/staging/",
+            work_group="test-workgroup",
+        )
+
+        # Mock the connection
+        conn = Connection(
+            conn_type="athena",
+            schema="test_schema",
+            extra={"region_name": "us-east-1"},
+        )
+        hook.get_connection = mock.Mock(return_value=conn)
+
+        # Call get_conn
+        hook.get_conn()

Review Comment:
   I love comments just as much as the next person (if not more even), but 
these kinds of comments are really unnecessary. 



##########
providers/amazon/src/airflow/providers/amazon/aws/hooks/athena_sql.py:
##########
@@ -65,8 +65,41 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
     hook_name = "Amazon Athena"
     supports_autocommit = True
 
-    def __init__(self, athena_conn_id: str = default_conn_name, *args, 
**kwargs) -> None:
-        super().__init__(*args, **kwargs)
+    def __init__(
+        self,
+        athena_conn_id: str | None = None,  # keep positional compatibility
+        *,
+        s3_staging_dir: str | None = None,
+        work_group: str | None = None,
+        driver: str | None = None,
+        aws_domain: str | None = None,
+        session_kwargs: dict | None = None,
+        config_kwargs: dict | None = None,
+        role_arn: str | None = None,
+        assume_role_method: str | None = None,
+        assume_role_kwargs: dict | None = None,
+        aws_session_token: str | None = None,
+        endpoint_url: str | None = None,
+        **kwargs,
+    ) -> None:
+        # prefer explicit arg; fall back to kwargs; finally default
+        if athena_conn_id is None:
+            athena_conn_id = kwargs.pop("athena_conn_id", 
self.default_conn_name)
+        else:
+            kwargs.pop("athena_conn_id", None)  # avoid conflicts
+        super().__init__(**kwargs)
+        # Store explicit params on self

Review Comment:
   Btw, why did we flip back to setting all of these on the init constructor vs 
fetching them from kwargs? Didn't you say in a previous message that you were 
going to keep the kwargs pop?



##########
providers/amazon/tests/unit/amazon/aws/hooks/test_athena_sql.py:
##########
@@ -153,3 +163,123 @@ def test_conn_id_override_setter(self):
         hook = AthenaSQLHook(athena_conn_id=AWS_ATHENA_CONN_ID, 
aws_conn_id=AWS_CONN_ID)
         assert hook.athena_conn_id == AWS_ATHENA_CONN_ID
         assert hook.aws_conn_id == AWS_CONN_ID
+
+    def test_hook_params_handling(self):
+        """Test that hook_params are properly handled and don't cause 
TypeError."""
+        # Test that hook_params with Athena-specific parameters don't cause 
errors
+        hook = AthenaSQLHook(
+            athena_conn_id="test_conn",
+            s3_staging_dir="s3://test-bucket/staging/",
+            work_group="test-workgroup",
+            driver="rest",
+            aws_domain="amazonaws.com",
+            session_kwargs={"profile_name": "test"},
+            config_kwargs={"retries": {"max_attempts": 5}},
+            role_arn="arn:aws:iam::123456789012:role/test-role",
+            assume_role_method="assume_role",
+            assume_role_kwargs={"RoleSessionName": "airflow-test"},
+            aws_session_token="test-token",
+            endpoint_url="https://athena.us-east-1.amazonaws.com";,
+        )
+
+        # Verify that the parameters were extracted correctly
+        assert hook.s3_staging_dir == "s3://test-bucket/staging/"
+        assert hook.work_group == "test-workgroup"
+        assert hook.driver == "rest"
+        assert hook.aws_domain == "amazonaws.com"
+        assert hook.session_kwargs == {"profile_name": "test"}
+        assert hook.config_kwargs == {"retries": {"max_attempts": 5}}
+        assert hook.role_arn == "arn:aws:iam::123456789012:role/test-role"
+        assert hook.assume_role_method == "assume_role"
+        assert hook.assume_role_kwargs == {"RoleSessionName": "airflow-test"}
+        assert hook.aws_session_token == "test-token"
+        assert hook.endpoint_url == "https://athena.us-east-1.amazonaws.com";
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.athena_sql.pyathena.connect")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.athena_sql.AthenaSQLHook.get_session")
+    def test_get_conn_with_hook_params(self, mock_get_session, mock_connect):
+        """Test that get_conn uses hook_params when provided."""
+        # Create hook with hook_params
+        hook = AthenaSQLHook(
+            athena_conn_id="test_conn",
+            s3_staging_dir="s3://test-bucket/staging/",
+            work_group="test-workgroup",
+        )
+
+        # Mock the connection
+        conn = Connection(
+            conn_type="athena",
+            schema="test_schema",
+            extra={"region_name": "us-east-1"},
+        )
+        hook.get_connection = mock.Mock(return_value=conn)
+
+        # Call get_conn
+        hook.get_conn()
+
+        # Verify that pyathena.connect was called with hook_params
+        mock_connect.assert_called_once()
+        call_args = mock_connect.call_args[1]  # Get keyword arguments
+        assert call_args["s3_staging_dir"] == "s3://test-bucket/staging/"
+        assert call_args["work_group"] == "test-workgroup"
+
+    def test_sql_value_check_operator_compatibility(self):
+        """Test that AthenaSQLHook works with SQLValueCheckOperator."""
+        from airflow.providers.common.sql.operators.sql import 
SQLValueCheckOperator
+
+        # Mock Athena connection with s3_staging_dir in extra
+        athena_conn = Connection(
+            conn_id="athena_conn",
+            conn_type="athena",
+            description="Connection to a Athena API",
+            schema="athena_sql_schema1",
+            extra={"s3_staging_dir": "s3://mybucket/athena/", "region_name": 
"eu-west-1"},
+        )
+
+        with patch(f"{BASEHOOK_PATCH_PATH}.get_connection", 
return_value=athena_conn):
+            # This should NOT raise TypeError: AwsGenericHook.__init__() got 
an unexpected keyword argument 's3_staging_dir'

Review Comment:
   This is a great comment 🙂 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to