tscottcoombes1 commented on code in PR #1534:
URL: https://github.com/apache/iceberg-python/pull/1534#discussion_r1943711166


##########
pyiceberg/table/__init__.py:
##########
@@ -1064,6 +1067,78 @@ def name_mapping(self) -> Optional[NameMapping]:
         """Return the table's field-id NameMapping."""
         return self.metadata.name_mapping()
 
+    @dataclass(frozen=True)
+    class UpsertResult:
+        """Summary the upsert operation"""
+        rows_updated: int = 0
+        rows_inserted: int = 0
+        info_msgs: Optional[str] = None
+        error_msgs: Optional[str] = None
+
+    def upsert(self, df: pa.Table, join_cols: list
+                   , when_matched_update_all: bool = True
+                   , when_not_matched_insert_all: bool = True
+                ) -> UpsertResult:
+        """
+        Shorthand API for performing an upsert to an iceberg table.
+        
+        Args:
+            df: The input dataframe to upsert with the table's data.
+            join_cols: The columns to join on.
+            when_matched_update_all: Bool indicating to update rows that are 
matched but require an update due to a value in a non-key column changing
+            when_not_matched_insert_all: Bool indicating new rows to be 
inserted that do not match any existing rows in the table
+
+        Returns: a UpsertResult class
+        """
+
+        from pyiceberg.table import upsert_util
+
+        if when_matched_update_all == False and when_not_matched_insert_all == 
False:
+            return {'rows_updated': 0, 'rows_inserted': 0, 'info_msgs': 'no 
upsert options selected...exiting'}
+            #return UpsertResult(info_msgs='no upsert options 
selected...exiting')
+
+        if upsert_util.dups_check_in_source(df, join_cols):
+
+            return {'error_msgs': 'Duplicate rows found in source dataset 
based on the key columns. No upsert executed'}
+
+        #get list of rows that exist so we don't have to load the entire 
target table
+        pred = upsert_util.get_filter_list(df, join_cols)
+        iceberg_table_trimmed = self.scan(row_filter=pred).to_arrow()
+
+        update_row_cnt = 0
+        insert_row_cnt = 0
+
+        try:
+
+            with self.transaction() as txn:
+            
+                if when_matched_update_all:
+
+                    update_recs = upsert_util.get_rows_to_update(df, 
iceberg_table_trimmed, join_cols)
+
+                    update_row_cnt = len(update_recs)
+
+                    overwrite_filter = 
upsert_util.get_filter_list(update_recs, join_cols)
+
+                    txn.overwrite(update_recs, 
overwrite_filter=overwrite_filter)    
+
+
+                if when_not_matched_insert_all:
+                    
+                    insert_recs = upsert_util.get_rows_to_insert(df, 
iceberg_table_trimmed, join_cols)
+
+                    insert_row_cnt = len(insert_recs)
+
+                    txn.append(insert_recs)
+
+            return {
+                "rows_updated": update_row_cnt,
+                "rows_inserted": insert_row_cnt
+            }
+
+        except Exception as e:

Review Comment:
   I'm not sure what this accomplishes, if there is an error in the transaction 
block, just let it raise through the stack



##########
tests/table/test_upsert.py:
##########
@@ -0,0 +1,327 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyiceberg.catalog.sql import SqlCatalog
+from pyiceberg.catalog import Table as pyiceberg_table
+import os
+import shutil
+import pytest
+
+_TEST_NAMESPACE = "test_ns"
+
+try:
+    from datafusion import SessionContext
+except ModuleNotFoundError as e:
+    raise ModuleNotFoundError("For upsert testing, DataFusion needs to be 
installed") from e
+
+def get_test_warehouse_path():
+    curr_dir = os.path.dirname(os.path.abspath(__file__))
+    return f"{curr_dir}/warehouse"
+
+def purge_warehouse():
+    warehouse_path = get_test_warehouse_path()
+    if os.path.exists(warehouse_path):
+        shutil.rmtree(warehouse_path)
+
+def show_iceberg_table(table, ctx: SessionContext):
+    import pyarrow.dataset as ds
+    table_name = "target"
+    if ctx.table_exist(table_name):
+        ctx.deregister_table(table_name)
+    ctx.register_dataset(table_name, ds.dataset(table.scan().to_arrow()))
+    ctx.sql(f"SELECT * FROM {table_name} limit 5").show()
+
+def show_df(df, ctx: SessionContext):
+    import pyarrow.dataset as ds
+    ctx.register_dataset("df", ds.dataset(df))
+    ctx.sql("select * from df limit 10").show()
+
+def gen_source_dataset(start_row: int, end_row: int, composite_key: bool, 
add_dup: bool, ctx: SessionContext):
+
+    additional_columns = ", t.order_id + 1000 as order_line_id" if 
composite_key else ""
+
+    dup_row = f"""
+        UNION ALL
+        (
+        SELECT t.order_id {additional_columns}
+            , date '2021-01-01' as order_date, 'B' as order_type
+        from t
+        limit 1
+        )
+    """ if add_dup else ""
+
+
+    sql = f"""
+        with t as (SELECT unnest(range({start_row},{end_row+1})) as order_id)
+        SELECT t.order_id {additional_columns}
+            , date '2021-01-01' as order_date, 'B' as order_type
+        from t
+        {dup_row}
+    """
+
+    df = ctx.sql(sql).to_arrow_table()
+
+    return df
+
+def gen_target_iceberg_table_v2(start_row: int, end_row: int, composite_key: 
bool, ctx: SessionContext, catalog: SqlCatalog, namespace: str):
+
+    additional_columns = ", t.order_id + 1000 as order_line_id" if 
composite_key else ""
+
+    df = ctx.sql(f"""
+        with t as (SELECT unnest(range({start_row},{end_row+1})) as order_id)
+        SELECT t.order_id {additional_columns}
+            , date '2021-01-01' as order_date, 'A' as order_type
+        from t
+    """).to_arrow_table()
+
+    table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
+
+    table.append(df)
+
+    return table
+
[email protected](scope="session")
+def catalog_conn():
+    warehouse_path = get_test_warehouse_path()
+    os.makedirs(warehouse_path, exist_ok=True)
+    print(warehouse_path)
+    catalog = SqlCatalog(
+        "default",
+        **{
+            "uri": f"sqlite:///:memory:",
+            "warehouse": f"file://{warehouse_path}",
+        },
+    )
+
+    catalog.create_namespace(namespace="test_ns")
+
+    yield catalog
+
[email protected](
+    "join_cols, src_start_row, src_end_row, target_start_row, target_end_row, 
when_matched_update_all, when_not_matched_insert_all, expected_updated, 
expected_inserted",
+    [
+        (["order_id"], 1, 2, 2, 3, True, True, 1, 1), # single row
+        (["order_id"], 5001, 15000, 1, 10000, True, True, 5000, 5000), #10k 
rows
+        (["order_id"], 501, 1500, 1, 1000, True, False, 500, 0), # update only
+        (["order_id"], 501, 1500, 1, 1000, False, True, 0, 500), # insert only

Review Comment:
   pytest.mark.parametrize accepts ids=[] instead of comments



##########
pyiceberg/table/upsert_util.py:
##########
@@ -0,0 +1,158 @@
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from pyarrow import Table as pyarrow_table
+import pyarrow as pa
+from pyarrow import compute as pc
+from pyiceberg import table as pyiceberg_table
+
+from pyiceberg.expressions import (
+    BooleanExpression,
+    And,
+    EqualTo,
+    Or,
+    In,
+)
+
+def get_filter_list(df: pyarrow_table, join_cols: list) -> BooleanExpression:
+
+    unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
+
+    pred = None
+
+    if len(join_cols) == 1:
+        pred = In(join_cols[0], unique_keys[0].to_pylist())
+    else:
+        pred = Or(*[
+            And(*[
+                EqualTo(col, row[col])
+                for col in join_cols
+            ])
+            for row in unique_keys.to_pylist()
+        ])
+
+    return pred
+
+def dups_check_in_source(df: pyarrow_table, join_cols: list) -> bool:
+    """
+    This function checks if there are duplicate rows in the source table based 
on the join columns.
+    It returns True if there are duplicate rows in the source table, otherwise 
it returns False.
+    """
+    # Check for duplicates in the source table
+    source_dup_count = len(
+        df.select(join_cols)
+            .group_by(join_cols)
+            .aggregate([([], "count_all")])
+            .filter(pc.field("count_all") > 1)
+    )
+    
+    return source_dup_count > 0
+
+def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, 
join_cols: list) -> pa.Table:
+    
+    """
+        This function takes the source_table, trims it down to rows that match 
in both source and target.
+        It then does a scan for the non-key columns to see if any are 
mis-aligned before returning the final row set to update
+    """
+    
+    all_columns = set(source_table.column_names)
+    join_cols_set = set(join_cols)
+
+    non_key_cols = list(all_columns - join_cols_set)
+
+    
+    match_expr = None
+
+    for col in join_cols:
+        target_values = target_table.column(col).to_pylist()
+        expr = pc.field(col).isin(target_values)
+
+        if match_expr is None:
+            match_expr = expr
+        else:
+            match_expr = match_expr & expr
+
+    
+    matching_source_rows = source_table.filter(match_expr)
+
+    rows_to_update = []
+
+    for index in range(matching_source_rows.num_rows):
+        
+        source_row = matching_source_rows.slice(index, 1)
+
+        
+        target_filter = None
+
+        for col in join_cols:
+            target_value = source_row.column(col)[0].as_py()  
+            if target_filter is None:
+                target_filter = pc.field(col) == target_value
+            else:
+                target_filter = target_filter & (pc.field(col) == target_value)
+
+        matching_target_row = target_table.filter(target_filter)
+
+        if matching_target_row.num_rows > 0:
+            needs_update = False
+
+            for non_key_col in non_key_cols:
+                source_value = source_row.column(non_key_col)[0].as_py()
+                target_value = 
matching_target_row.column(non_key_col)[0].as_py()
+
+                if source_value != target_value:
+                    needs_update = True
+                    break 
+
+            if needs_update:
+                rows_to_update.append(source_row)
+
+    if rows_to_update:
+        rows_to_update_table = pa.concat_tables(rows_to_update)
+    else:
+        rows_to_update_table = pa.Table.from_arrays([], 
names=source_table.column_names)
+
+    common_columns = 
set(source_table.column_names).intersection(set(target_table.column_names))
+    rows_to_update_table = rows_to_update_table.select(list(common_columns))
+
+    return rows_to_update_table
+
+def get_rows_to_insert(source_table: pa.Table, target_table: pa.Table, 
join_cols: list) -> pa.Table:
+  
+    source_filter_expr = None
+
+    for col in join_cols:
+
+        target_values = target_table.column(col).to_pylist()
+        expr = pc.field(col).isin(target_values)
+
+        if source_filter_expr is None:
+            source_filter_expr = expr
+        else:
+            source_filter_expr = source_filter_expr & expr

Review Comment:
   ```suggestion
       masks = [
           pc.field(col).isin(target_table.column(col).to_pylist())
           for col in join_cols
       ]
       
       mask = [pa.compute.all(x) for x in zip(*masks)]
   ```



##########
tests/table/test_upsert.py:
##########
@@ -0,0 +1,327 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyiceberg.catalog.sql import SqlCatalog
+from pyiceberg.catalog import Table as pyiceberg_table
+import os
+import shutil
+import pytest
+
+_TEST_NAMESPACE = "test_ns"
+
+try:
+    from datafusion import SessionContext
+except ModuleNotFoundError as e:
+    raise ModuleNotFoundError("For upsert testing, DataFusion needs to be 
installed") from e
+
+def get_test_warehouse_path():
+    curr_dir = os.path.dirname(os.path.abspath(__file__))
+    return f"{curr_dir}/warehouse"
+
+def purge_warehouse():
+    warehouse_path = get_test_warehouse_path()
+    if os.path.exists(warehouse_path):
+        shutil.rmtree(warehouse_path)
+
+def show_iceberg_table(table, ctx: SessionContext):
+    import pyarrow.dataset as ds
+    table_name = "target"
+    if ctx.table_exist(table_name):
+        ctx.deregister_table(table_name)
+    ctx.register_dataset(table_name, ds.dataset(table.scan().to_arrow()))
+    ctx.sql(f"SELECT * FROM {table_name} limit 5").show()
+
+def show_df(df, ctx: SessionContext):
+    import pyarrow.dataset as ds
+    ctx.register_dataset("df", ds.dataset(df))
+    ctx.sql("select * from df limit 10").show()
+
+def gen_source_dataset(start_row: int, end_row: int, composite_key: bool, 
add_dup: bool, ctx: SessionContext):
+
+    additional_columns = ", t.order_id + 1000 as order_line_id" if 
composite_key else ""
+
+    dup_row = f"""
+        UNION ALL
+        (
+        SELECT t.order_id {additional_columns}
+            , date '2021-01-01' as order_date, 'B' as order_type
+        from t
+        limit 1
+        )
+    """ if add_dup else ""
+
+
+    sql = f"""
+        with t as (SELECT unnest(range({start_row},{end_row+1})) as order_id)
+        SELECT t.order_id {additional_columns}
+            , date '2021-01-01' as order_date, 'B' as order_type
+        from t
+        {dup_row}
+    """
+
+    df = ctx.sql(sql).to_arrow_table()
+
+    return df
+
+def gen_target_iceberg_table_v2(start_row: int, end_row: int, composite_key: 
bool, ctx: SessionContext, catalog: SqlCatalog, namespace: str):
+
+    additional_columns = ", t.order_id + 1000 as order_line_id" if 
composite_key else ""
+
+    df = ctx.sql(f"""
+        with t as (SELECT unnest(range({start_row},{end_row+1})) as order_id)
+        SELECT t.order_id {additional_columns}
+            , date '2021-01-01' as order_date, 'A' as order_type
+        from t
+    """).to_arrow_table()
+
+    table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
+
+    table.append(df)
+
+    return table
+
[email protected](scope="session")
+def catalog_conn():
+    warehouse_path = get_test_warehouse_path()
+    os.makedirs(warehouse_path, exist_ok=True)
+    print(warehouse_path)
+    catalog = SqlCatalog(
+        "default",
+        **{
+            "uri": f"sqlite:///:memory:",
+            "warehouse": f"file://{warehouse_path}",
+        },
+    )
+
+    catalog.create_namespace(namespace="test_ns")
+
+    yield catalog
+
[email protected](
+    "join_cols, src_start_row, src_end_row, target_start_row, target_end_row, 
when_matched_update_all, when_not_matched_insert_all, expected_updated, 
expected_inserted",
+    [
+        (["order_id"], 1, 2, 2, 3, True, True, 1, 1), # single row
+        (["order_id"], 5001, 15000, 1, 10000, True, True, 5000, 5000), #10k 
rows
+        (["order_id"], 501, 1500, 1, 1000, True, False, 500, 0), # update only
+        (["order_id"], 501, 1500, 1, 1000, False, True, 0, 500), # insert only
+    ]
+)
+def test_merge_rows(catalog_conn, join_cols, src_start_row, src_end_row, 
target_start_row, target_end_row
+                    , when_matched_update_all, when_not_matched_insert_all, 
expected_updated, expected_inserted):
+
+    ctx = SessionContext()
+
+    catalog = catalog_conn
+
+    source_df = gen_source_dataset(src_start_row, src_end_row, False, False, 
ctx)
+    ice_table = gen_target_iceberg_table_v2(target_start_row, target_end_row, 
False, ctx, catalog, _TEST_NAMESPACE)
+    res = ice_table.upsert(df=source_df, join_cols=join_cols, 
when_matched_update_all=when_matched_update_all, 
when_not_matched_insert_all=when_not_matched_insert_all)
+
+    assert res['rows_updated'] == expected_updated, f"rows updated should be 
{expected_updated}, but got {res['rows_updated']}"
+    assert res['rows_inserted'] == expected_inserted, f"rows inserted should 
be {expected_inserted}, but got {res['rows_inserted']}"
+
+    catalog.drop_table(f"{_TEST_NAMESPACE}.target")
+
[email protected](scope="session", autouse=True)
+def cleanup():
+    yield  # This allows other tests to run first
+    purge_warehouse()  
+
+def test_merge_scenario_skip_upd_row(catalog_conn):
+
+    """
+        tests a single insert and update; skips a row that does not need to be 
updated
+    """
+
+
+    ctx = SessionContext()
+
+    df = ctx.sql(f"""
+        select 1 as order_id, date '2021-01-01' as order_date, 'A' as 
order_type
+        union all
+        select 2 as order_id, date '2021-01-01' as order_date, 'A' as 
order_type
+    """).to_arrow_table()
+
+    catalog = catalog_conn
+    table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
+
+    table.append(df)
+
+    source_df = ctx.sql(f"""
+        select 1 as order_id, date '2021-01-01' as order_date, 'A' as 
order_type
+        union all
+        select 2 as order_id, date '2021-01-01' as order_date, 'B' as 
order_type  
+        union all 
+        select 3 as order_id, date '2021-01-01' as order_date, 'A' as 
order_type
+    """).to_arrow_table()
+
+    res = table.upsert(df=source_df, join_cols=["order_id"])
+
+    rows_updated_should_be = 1
+    rows_inserted_should_be = 1
+
+    assert res['rows_updated'] == rows_updated_should_be, f"rows updated 
should be {rows_updated_should_be}, but got {res['rows_updated']}"
+    assert res['rows_inserted'] == rows_inserted_should_be, f"rows inserted 
should be {rows_inserted_should_be}, but got {res['rows_inserted']}"
+
+    catalog.drop_table(f"{_TEST_NAMESPACE}.target")
+
+def test_merge_scenario_date_as_key(catalog_conn):
+
+    """
+        tests a single insert and update; primary key is a date column
+    """
+
+    ctx = SessionContext()
+
+    df = ctx.sql(f"""
+        select date '2021-01-01' as order_date, 'A' as order_type
+        union all
+        select date '2021-01-02' as order_date, 'A' as order_type
+    """).to_arrow_table()
+
+    catalog = catalog_conn
+    table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
+
+    table.append(df)
+
+    source_df = ctx.sql(f"""
+        select date '2021-01-01' as order_date, 'A' as order_type
+        union all
+        select date '2021-01-02' as order_date, 'B' as order_type  
+        union all 
+        select date '2021-01-03' as order_date, 'A' as order_type
+    """).to_arrow_table()
+
+    res = table.upsert(df=source_df, join_cols=["order_date"])
+
+    rows_updated_should_be = 1
+    rows_inserted_should_be = 1
+
+    assert res['rows_updated'] == rows_updated_should_be, f"rows updated 
should be {rows_updated_should_be}, but got {res['rows_updated']}"
+    assert res['rows_inserted'] == rows_inserted_should_be, f"rows inserted 
should be {rows_inserted_should_be}, but got {res['rows_inserted']}"
+
+    catalog.drop_table(f"{_TEST_NAMESPACE}.target")
+
+def test_merge_scenario_string_as_key(catalog_conn):
+
+    """
+        tests a single insert and update; primary key is a string column
+    """
+
+    ctx = SessionContext()
+
+    df = ctx.sql(f"""
+        select 'abc' as order_id, 'A' as order_type
+        union all
+        select 'def' as order_id, 'A' as order_type
+    """).to_arrow_table()
+
+    catalog = catalog_conn
+    table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
+
+    table.append(df)
+
+    source_df = ctx.sql(f"""
+        select 'abc' as order_id, 'A' as order_type
+        union all
+        select 'def' as order_id, 'B' as order_type  
+        union all 
+        select 'ghi' as order_id, 'A' as order_type
+    """).to_arrow_table()
+
+    res = table.upsert(df=source_df, join_cols=["order_id"])
+
+    rows_updated_should_be = 1
+    rows_inserted_should_be = 1
+
+    assert res['rows_updated'] == rows_updated_should_be, f"rows updated 
should be {rows_updated_should_be}, but got {res['rows_updated']}"
+    assert res['rows_inserted'] == rows_inserted_should_be, f"rows inserted 
should be {rows_inserted_should_be}, but got {res['rows_inserted']}"
+
+    catalog.drop_table(f"{_TEST_NAMESPACE}.target")
+
+def test_merge_scenario_composite_key(catalog_conn):
+
+    """
+        tests merging 200 rows with a composite key
+    """
+
+    ctx = SessionContext()
+
+    catalog = catalog_conn
+    table = gen_target_iceberg_table_v2(1, 200, True, ctx, catalog, 
_TEST_NAMESPACE)
+    source_df = gen_source_dataset(101, 300, True, False, ctx)
+    
+
+    res = table.upsert(df=source_df, join_cols=["order_id", "order_line_id"])
+
+    rows_updated_should_be = 100
+    rows_inserted_should_be = 100
+
+    assert res['rows_updated'] == rows_updated_should_be, f"rows updated 
should be {rows_updated_should_be}, but got {res['rows_updated']}"
+    assert res['rows_inserted'] == rows_inserted_should_be, f"rows inserted 
should be {rows_inserted_should_be}, but got {res['rows_inserted']}"
+
+    catalog.drop_table(f"{_TEST_NAMESPACE}.target")
+
+def test_merge_source_dups(catalog_conn):
+
+    """
+        tests duplicate rows in source
+    """
+
+    ctx = SessionContext()
+
+
+    catalog = catalog_conn
+    table = gen_target_iceberg_table_v2(1, 10, False, ctx, catalog, 
_TEST_NAMESPACE)
+    source_df = gen_source_dataset(5, 15, False, True, ctx)
+    
+    res = table.upsert(df=source_df, join_cols=["order_id"])
+
+    error_msgs = res['error_msgs']
+
+    assert 'Duplicate rows found in source dataset' in error_msgs, f"error 
message should contain 'Duplicate rows found in source dataset', but got 
{error_msgs}"
+
+    catalog.drop_table(f"{_TEST_NAMESPACE}.target")
+
+def test_key_cols_misaligned(catalog_conn):
+
+    """
+        tests join columns missing from one of the tables
+    """
+
+    ctx = SessionContext()
+
+    df = ctx.sql("select 1 as order_id, date '2021-01-01' as order_date, 'A' 
as order_type").to_arrow_table()
+
+    catalog = catalog_conn
+    table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
+
+    table.append(df)
+
+    df_src = ctx.sql("select 1 as item_id, date '2021-05-01' as order_date, 
'B' as order_type").to_arrow_table()
+
+    try:
+
+        res = table.upsert(df=df_src, join_cols=['order_id'])
+
+    except KeyError as e:
+        error_msgs = str(e)

Review Comment:
   ```suggestion
       with pytest.raises(KeyError):
           res = table.upsert(df=df_src, join_cols=['order_id'])
   ```
   
   tests that your code raises an exception. Also if we might through this, we 
should handle it and raise a custom exception to say we are missing a column



##########
tests/table/test_upsert.py:
##########
@@ -0,0 +1,327 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyiceberg.catalog.sql import SqlCatalog
+from pyiceberg.catalog import Table as pyiceberg_table
+import os
+import shutil
+import pytest
+
+_TEST_NAMESPACE = "test_ns"
+
+try:
+    from datafusion import SessionContext
+except ModuleNotFoundError as e:
+    raise ModuleNotFoundError("For upsert testing, DataFusion needs to be 
installed") from e
+
+def get_test_warehouse_path():
+    curr_dir = os.path.dirname(os.path.abspath(__file__))
+    return f"{curr_dir}/warehouse"
+
+def purge_warehouse():
+    warehouse_path = get_test_warehouse_path()
+    if os.path.exists(warehouse_path):
+        shutil.rmtree(warehouse_path)

Review Comment:
   This example creates the catalog and also does the cleanup: 
https://github.com/apache/iceberg-python/blob/main/tests/catalog/integration_test_glue.py#L52



##########
tests/table/test_upsert.py:
##########
@@ -0,0 +1,327 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyiceberg.catalog.sql import SqlCatalog
+from pyiceberg.catalog import Table as pyiceberg_table
+import os
+import shutil
+import pytest
+
+_TEST_NAMESPACE = "test_ns"
+
+try:
+    from datafusion import SessionContext
+except ModuleNotFoundError as e:
+    raise ModuleNotFoundError("For upsert testing, DataFusion needs to be 
installed") from e
+
+def get_test_warehouse_path():
+    curr_dir = os.path.dirname(os.path.abspath(__file__))
+    return f"{curr_dir}/warehouse"
+
+def purge_warehouse():
+    warehouse_path = get_test_warehouse_path()
+    if os.path.exists(warehouse_path):
+        shutil.rmtree(warehouse_path)
+
+def show_iceberg_table(table, ctx: SessionContext):
+    import pyarrow.dataset as ds
+    table_name = "target"
+    if ctx.table_exist(table_name):
+        ctx.deregister_table(table_name)
+    ctx.register_dataset(table_name, ds.dataset(table.scan().to_arrow()))
+    ctx.sql(f"SELECT * FROM {table_name} limit 5").show()
+
+def show_df(df, ctx: SessionContext):
+    import pyarrow.dataset as ds
+    ctx.register_dataset("df", ds.dataset(df))
+    ctx.sql("select * from df limit 10").show()
+
+def gen_source_dataset(start_row: int, end_row: int, composite_key: bool, 
add_dup: bool, ctx: SessionContext):
+
+    additional_columns = ", t.order_id + 1000 as order_line_id" if 
composite_key else ""
+
+    dup_row = f"""
+        UNION ALL
+        (
+        SELECT t.order_id {additional_columns}
+            , date '2021-01-01' as order_date, 'B' as order_type
+        from t
+        limit 1
+        )
+    """ if add_dup else ""
+
+
+    sql = f"""
+        with t as (SELECT unnest(range({start_row},{end_row+1})) as order_id)
+        SELECT t.order_id {additional_columns}
+            , date '2021-01-01' as order_date, 'B' as order_type
+        from t
+        {dup_row}
+    """
+
+    df = ctx.sql(sql).to_arrow_table()
+
+    return df
+
+def gen_target_iceberg_table_v2(start_row: int, end_row: int, composite_key: 
bool, ctx: SessionContext, catalog: SqlCatalog, namespace: str):
+
+    additional_columns = ", t.order_id + 1000 as order_line_id" if 
composite_key else ""
+
+    df = ctx.sql(f"""
+        with t as (SELECT unnest(range({start_row},{end_row+1})) as order_id)
+        SELECT t.order_id {additional_columns}
+            , date '2021-01-01' as order_date, 'A' as order_type
+        from t
+    """).to_arrow_table()
+
+    table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
+
+    table.append(df)
+
+    return table
+
[email protected](scope="session")
+def catalog_conn():
+    warehouse_path = get_test_warehouse_path()
+    os.makedirs(warehouse_path, exist_ok=True)
+    print(warehouse_path)
+    catalog = SqlCatalog(
+        "default",
+        **{
+            "uri": f"sqlite:///:memory:",
+            "warehouse": f"file://{warehouse_path}",
+        },
+    )
+
+    catalog.create_namespace(namespace="test_ns")
+
+    yield catalog
+
[email protected](
+    "join_cols, src_start_row, src_end_row, target_start_row, target_end_row, 
when_matched_update_all, when_not_matched_insert_all, expected_updated, 
expected_inserted",
+    [
+        (["order_id"], 1, 2, 2, 3, True, True, 1, 1), # single row
+        (["order_id"], 5001, 15000, 1, 10000, True, True, 5000, 5000), #10k 
rows
+        (["order_id"], 501, 1500, 1, 1000, True, False, 500, 0), # update only
+        (["order_id"], 501, 1500, 1, 1000, False, True, 0, 500), # insert only
+    ]
+)
+def test_merge_rows(catalog_conn, join_cols, src_start_row, src_end_row, 
target_start_row, target_end_row
+                    , when_matched_update_all, when_not_matched_insert_all, 
expected_updated, expected_inserted):
+
+    ctx = SessionContext()
+
+    catalog = catalog_conn
+
+    source_df = gen_source_dataset(src_start_row, src_end_row, False, False, 
ctx)
+    ice_table = gen_target_iceberg_table_v2(target_start_row, target_end_row, 
False, ctx, catalog, _TEST_NAMESPACE)
+    res = ice_table.upsert(df=source_df, join_cols=join_cols, 
when_matched_update_all=when_matched_update_all, 
when_not_matched_insert_all=when_not_matched_insert_all)
+
+    assert res['rows_updated'] == expected_updated, f"rows updated should be 
{expected_updated}, but got {res['rows_updated']}"
+    assert res['rows_inserted'] == expected_inserted, f"rows inserted should 
be {expected_inserted}, but got {res['rows_inserted']}"
+
+    catalog.drop_table(f"{_TEST_NAMESPACE}.target")
+
[email protected](scope="session", autouse=True)
+def cleanup():
+    yield  # This allows other tests to run first
+    purge_warehouse()  
+
+def test_merge_scenario_skip_upd_row(catalog_conn):

Review Comment:
   these tests could be folded into the parameterised one if you wanted to



##########
pyiceberg/table/upsert_util.py:
##########
@@ -0,0 +1,158 @@
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from pyarrow import Table as pyarrow_table
+import pyarrow as pa
+from pyarrow import compute as pc
+from pyiceberg import table as pyiceberg_table
+
+from pyiceberg.expressions import (
+    BooleanExpression,
+    And,
+    EqualTo,
+    Or,
+    In,
+)
+
+def get_filter_list(df: pyarrow_table, join_cols: list) -> BooleanExpression:
+
+    unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
+
+    pred = None
+
+    if len(join_cols) == 1:
+        pred = In(join_cols[0], unique_keys[0].to_pylist())
+    else:
+        pred = Or(*[
+            And(*[
+                EqualTo(col, row[col])
+                for col in join_cols
+            ])
+            for row in unique_keys.to_pylist()
+        ])
+
+    return pred

Review Comment:
   ```suggestion
       if len(join_cols) == 1:
           return In(join_cols[0], unique_keys[0].to_pylist())
       return Or(*[
           And(*[
               EqualTo(col, row[col])
               for col in join_cols
           ])
           for row in unique_keys.to_pylist()
       ])
   ```



##########
pyiceberg/table/__init__.py:
##########
@@ -1064,6 +1067,78 @@ def name_mapping(self) -> Optional[NameMapping]:
         """Return the table's field-id NameMapping."""
         return self.metadata.name_mapping()
 
+    @dataclass(frozen=True)
+    class UpsertResult:
+        """Summary the upsert operation"""
+        rows_updated: int = 0
+        rows_inserted: int = 0
+        info_msgs: Optional[str] = None
+        error_msgs: Optional[str] = None
+
+    def upsert(self, df: pa.Table, join_cols: list
+                   , when_matched_update_all: bool = True
+                   , when_not_matched_insert_all: bool = True
+                ) -> UpsertResult:
+        """
+        Shorthand API for performing an upsert to an iceberg table.
+        
+        Args:
+            df: The input dataframe to upsert with the table's data.
+            join_cols: The columns to join on.
+            when_matched_update_all: Bool indicating to update rows that are 
matched but require an update due to a value in a non-key column changing
+            when_not_matched_insert_all: Bool indicating new rows to be 
inserted that do not match any existing rows in the table
+
+        Returns: a UpsertResult class
+        """
+
+        from pyiceberg.table import upsert_util
+
+        if when_matched_update_all == False and when_not_matched_insert_all == 
False:
+            return {'rows_updated': 0, 'rows_inserted': 0, 'info_msgs': 'no 
upsert options selected...exiting'}
+            #return UpsertResult(info_msgs='no upsert options 
selected...exiting')
+
+        if upsert_util.dups_check_in_source(df, join_cols):

Review Comment:
   I think we should throw an exception here. Otherwise the end user needs to 
remember to check for any issues in the result.
   
   ```suggestion
           if upsert_util.dups_check_in_source(df, join_cols):
               raise DuplicateRowException
   ```



##########
pyiceberg/table/upsert_util.py:
##########
@@ -0,0 +1,158 @@
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from pyarrow import Table as pyarrow_table
+import pyarrow as pa
+from pyarrow import compute as pc
+from pyiceberg import table as pyiceberg_table
+
+from pyiceberg.expressions import (
+    BooleanExpression,
+    And,
+    EqualTo,
+    Or,
+    In,
+)
+
+def get_filter_list(df: pyarrow_table, join_cols: list) -> BooleanExpression:
+
+    unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
+
+    pred = None
+
+    if len(join_cols) == 1:
+        pred = In(join_cols[0], unique_keys[0].to_pylist())
+    else:
+        pred = Or(*[
+            And(*[
+                EqualTo(col, row[col])
+                for col in join_cols
+            ])
+            for row in unique_keys.to_pylist()
+        ])
+
+    return pred
+
+def dups_check_in_source(df: pyarrow_table, join_cols: list) -> bool:
+    """
+    This function checks if there are duplicate rows in the source table based 
on the join columns.
+    It returns True if there are duplicate rows in the source table, otherwise 
it returns False.
+    """
+    # Check for duplicates in the source table

Review Comment:
   remove



##########
tests/table/test_upsert.py:
##########
@@ -0,0 +1,327 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyiceberg.catalog.sql import SqlCatalog
+from pyiceberg.catalog import Table as pyiceberg_table
+import os
+import shutil
+import pytest
+
+_TEST_NAMESPACE = "test_ns"
+
+try:
+    from datafusion import SessionContext
+except ModuleNotFoundError as e:
+    raise ModuleNotFoundError("For upsert testing, DataFusion needs to be 
installed") from e
+
+def get_test_warehouse_path():
+    curr_dir = os.path.dirname(os.path.abspath(__file__))
+    return f"{curr_dir}/warehouse"
+
+def purge_warehouse():
+    warehouse_path = get_test_warehouse_path()
+    if os.path.exists(warehouse_path):
+        shutil.rmtree(warehouse_path)

Review Comment:
   I think you are better off using the inmemory catalog here: 
   
   
https://github.com/apache/iceberg-python/blob/main/tests/catalog/test_base.py#L272
   
   then your tests don't actually delete anything (which could eventually cause 
some confusion)



##########
pyiceberg/table/__init__.py:
##########
@@ -1064,6 +1067,78 @@ def name_mapping(self) -> Optional[NameMapping]:
         """Return the table's field-id NameMapping."""
         return self.metadata.name_mapping()
 
+    @dataclass(frozen=True)
+    class UpsertResult:
+        """Summary the upsert operation"""
+        rows_updated: int = 0
+        rows_inserted: int = 0
+        info_msgs: Optional[str] = None
+        error_msgs: Optional[str] = None
+
+    def upsert(self, df: pa.Table, join_cols: list
+                   , when_matched_update_all: bool = True
+                   , when_not_matched_insert_all: bool = True
+                ) -> UpsertResult:
+        """
+        Shorthand API for performing an upsert to an iceberg table.
+        
+        Args:
+            df: The input dataframe to upsert with the table's data.
+            join_cols: The columns to join on.
+            when_matched_update_all: Bool indicating to update rows that are 
matched but require an update due to a value in a non-key column changing
+            when_not_matched_insert_all: Bool indicating new rows to be 
inserted that do not match any existing rows in the table
+
+        Returns: a UpsertResult class
+        """
+
+        from pyiceberg.table import upsert_util
+
+        if when_matched_update_all == False and when_not_matched_insert_all == 
False:
+            return {'rows_updated': 0, 'rows_inserted': 0, 'info_msgs': 'no 
upsert options selected...exiting'}
+            #return UpsertResult(info_msgs='no upsert options 
selected...exiting')
+
+        if upsert_util.dups_check_in_source(df, join_cols):
+
+            return {'error_msgs': 'Duplicate rows found in source dataset 
based on the key columns. No upsert executed'}
+
+        #get list of rows that exist so we don't have to load the entire 
target table
+        pred = upsert_util.get_filter_list(df, join_cols)
+        iceberg_table_trimmed = self.scan(row_filter=pred).to_arrow()
+
+        update_row_cnt = 0
+        insert_row_cnt = 0
+
+        try:
+
+            with self.transaction() as txn:
+            
+                if when_matched_update_all:
+
+                    update_recs = upsert_util.get_rows_to_update(df, 
iceberg_table_trimmed, join_cols)
+
+                    update_row_cnt = len(update_recs)
+
+                    overwrite_filter = 
upsert_util.get_filter_list(update_recs, join_cols)
+
+                    txn.overwrite(update_recs, 
overwrite_filter=overwrite_filter)    
+
+
+                if when_not_matched_insert_all:
+                    
+                    insert_recs = upsert_util.get_rows_to_insert(df, 
iceberg_table_trimmed, join_cols)
+
+                    insert_row_cnt = len(insert_recs)
+
+                    txn.append(insert_recs)
+
+            return {
+                "rows_updated": update_row_cnt,
+                "rows_inserted": insert_row_cnt
+            }

Review Comment:
   this is meant to be the UpsertResult, not a dict



##########
pyiceberg/table/upsert_util.py:
##########
@@ -0,0 +1,158 @@
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from pyarrow import Table as pyarrow_table
+import pyarrow as pa
+from pyarrow import compute as pc
+from pyiceberg import table as pyiceberg_table
+
+from pyiceberg.expressions import (
+    BooleanExpression,
+    And,
+    EqualTo,
+    Or,
+    In,
+)
+
+def get_filter_list(df: pyarrow_table, join_cols: list) -> BooleanExpression:
+
+    unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
+
+    pred = None
+
+    if len(join_cols) == 1:
+        pred = In(join_cols[0], unique_keys[0].to_pylist())
+    else:
+        pred = Or(*[
+            And(*[
+                EqualTo(col, row[col])
+                for col in join_cols
+            ])
+            for row in unique_keys.to_pylist()
+        ])
+
+    return pred
+
+def dups_check_in_source(df: pyarrow_table, join_cols: list) -> bool:
+    """
+    This function checks if there are duplicate rows in the source table based 
on the join columns.
+    It returns True if there are duplicate rows in the source table, otherwise 
it returns False.
+    """
+    # Check for duplicates in the source table
+    source_dup_count = len(
+        df.select(join_cols)
+            .group_by(join_cols)
+            .aggregate([([], "count_all")])
+            .filter(pc.field("count_all") > 1)
+    )
+    
+    return source_dup_count > 0
+
+def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, 
join_cols: list) -> pa.Table:
+    
+    """
+        This function takes the source_table, trims it down to rows that match 
in both source and target.
+        It then does a scan for the non-key columns to see if any are 
mis-aligned before returning the final row set to update
+    """
+    
+    all_columns = set(source_table.column_names)
+    join_cols_set = set(join_cols)
+
+    non_key_cols = list(all_columns - join_cols_set)
+
+    
+    match_expr = None
+
+    for col in join_cols:
+        target_values = target_table.column(col).to_pylist()
+        expr = pc.field(col).isin(target_values)
+
+        if match_expr is None:
+            match_expr = expr
+        else:
+            match_expr = match_expr & expr

Review Comment:
   this is the same as the `get_rows_to_update` function, can we abstract away?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to