justinpakzad commented on code in PR #67045:
URL: https://github.com/apache/airflow/pull/67045#discussion_r3253624632


##########
providers/postgres/src/airflow/providers/postgres/hooks/postgres.py:
##########
@@ -720,3 +726,137 @@ def insert_rows(
 
         self.log.info("Done loading. Loaded a total of %s rows into %s", 
nb_rows, table)
         return None
+
+    def _generate_upsert_sql(
+        self,
+        table: str,
+        values: tuple[Any, ...] | list[Any],
+        target_fields: list[str],
+        conflict_fields: list[str],
+        update_fields: list[str] | None = None,
+        **kwargs,

Review Comment:
   I don't think we need the kwargs here as there is nothing consuming them.



##########
providers/postgres/src/airflow/providers/postgres/hooks/postgres.py:
##########
@@ -720,3 +726,137 @@ def insert_rows(
 
         self.log.info("Done loading. Loaded a total of %s rows into %s", 
nb_rows, table)
         return None
+
+    def _generate_upsert_sql(
+        self,
+        table: str,
+        values: tuple[Any, ...] | list[Any],

Review Comment:
   Do we need to pass in the values here? The only thing it's used for is to 
produce the right number of placeholders but I think that can just be done with 
`len(target_fields)`.



##########
providers/postgres/src/airflow/providers/postgres/hooks/postgres.py:
##########
@@ -720,3 +726,137 @@ def insert_rows(
 
         self.log.info("Done loading. Loaded a total of %s rows into %s", 
nb_rows, table)
         return None
+
+    def _generate_upsert_sql(
+        self,
+        table: str,
+        values: tuple[Any, ...] | list[Any],
+        target_fields: list[str],
+        conflict_fields: list[str],
+        update_fields: list[str] | None = None,
+        **kwargs,
+    ) -> str:
+        """
+        Generate PostgreSQL UPSERT SQL using ``ON CONFLICT``.
+
+        :param table: Name of target table.
+        :param values: Row values used for placeholder generation.
+        :param target_fields: Non-empty column names used in the ``INSERT`` 
statement.
+        :param conflict_fields: Non-empty column names used in the ``ON 
CONFLICT`` clause.
+        :param update_fields: Columns to update on conflict. If omitted or 
empty,
+            ``DO NOTHING`` is used.
+        """
+        placeholders = ", ".join(["%s"] * len(values))
+
+        columns = ", ".join(target_fields)
+
+        conflict_clause = ", ".join(conflict_fields)
+
+        sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders}) "
+
+        if update_fields:
+            update_clause = ", ".join(f"{field} = EXCLUDED.{field}" for field 
in update_fields)
+
+            sql += f"ON CONFLICT ({conflict_clause}) DO UPDATE SET 
{update_clause}"
+        else:
+            sql += f"ON CONFLICT ({conflict_clause}) DO NOTHING"
+
+        return sql.strip()
+
+    def upsert_rows(
+        self,
+        table: str,
+        rows: Iterable[tuple[Any, ...]],
+        target_fields: list[str],
+        conflict_fields: list[str],
+        update_fields: list[str] | None = None,
+        commit_every: int = 1000,
+        *,
+        fast_executemany: bool = False,
+        autocommit: bool = False,
+    ) -> None:
+        """
+        Upsert rows into a PostgreSQL table using ``ON CONFLICT``.
+
+        :param table: Name of the target table.
+        :param rows: Rows to upsert.
+        :param target_fields: Non-empty column names used in the ``INSERT`` 
statement.
+        :param conflict_fields: Non-empty column names used in the ``ON 
CONFLICT`` clause.
+        :param update_fields: Columns updated on conflict. If omitted or empty,
+            conflicting rows are ignored via ``DO NOTHING``.
+        :param commit_every: Maximum number of rows per transaction. Default 
value is 1000.
+        :param fast_executemany: Use ``psycopg2.extras.execute_batch`` for 
improved
+            batch performance.
+        :param autocommit: Connection autocommit setting.
+        """
+        if not target_fields or any(not field for field in target_fields):
+            raise ValueError("target_fields must be provided and must not be 
empty.")
+
+        if not conflict_fields or any(not field for field in conflict_fields):
+            raise ValueError("conflict_fields must be provided and must not be 
empty.")
+
+        rows = iter(rows)
+
+        nb_rows = 0
+        sql = None
+
+        with self._create_autocommit_connection(autocommit) as conn:
+            conn.commit()
+
+            with closing(conn.cursor()) as cur:
+                for chunked_rows in chunked(rows, commit_every):
+                    values = [self._serialize_cells(row, conn) for row in 
chunked_rows]
+
+                    if not values:
+                        continue
+
+                    sql = self._generate_upsert_sql(
+                        table=table,
+                        values=values[0],
+                        target_fields=target_fields,
+                        conflict_fields=conflict_fields,
+                        update_fields=update_fields,
+                    )

Review Comment:
   Related to my comment above - if we don't need to pass values since it's 
just used for the number of placeholders, I don't think we need to regenerate 
the same SQL string on every chunk. This can just be done once outside of the 
loop.



##########
providers/postgres/src/airflow/providers/postgres/hooks/postgres.py:
##########
@@ -720,3 +726,137 @@ def insert_rows(
 
         self.log.info("Done loading. Loaded a total of %s rows into %s", 
nb_rows, table)
         return None
+
+    def _generate_upsert_sql(
+        self,
+        table: str,
+        values: tuple[Any, ...] | list[Any],
+        target_fields: list[str],
+        conflict_fields: list[str],
+        update_fields: list[str] | None = None,
+        **kwargs,
+    ) -> str:
+        """
+        Generate PostgreSQL UPSERT SQL using ``ON CONFLICT``.
+
+        :param table: Name of target table.
+        :param values: Row values used for placeholder generation.
+        :param target_fields: Non-empty column names used in the ``INSERT`` 
statement.
+        :param conflict_fields: Non-empty column names used in the ``ON 
CONFLICT`` clause.
+        :param update_fields: Columns to update on conflict. If omitted or 
empty,
+            ``DO NOTHING`` is used.
+        """
+        placeholders = ", ".join(["%s"] * len(values))
+
+        columns = ", ".join(target_fields)
+
+        conflict_clause = ", ".join(conflict_fields)
+
+        sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders}) "
+
+        if update_fields:
+            update_clause = ", ".join(f"{field} = EXCLUDED.{field}" for field 
in update_fields)
+
+            sql += f"ON CONFLICT ({conflict_clause}) DO UPDATE SET 
{update_clause}"
+        else:
+            sql += f"ON CONFLICT ({conflict_clause}) DO NOTHING"
+
+        return sql.strip()
+
+    def upsert_rows(
+        self,
+        table: str,
+        rows: Iterable[tuple[Any, ...]],
+        target_fields: list[str],
+        conflict_fields: list[str],
+        update_fields: list[str] | None = None,
+        commit_every: int = 1000,
+        *,
+        fast_executemany: bool = False,
+        autocommit: bool = False,
+    ) -> None:
+        """
+        Upsert rows into a PostgreSQL table using ``ON CONFLICT``.
+
+        :param table: Name of the target table.
+        :param rows: Rows to upsert.
+        :param target_fields: Non-empty column names used in the ``INSERT`` 
statement.
+        :param conflict_fields: Non-empty column names used in the ``ON 
CONFLICT`` clause.
+        :param update_fields: Columns updated on conflict. If omitted or empty,
+            conflicting rows are ignored via ``DO NOTHING``.
+        :param commit_every: Maximum number of rows per transaction. Default 
value is 1000.
+        :param fast_executemany: Use ``psycopg2.extras.execute_batch`` for 
improved
+            batch performance.
+        :param autocommit: Connection autocommit setting.
+        """
+        if not target_fields or any(not field for field in target_fields):
+            raise ValueError("target_fields must be provided and must not be 
empty.")
+
+        if not conflict_fields or any(not field for field in conflict_fields):
+            raise ValueError("conflict_fields must be provided and must not be 
empty.")
+
+        rows = iter(rows)
+
+        nb_rows = 0
+        sql = None
+
+        with self._create_autocommit_connection(autocommit) as conn:
+            conn.commit()
+
+            with closing(conn.cursor()) as cur:
+                for chunked_rows in chunked(rows, commit_every):
+                    values = [self._serialize_cells(row, conn) for row in 
chunked_rows]
+
+                    if not values:
+                        continue
+
+                    sql = self._generate_upsert_sql(
+                        table=table,
+                        values=values[0],
+                        target_fields=target_fields,
+                        conflict_fields=conflict_fields,
+                        update_fields=update_fields,
+                    )
+
+                    self.log.debug("Generated sql: %s", sql)
+
+                    try:
+                        if fast_executemany:
+                            # execute_batch reduces round trips by batching 
parameter sets.
+                            execute_batch(
+                                cur,
+                                sql,
+                                values,
+                                page_size=commit_every,
+                            )
+                        else:
+                            cur.executemany(sql, values)
+
+                    except Exception:
+                        self.log.error("Generated sql: %s", sql)
+                        self.log.error("Parameters: %s", values)
+                        raise
+
+                    conn.commit()
+
+                    nb_rows += len(values)
+
+                    self.log.info(
+                        "Upserted %s rows into %s so far",
+                        nb_rows,
+                        table,
+                    )
+
+        if sql:

Review Comment:
   Also related to my comment above, if we construct the query once outside the 
loop then this would need to be updated. This could be `if nb_rows > 0`.



##########
providers/postgres/src/airflow/providers/postgres/hooks/postgres.py:
##########
@@ -720,3 +726,137 @@ def insert_rows(
 
         self.log.info("Done loading. Loaded a total of %s rows into %s", 
nb_rows, table)
         return None
+
+    def _generate_upsert_sql(
+        self,
+        table: str,
+        values: tuple[Any, ...] | list[Any],
+        target_fields: list[str],
+        conflict_fields: list[str],
+        update_fields: list[str] | None = None,
+        **kwargs,
+    ) -> str:
+        """
+        Generate PostgreSQL UPSERT SQL using ``ON CONFLICT``.
+
+        :param table: Name of target table.
+        :param values: Row values used for placeholder generation.
+        :param target_fields: Non-empty column names used in the ``INSERT`` 
statement.
+        :param conflict_fields: Non-empty column names used in the ``ON 
CONFLICT`` clause.
+        :param update_fields: Columns to update on conflict. If omitted or 
empty,
+            ``DO NOTHING`` is used.
+        """
+        placeholders = ", ".join(["%s"] * len(values))
+
+        columns = ", ".join(target_fields)
+
+        conflict_clause = ", ".join(conflict_fields)
+
+        sql = f"INSERT INTO {table} ({columns}) VALUES ({placeholders}) "
+
+        if update_fields:
+            update_clause = ", ".join(f"{field} = EXCLUDED.{field}" for field 
in update_fields)
+
+            sql += f"ON CONFLICT ({conflict_clause}) DO UPDATE SET 
{update_clause}"
+        else:
+            sql += f"ON CONFLICT ({conflict_clause}) DO NOTHING"
+
+        return sql.strip()
+
+    def upsert_rows(
+        self,
+        table: str,
+        rows: Iterable[tuple[Any, ...]],
+        target_fields: list[str],
+        conflict_fields: list[str],
+        update_fields: list[str] | None = None,
+        commit_every: int = 1000,
+        *,
+        fast_executemany: bool = False,
+        autocommit: bool = False,
+    ) -> None:
+        """
+        Upsert rows into a PostgreSQL table using ``ON CONFLICT``.
+
+        :param table: Name of the target table.
+        :param rows: Rows to upsert.
+        :param target_fields: Non-empty column names used in the ``INSERT`` 
statement.
+        :param conflict_fields: Non-empty column names used in the ``ON 
CONFLICT`` clause.
+        :param update_fields: Columns updated on conflict. If omitted or empty,
+            conflicting rows are ignored via ``DO NOTHING``.
+        :param commit_every: Maximum number of rows per transaction. Default 
value is 1000.
+        :param fast_executemany: Use ``psycopg2.extras.execute_batch`` for 
improved
+            batch performance.
+        :param autocommit: Connection autocommit setting.
+        """
+        if not target_fields or any(not field for field in target_fields):
+            raise ValueError("target_fields must be provided and must not be 
empty.")
+
+        if not conflict_fields or any(not field for field in conflict_fields):
+            raise ValueError("conflict_fields must be provided and must not be 
empty.")
+
+        rows = iter(rows)
+
+        nb_rows = 0
+        sql = None
+
+        with self._create_autocommit_connection(autocommit) as conn:
+            conn.commit()
+
+            with closing(conn.cursor()) as cur:
+                for chunked_rows in chunked(rows, commit_every):
+                    values = [self._serialize_cells(row, conn) for row in 
chunked_rows]
+
+                    if not values:
+                        continue
+
+                    sql = self._generate_upsert_sql(
+                        table=table,
+                        values=values[0],
+                        target_fields=target_fields,
+                        conflict_fields=conflict_fields,
+                        update_fields=update_fields,
+                    )
+
+                    self.log.debug("Generated sql: %s", sql)
+
+                    try:
+                        if fast_executemany:
+                            # execute_batch reduces round trips by batching 
parameter sets.
+                            execute_batch(

Review Comment:
   Do we need a guard here since psycopg3 does not support `execute_batch`? 
Maybe using the `USE_PSYCOPG3` constant that's used in other parts of the code. 
Either logging a warning and defaulting back to `cur.executemany` or raising an 
error.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to