josh-fell commented on code in PR #36953: URL: https://github.com/apache/airflow/pull/36953#discussion_r1470489935
########## airflow/providers/teradata/example_dags/example_teradata_operator.py: ########## @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG to show basic CRUD operation on teradata database using TeradataOperator + +This DAG assumes Airflow Connection with connection id `teradata_default` already exists in locally. +It shows how to run queries as tasks in airflow dags using TeradataOperator.. +""" +from __future__ import annotations + +from datetime import datetime + +import pytest + +from airflow import DAG +from airflow.models.baseoperator import chain + +try: + from airflow.providers.teradata.operators.teradata import TeradataOperator +except ImportError: + pytest.skip("Teradata provider apache-airflow-provider-teradata not available", allow_module_level=True) Review Comment: Just curious, why would `pytest`-related logic be in example DAGs/system tests? ########## docs/apache-airflow-providers-teradata/connections/teradata.rst: ########## @@ -0,0 +1,82 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/connection:teradata: + +Teradata Connection +====================== +The Teradata connection type enables integrations with Teradata. + +Configuring the Connection +-------------------------- +Host (required) + The host to connect to. + +Database (optional) + Specify the name of the database to connect to. + +Login (required) + Specify the user name to connect. + +Password (required) + Specify the password to connect. + +Extra (optional) + Specify the extra parameters (as json dictionary) that can be used in Teradata + connection. The following parameters out of the standard python parameters + are supported: + + * ``tmode`` - Specifies the transaction mode.Possible values are DEFAULT (the default), ANSI, or TERA + * ``sslmode`` - This option specifies the mode for connections to the database. + There are six modes: + ``disable``, ``allow``, ``prefer``, ``require``, ``verify-ca``, ``verify-full``. + * ``sslca`` - This parameter specifies the file name of a PEM file that contains + Certificate Authority (CA) certificates for use with sslmode values VERIFY-CA or VERIFY-FULL. + * ``sslcapath`` - This parameter specifies the TLS cipher for HTTPS/TLS connections. + * ``sslcipher`` - This parameter specifies the name of a file containing SSL + certificate authority (CA) certificate(s). + * ``sslcrc`` - This parameter controls TLS certificate revocation checking for + HTTPS/TLS connections when sslmode is VERIFY-FULL. + * ``sslprotocol`` - Specifies the TLS protocol for HTTPS/TLS connections. + + More details on all Teradata parameters supported can be found in + `Teradata documentation <https://github.com/Teradata/python-driver?tab=readme-ov-file#connection-parameters>`_. + + Example "extras" field: + + .. code-block:: json + + { + "tmode": "TERA", + "sslmode": "verify-ca", + "sslcert": "/tmp/client-cert.pem", + "sslca": "/tmp/server-ca.pem", + "sslkey": "/tmp/client-key.pem" + } Review Comment: This would be a _much_ better example as a placeholder for the `extra` field in the connection form IMO than was is proposed. Could you make that placeholder update please? ########## airflow/providers/teradata/operators/teradata.py: ########## @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Sequence + +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.providers.teradata.hooks.teradata import TeradataHook + + +class TeradataOperator(SQLExecuteQueryOperator): + """ + General Teradata Operator to execute queries on Teradata Database. + + Executes sql statements in the Teradata SQL Database using teradatasql jdbc driver + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TeradataOperator` + :param sql: the SQL query to be executed as a single string, or a list of str (sql statements) + :param conn_id: reference to a predefined database + :param autocommit: if True, each command is automatically committed.(default value: False) + :param parameters: (optional) the parameters to render the SQL query with. + """ + + template_fields: Sequence[str] = ( + "parameters", + "sql", + ) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + ui_color = "#e07c24" + + def __init__( + self, + conn_id: str = TeradataHook.default_conn_name, + host: str | None = None, + schema: str | None = None, + login: str | None = None, + password: str | None = None, Review Comment: These are missing from the operator's docstring. Can you add these please? Although passing `password` as an arg seems like an antipattern. Better off to use a proper connection manager here? ########## airflow/providers/teradata/example_dags/example_teradata_operator.py: ########## @@ -0,0 +1,132 @@ +# Review Comment: +1. Since this is all net-new, it would be better to make them system tests straight away. ########## airflow/providers/teradata/operators/teradata.py: ########## @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Sequence + +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.providers.teradata.hooks.teradata import TeradataHook + + +class TeradataOperator(SQLExecuteQueryOperator): + """ + General Teradata Operator to execute queries on Teradata Database. + + Executes sql statements in the Teradata SQL Database using teradatasql jdbc driver + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TeradataOperator` Review Comment: ```suggestion .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:TeradataOperator` ``` ########## airflow/providers/teradata/hooks/teradata.py: ########## @@ -0,0 +1,210 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A Airflow Hook for interacting with Teradata SQL Server.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar + +import sqlalchemy +import teradatasql +from teradatasql import TeradataConnection + +from airflow.providers.common.sql.hooks.sql import DbApiHook + +T = TypeVar("T") +if TYPE_CHECKING: + from airflow.models.connection import Connection + + +class TeradataHook(DbApiHook): + """General hook for interacting with Teradata SQL Database. + + This module contains basic APIs to connect to and interact with Teradata SQL Database. It uses teradatasql + client internally as a database driver for connecting to Teradata database. The config parameters like + Teradata DB Server URL, username, password and database name are fetched from the predefined connection + config connection_id. It raises an airflow error if the given connection id doesn't exist. + + See :doc:` docs/apache-airflow-providers-teradata/connections/teradata.rst` for full documentation. + + :param args: passed to DbApiHook + :param kwargs: passed to DbApiHook + + + Usage Help: + + >>> tdh = TeradataHook() + >>> sql = "SELECT top 1 _airbyte_ab_id from airbyte_td._airbyte_raw_Sales;" + >>> tdh.get_records(sql) + [[61ad1d63-3efd-4da4-9904-a4489cc3a520]] + + """ + + # Override to provide the connection name. + conn_name_attr = "teradata_conn_id" + + # Override to have a default connection id for a particular dbHook + default_conn_name = "teradata_default" + + # Override if this db supports autocommit. + supports_autocommit = True + + # Override this for hook to have a custom name in the UI selection + conn_type = "teradata" + + # Override hook name to give descriptive name for hook + hook_name = "Teradata" + + # Override with the Teradata specific placeholder parameter string used for insert queries + placeholder: str = "?" + + # Override SQL query to be used for testing database connection + _test_connection_sql = "select 1" + + def __init__( + self, + *args, + database: str | None = None, Review Comment: Can you add this parameter to the docstring please? Otherwise, it will not appear in the Python API docs. ########## tests/providers/teradata/hooks/test_teradata.py: ########## @@ -0,0 +1,272 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from datetime import datetime +from unittest import mock + +import pytest + +from airflow.models import Connection + +try: + from airflow.providers.teradata.hooks.teradata import TeradataHook +except ImportError: + pytest.skip( + "Airflow Provider for Teradata not available, unable to import dependency " + "airflow.providers.teradata.hooks.teradata.TeradataHook", + allow_module_level=True, + ) Review Comment: Is this relevant for the System Test Dashboard? The provider would be available in CI in this repo. ########## airflow/providers/teradata/transfers/teradata_to_teradata.py: ########## @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from airflow.models import BaseOperator +from airflow.providers.teradata.hooks.teradata import TeradataHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class TeradataToTeradataOperator(BaseOperator): + """ + Moves data from Teradata source database to Teradata destination database. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TeradataToTeradataOperator` + + :param dest_teradata_conn_id: destination Teradata connection. + :param destination_table: destination table to insert rows. + :param source_teradata_conn_id: :ref:`Source Teradata connection <howto/connection:Teradata>`. + :param sql: SQL query to execute against the source Teradata database + :param sql_params: Parameters to use in sql query. + :param rows_chunk: number of rows per chunk to commit. + """ + + template_fields: Sequence[str] = ( + "sql", + "sql_params", + ) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql", "sql_params": "py"} + ui_color = "#e07c24" + + def __init__( + self, + *, + dest_teradata_conn_id: str, + destination_table: str, + source_teradata_conn_id: str, + sql: str, + sql_params: dict | None = None, + rows_chunk: int = 5000, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if sql_params is None: + sql_params = {} + self.dest_teradata_conn_id = dest_teradata_conn_id + self.destination_table = destination_table + self.source_teradata_conn_id = source_teradata_conn_id + self.sql = sql + self.sql_params = sql_params + self.rows_chunk = rows_chunk + + def _execute(self, src_hook, dest_hook, context) -> None: + with src_hook.get_conn() as src_conn: + cursor = src_conn.cursor() + cursor.execute(self.sql, self.sql_params) + target_fields = [field[0] for field in cursor.description] + rows_total = 0 + for rows in iter(lambda: cursor.fetchmany(self.rows_chunk), []): + dest_hook.bulk_insert_rows( + self.destination_table, rows, target_fields=target_fields, commit_every=self.rows_chunk + ) + rows_total += len(rows) Review Comment: I might be missing it, but is `rows_total` used anywhere? Seems useful to have it in a logging statement. ########## pyproject.toml: ########## @@ -533,6 +533,7 @@ alibaba = [ "oss2>=2.14.0", ] amazon = [ + "PyAthena>=3.0.10", Review Comment: Is this a related change? ########## airflow/providers/teradata/hooks/teradata.py: ########## @@ -0,0 +1,210 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A Airflow Hook for interacting with Teradata SQL Server.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar + +import sqlalchemy +import teradatasql +from teradatasql import TeradataConnection + +from airflow.providers.common.sql.hooks.sql import DbApiHook + +T = TypeVar("T") +if TYPE_CHECKING: + from airflow.models.connection import Connection + + +class TeradataHook(DbApiHook): + """General hook for interacting with Teradata SQL Database. + + This module contains basic APIs to connect to and interact with Teradata SQL Database. It uses teradatasql + client internally as a database driver for connecting to Teradata database. The config parameters like + Teradata DB Server URL, username, password and database name are fetched from the predefined connection + config connection_id. It raises an airflow error if the given connection id doesn't exist. + + See :doc:` docs/apache-airflow-providers-teradata/connections/teradata.rst` for full documentation. + + :param args: passed to DbApiHook + :param kwargs: passed to DbApiHook + + + Usage Help: + + >>> tdh = TeradataHook() + >>> sql = "SELECT top 1 _airbyte_ab_id from airbyte_td._airbyte_raw_Sales;" + >>> tdh.get_records(sql) + [[61ad1d63-3efd-4da4-9904-a4489cc3a520]] + + """ + + # Override to provide the connection name. + conn_name_attr = "teradata_conn_id" + + # Override to have a default connection id for a particular dbHook + default_conn_name = "teradata_default" + + # Override if this db supports autocommit. + supports_autocommit = True + + # Override this for hook to have a custom name in the UI selection + conn_type = "teradata" + + # Override hook name to give descriptive name for hook + hook_name = "Teradata" + + # Override with the Teradata specific placeholder parameter string used for insert queries + placeholder: str = "?" + + # Override SQL query to be used for testing database connection + _test_connection_sql = "select 1" + + def __init__( + self, + *args, + database: str | None = None, + **kwargs, + ) -> None: + super().__init__(*args, schema=database, **kwargs) + + def get_conn(self) -> TeradataConnection: + """Creates and returns a Teradata Connection object using teradatasql client. + + Establishes connection to a Teradata SQL database using config corresponding to teradata_conn_id. + + .. note:: By default it connects to the database via the teradatasql library. + But you can also choose the mysql-connector-python library which lets you connect through ssl + without any further ssl parameters required. + + :return: a mysql connection object + """ + teradata_conn_config: dict = self._get_conn_config_teradatasql() + teradata_conn = teradatasql.connect(**teradata_conn_config) + return teradata_conn + + def bulk_insert_rows( + self, + table: str, + rows: list[tuple], + target_fields: list[str] | None = None, + commit_every: int = 5000, + ): + """A bulk insert of records for Teradata SQL Database. + + This uses prepared statements via `executemany()`. For best performance, + pass in `rows` as an iterator. + + :param table: target Teradata database table, use dot notation to target a + specific database + :param rows: the rows to insert into the table + :param target_fields: the names of the columns to fill in the table, default None. + If None, each rows should have some order as table columns name + :param commit_every: the maximum number of rows to insert in one transaction + Default 5000. Set greater than 0. Set 1 to insert each row in each transaction + """ + if not rows: + raise ValueError("parameter rows could not be None or empty iterable") + conn = self.get_conn() + if self.supports_autocommit: + self.set_autocommit(conn, False) + cursor = conn.cursor() + cursor.fast_executemany = True + values_base = target_fields if target_fields else rows[0] + prepared_stm = "INSERT INTO {tablename} {columns} VALUES ({values})".format( + tablename=table, + columns="({})".format(", ".join(target_fields)) if target_fields else "", + values=", ".join("?" for i in range(1, len(values_base) + 1)), + ) + row_count = 0 + # Chunk the rows + row_chunk = [] + for row in rows: + row_chunk.append(row) + row_count += 1 + if row_count % commit_every == 0: + cursor.executemany(prepared_stm, row_chunk) + conn.commit() # type: ignore[attr-defined] + # Empty chunk + row_chunk = [] + # Commit the leftover chunk + if len(row_chunk) > 0: + cursor.executemany(prepared_stm, row_chunk) + conn.commit() # type: ignore[attr-defined] + self.log.info("[%s] inserted %s rows", table, row_count) + cursor.close() + conn.close() # type: ignore[attr-defined] + + def _get_conn_config_teradatasql(self) -> dict[str, Any]: + """Returns set of config params required for connecting to Teradata DB using teradatasql client.""" + conn: Connection = self.get_connection(getattr(self, self.conn_name_attr)) + conn_config = { + "host": conn.host or "localhost", + "dbs_port": conn.port or "1025", + "database": conn.schema or "", + "user": conn.login or "dbc", + "password": conn.password or "dbc", + } + + if conn.extra_dejson.get("tmode", False): + conn_config["tmode"] = conn.extra_dejson["tmode"] + + # Handling SSL connection parameters + + if conn.extra_dejson.get("sslmode", False): + conn_config["sslmode"] = conn.extra_dejson["sslmode"] + if "verify" in conn_config["sslmode"]: + if conn.extra_dejson.get("sslca", False): + conn_config["sslca"] = conn.extra_dejson["sslca"] + if conn.extra_dejson.get("sslcapath", False): + conn_config["sslcapath"] = conn.extra_dejson["sslcapath"] + if conn.extra_dejson.get("sslcipher", False): + conn_config["sslcipher"] = conn.extra_dejson["sslcipher"] + if conn.extra_dejson.get("sslcrc", False): + conn_config["sslcrc"] = conn.extra_dejson["sslcrc"] + if conn.extra_dejson.get("sslprotocol", False): + conn_config["sslprotocol"] = conn.extra_dejson["sslprotocol"] + + return conn_config + + def get_sqlalchemy_engine(self, engine_kwargs=None): + """Returns a connection object using sqlalchemy.""" + conn: Connection = self.get_connection(getattr(self, self.conn_name_attr)) + link = f"teradatasql://{conn.login}:{conn.password}@{conn.host}" + connection = sqlalchemy.create_engine(link) + return connection + + @staticmethod + def get_ui_field_behaviour() -> dict: + """Returns custom field behaviour.""" + import json + + return { + "hidden_fields": ["port"], + "relabeling": { + "host": "Database Server URL", + "schema": "Database Name", + "login": "Username", + "password": "Password", Review Comment: ```suggestion ``` In the Airflow UI's Connection form, the `password` field is labeled as "Password" by default. ########## tests/providers/teradata/operators/test_teradata.py: ########## @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import MagicMock, Mock + +import pytest + +from airflow.models.dag import DAG +from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.utils import timezone + +try: + from airflow.providers.teradata.hooks.teradata import TeradataHook + from airflow.providers.teradata.operators.teradata import TeradataOperator +except ImportError: + pytest.skip("Teradata not available", allow_module_level=True) + +from airflow.exceptions import AirflowException + +DEFAULT_DATE = timezone.datetime(2015, 1, 1) +TEST_DAG_ID = "unit_test_dag" + + +class TestTeradataOperator: + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + dag = DAG(TEST_DAG_ID, default_args=args) + self.dag = dag + + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") + def test_get_hook_from_conn(self, mock_get_db_hook): + """ + :class:`~.MsSqlOperator` should use the hook returned by :meth:`airflow.models.Connection.get_hook` + if one is returned. + + This behavior is necessary in order to support usage of :class:`~.OdbcHook` with this operator. + + Specifically we verify here that :meth:`~.MsSqlOperator.get_hook` returns the hook returned from a + call of ``get_hook`` on the object returned from :meth:`~.BaseHook.get_connection`. Review Comment: Can you update this docstring as needed? The references to other providers don't seem relevant. ########## airflow/providers/teradata/hooks/teradata.py: ########## @@ -0,0 +1,210 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A Airflow Hook for interacting with Teradata SQL Server.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar + +import sqlalchemy +import teradatasql +from teradatasql import TeradataConnection + +from airflow.providers.common.sql.hooks.sql import DbApiHook + +T = TypeVar("T") +if TYPE_CHECKING: + from airflow.models.connection import Connection + + +class TeradataHook(DbApiHook): + """General hook for interacting with Teradata SQL Database. + + This module contains basic APIs to connect to and interact with Teradata SQL Database. It uses teradatasql + client internally as a database driver for connecting to Teradata database. The config parameters like + Teradata DB Server URL, username, password and database name are fetched from the predefined connection + config connection_id. It raises an airflow error if the given connection id doesn't exist. + + See :doc:` docs/apache-airflow-providers-teradata/connections/teradata.rst` for full documentation. + + :param args: passed to DbApiHook + :param kwargs: passed to DbApiHook + + + Usage Help: + + >>> tdh = TeradataHook() + >>> sql = "SELECT top 1 _airbyte_ab_id from airbyte_td._airbyte_raw_Sales;" + >>> tdh.get_records(sql) + [[61ad1d63-3efd-4da4-9904-a4489cc3a520]] + + """ + + # Override to provide the connection name. + conn_name_attr = "teradata_conn_id" + + # Override to have a default connection id for a particular dbHook + default_conn_name = "teradata_default" + + # Override if this db supports autocommit. + supports_autocommit = True + + # Override this for hook to have a custom name in the UI selection + conn_type = "teradata" + + # Override hook name to give descriptive name for hook + hook_name = "Teradata" + + # Override with the Teradata specific placeholder parameter string used for insert queries + placeholder: str = "?" + + # Override SQL query to be used for testing database connection + _test_connection_sql = "select 1" + + def __init__( + self, + *args, + database: str | None = None, + **kwargs, + ) -> None: + super().__init__(*args, schema=database, **kwargs) + + def get_conn(self) -> TeradataConnection: + """Creates and returns a Teradata Connection object using teradatasql client. + + Establishes connection to a Teradata SQL database using config corresponding to teradata_conn_id. + + .. note:: By default it connects to the database via the teradatasql library. + But you can also choose the mysql-connector-python library which lets you connect through ssl + without any further ssl parameters required. + + :return: a mysql connection object + """ + teradata_conn_config: dict = self._get_conn_config_teradatasql() + teradata_conn = teradatasql.connect(**teradata_conn_config) + return teradata_conn + + def bulk_insert_rows( + self, + table: str, + rows: list[tuple], + target_fields: list[str] | None = None, + commit_every: int = 5000, + ): + """A bulk insert of records for Teradata SQL Database. + + This uses prepared statements via `executemany()`. For best performance, + pass in `rows` as an iterator. + + :param table: target Teradata database table, use dot notation to target a + specific database + :param rows: the rows to insert into the table + :param target_fields: the names of the columns to fill in the table, default None. + If None, each rows should have some order as table columns name + :param commit_every: the maximum number of rows to insert in one transaction + Default 5000. Set greater than 0. Set 1 to insert each row in each transaction + """ + if not rows: + raise ValueError("parameter rows could not be None or empty iterable") + conn = self.get_conn() + if self.supports_autocommit: + self.set_autocommit(conn, False) + cursor = conn.cursor() + cursor.fast_executemany = True + values_base = target_fields if target_fields else rows[0] + prepared_stm = "INSERT INTO {tablename} {columns} VALUES ({values})".format( + tablename=table, + columns="({})".format(", ".join(target_fields)) if target_fields else "", + values=", ".join("?" for i in range(1, len(values_base) + 1)), + ) + row_count = 0 + # Chunk the rows + row_chunk = [] + for row in rows: + row_chunk.append(row) + row_count += 1 + if row_count % commit_every == 0: + cursor.executemany(prepared_stm, row_chunk) + conn.commit() # type: ignore[attr-defined] + # Empty chunk + row_chunk = [] + # Commit the leftover chunk + if len(row_chunk) > 0: + cursor.executemany(prepared_stm, row_chunk) + conn.commit() # type: ignore[attr-defined] + self.log.info("[%s] inserted %s rows", table, row_count) + cursor.close() + conn.close() # type: ignore[attr-defined] + + def _get_conn_config_teradatasql(self) -> dict[str, Any]: + """Returns set of config params required for connecting to Teradata DB using teradatasql client.""" + conn: Connection = self.get_connection(getattr(self, self.conn_name_attr)) + conn_config = { + "host": conn.host or "localhost", + "dbs_port": conn.port or "1025", + "database": conn.schema or "", + "user": conn.login or "dbc", + "password": conn.password or "dbc", + } + + if conn.extra_dejson.get("tmode", False): + conn_config["tmode"] = conn.extra_dejson["tmode"] + + # Handling SSL connection parameters + + if conn.extra_dejson.get("sslmode", False): + conn_config["sslmode"] = conn.extra_dejson["sslmode"] + if "verify" in conn_config["sslmode"]: + if conn.extra_dejson.get("sslca", False): + conn_config["sslca"] = conn.extra_dejson["sslca"] + if conn.extra_dejson.get("sslcapath", False): + conn_config["sslcapath"] = conn.extra_dejson["sslcapath"] + if conn.extra_dejson.get("sslcipher", False): + conn_config["sslcipher"] = conn.extra_dejson["sslcipher"] + if conn.extra_dejson.get("sslcrc", False): + conn_config["sslcrc"] = conn.extra_dejson["sslcrc"] + if conn.extra_dejson.get("sslprotocol", False): + conn_config["sslprotocol"] = conn.extra_dejson["sslprotocol"] + + return conn_config + + def get_sqlalchemy_engine(self, engine_kwargs=None): + """Returns a connection object using sqlalchemy.""" + conn: Connection = self.get_connection(getattr(self, self.conn_name_attr)) + link = f"teradatasql://{conn.login}:{conn.password}@{conn.host}" + connection = sqlalchemy.create_engine(link) + return connection + + @staticmethod + def get_ui_field_behaviour() -> dict: + """Returns custom field behaviour.""" + import json + + return { + "hidden_fields": ["port"], + "relabeling": { + "host": "Database Server URL", + "schema": "Database Name", + "login": "Username", + "password": "Password", + }, + "placeholders": { + "extra": json.dumps({"example_parameter": "parameter"}, indent=4), Review Comment: There is a nice example in the connection doc that would be a great placeholder here instead IMO. ########## docs/apache-airflow-providers-teradata/index.rst: ########## @@ -0,0 +1,134 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +``apache-airflow-providers-teradata`` +===================================== + + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Basics + + Home <self> + Changelog <changelog> + Security <security> + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Guides + + Connection types <connections/teradata> + Operators <operators/index> + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: References + + Python API <_api/airflow/providers/teradata/index> + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: System tests + + System Tests <_api/tests/system/providers/teradata/index> + System Tests Dashboard <https://teradata.github.io/airflow/index.html> + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Resources + + Example DAGs <https://github.com/apache/airflow/tree/providers-teradata/|version|/airflow/providers/teradata/example_dags> + PyPI Repository <https://pypi.org/project/apache-airflow-providers-teradata/> + Installing from sources <installing-providers-from-sources> + +.. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME! + + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Commits + + Detailed list of commits <commits> + + +Package apache-airflow-providers-teradata +------------------------------------------------------ + +`Teradata <https://www.teradata.com/>`__ + + +Release: 1.0.0 + +Provider package +---------------- + +This is a provider package for ``teradata`` provider. All classes for this provider package +are in ``airflow.providers.teradata`` python package. + +Installation +------------ + +You can install this package on top of an existing Airflow 2 installation (see ``Requirements`` below) +for the minimum Airflow version supported) via +``pip install apache-airflow-providers-teradata`` + +Requirements +------------ + +The minimum Apache Airflow version supported by this provider package is ``2.6.0``. + +======================================= ================== +PIP package Version required +======================================= ================== +``apache-airflow`` ``>=2.6.0`` +``apache-airflow-providers-common-sql`` ``>=1.3.1`` Review Comment: There `teradata...` packages are missing here. Can you align this with the `provider.yaml` dependencies please? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org