This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new dc7c03ec5a [spark] AbstractSparkInternalRow supports fallback equals 
and hashCode (#5101)
dc7c03ec5a is described below

commit dc7c03ec5a94b00f0107cb8d89e53bad6cc03b55
Author: Xiduo You <[email protected]>
AuthorDate: Sat Feb 22 22:00:38 2025 +0800

    [spark] AbstractSparkInternalRow supports fallback equals and hashCode 
(#5101)
---
 .../java/org/apache/paimon/data/GenericMap.java    |   4 +
 .../org/apache/paimon/utils/InternalRowUtils.java  | 131 +++++++++++++++++++++
 .../apache/paimon/utils/InternalRowUtilsTest.java  |  78 ++++++++++++
 .../paimon/spark/AbstractSparkInternalRow.java     |  18 ++-
 4 files changed, 229 insertions(+), 2 deletions(-)

diff --git a/paimon-common/src/main/java/org/apache/paimon/data/GenericMap.java 
b/paimon-common/src/main/java/org/apache/paimon/data/GenericMap.java
index 0b196c0757..0e07e80a5f 100644
--- a/paimon-common/src/main/java/org/apache/paimon/data/GenericMap.java
+++ b/paimon-common/src/main/java/org/apache/paimon/data/GenericMap.java
@@ -64,6 +64,10 @@ public final class GenericMap implements InternalMap, 
Serializable {
         return map.get(key);
     }
 
+    public boolean contains(Object key) {
+        return map.containsKey(key);
+    }
+
     @Override
     public int size() {
         return map.size();
diff --git 
a/paimon-common/src/main/java/org/apache/paimon/utils/InternalRowUtils.java 
b/paimon-common/src/main/java/org/apache/paimon/utils/InternalRowUtils.java
index bd46bae631..b052690f0f 100644
--- a/paimon-common/src/main/java/org/apache/paimon/utils/InternalRowUtils.java
+++ b/paimon-common/src/main/java/org/apache/paimon/utils/InternalRowUtils.java
@@ -48,6 +48,7 @@ import javax.annotation.Nullable;
 import java.math.BigDecimal;
 import java.math.RoundingMode;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -55,6 +56,127 @@ import java.util.Map;
 /** Utils for {@link InternalRow} structures. */
 public class InternalRowUtils {
 
+    public static boolean equals(Object data1, Object data2, DataType 
dataType) {
+        if ((data1 == null) != (data2 == null)) {
+            return false;
+        }
+        if (data1 != null) {
+            if (data1 instanceof InternalRow) {
+                RowType rowType = (RowType) dataType;
+                int len = rowType.getFieldCount();
+                for (int i = 0; i < len; i++) {
+                    Object value1 = get((InternalRow) data1, i, 
rowType.getTypeAt(i));
+                    Object value2 = get((InternalRow) data2, i, 
rowType.getTypeAt(i));
+                    if (!equals(value1, value2, rowType.getTypeAt(i))) {
+                        return false;
+                    }
+                }
+            } else if (data1 instanceof InternalArray) {
+                if (((InternalArray) data1).size() != ((InternalArray) 
data2).size()) {
+                    return false;
+                }
+                ArrayType arrayType = (ArrayType) dataType;
+                for (int i = 0; i < ((InternalArray) data1).size(); i++) {
+                    Object value1 = get((InternalArray) data1, i, 
arrayType.getElementType());
+                    Object value2 = get((InternalArray) data2, i, 
arrayType.getElementType());
+                    if (!equals(value1, value2, arrayType.getElementType())) {
+                        return false;
+                    }
+                }
+            } else if (data1 instanceof InternalMap) {
+                if (((InternalMap) data1).size() != ((InternalMap) 
data2).size()) {
+                    return false;
+                }
+                MapType mapType = (MapType) dataType;
+                GenericMap map1;
+                GenericMap map2;
+                if (data1 instanceof GenericMap) {
+                    map1 = (GenericMap) data1;
+                    map2 = (GenericMap) data2;
+                } else {
+                    map1 =
+                            copyToGenericMap(
+                                    (InternalMap) data1,
+                                    mapType.getKeyType(),
+                                    mapType.getValueType());
+                    map2 =
+                            copyToGenericMap(
+                                    (InternalMap) data2,
+                                    mapType.getKeyType(),
+                                    mapType.getValueType());
+                }
+                InternalArray keyArray1 = map1.keyArray();
+                for (int i = 0; i < map1.size(); i++) {
+                    Object key = get(keyArray1, i, mapType.getKeyType());
+                    if (!map2.contains(key)
+                            || !equals(map1.get(key), map2.get(key), 
mapType.getValueType())) {
+                        return false;
+                    }
+                }
+            } else if (data1 instanceof byte[]) {
+                if (!java.util.Arrays.equals((byte[]) data1, (byte[]) data2)) {
+                    return false;
+                }
+            } else if (data1 instanceof Float && java.lang.Float.isNaN((Float) 
data1)) {
+                if (!java.lang.Float.isNaN((Float) data2)) {
+                    return false;
+                }
+            } else if (data1 instanceof Double && 
java.lang.Double.isNaN((Double) data1)) {
+                if (!java.lang.Double.isNaN((Double) data2)) {
+                    return false;
+                }
+            } else {
+                if (!data1.equals(data2)) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
+    public static int hash(Object data, DataType dataType) {
+        if (data == null) {
+            return 0;
+        }
+        int result = 0;
+        if (data instanceof InternalRow) {
+            RowType rowType = (RowType) dataType;
+            int len = rowType.getFieldCount();
+            for (int i = 0; i < len; i++) {
+                Object v = get((InternalRow) data, i, rowType.getTypeAt(i));
+                result = 37 * result + hash(v, rowType.getTypeAt(i));
+            }
+        } else if (data instanceof InternalArray) {
+            ArrayType arrayType = (ArrayType) dataType;
+            int len = ((InternalArray) data).size();
+            for (int i = 0; i < len; i++) {
+                Object v = get((InternalArray) data, i, 
arrayType.getElementType());
+                result = 37 * result + hash(v, arrayType.getElementType());
+            }
+        } else if (data instanceof InternalMap) {
+            MapType mapType = (MapType) dataType;
+            GenericMap map;
+            if (data instanceof GenericMap) {
+                map = (GenericMap) data;
+            } else {
+                map =
+                        copyToGenericMap(
+                                (InternalMap) data, mapType.getKeyType(), 
mapType.getValueType());
+            }
+            InternalArray keyArray = map.keyArray();
+            for (int i = 0; i < map.size(); i++) {
+                Object key = get(keyArray, i, mapType.getKeyType());
+                result = 37 * result + hash(key, mapType.getKeyType());
+                result = 37 * result + hash(map.get(key), 
mapType.getValueType());
+            }
+        } else if (data instanceof byte[]) {
+            result = Arrays.hashCode((byte[]) data);
+        } else {
+            result = data.hashCode();
+        }
+        return result;
+    }
+
     public static InternalRow copyInternalRow(InternalRow row, RowType 
rowType) {
         if (row instanceof BinaryRow) {
             return ((BinaryRow) row).copy();
@@ -117,6 +239,11 @@ public class InternalRowUtils {
             return ((BinaryMap) map).copy();
         }
 
+        return copyToGenericMap(map, keyType, valueType);
+    }
+
+    private static GenericMap copyToGenericMap(
+            InternalMap map, DataType keyType, DataType valueType) {
         Map<Object, Object> javaMap = new HashMap<>();
         InternalArray keys = map.keyArray();
         InternalArray values = map.valueArray();
@@ -145,6 +272,10 @@ public class InternalRowUtils {
                 return copyMap(
                         (InternalMap) o, ((MultisetType) 
type).getElementType(), new IntType());
             }
+        } else if (o instanceof byte[]) {
+            byte[] copy = new byte[((byte[]) o).length];
+            System.arraycopy(((byte[]) o), 0, copy, 0, ((byte[]) o).length);
+            return copy;
         } else if (o instanceof Decimal) {
             return ((Decimal) o).copy();
         }
diff --git 
a/paimon-common/src/test/java/org/apache/paimon/utils/InternalRowUtilsTest.java 
b/paimon-common/src/test/java/org/apache/paimon/utils/InternalRowUtilsTest.java
index ea3bd98cfe..70d32c928c 100644
--- 
a/paimon-common/src/test/java/org/apache/paimon/utils/InternalRowUtilsTest.java
+++ 
b/paimon-common/src/test/java/org/apache/paimon/utils/InternalRowUtilsTest.java
@@ -21,6 +21,9 @@ package org.apache.paimon.utils;
 import org.apache.paimon.data.BinaryRow;
 import org.apache.paimon.data.BinaryString;
 import org.apache.paimon.data.Decimal;
+import org.apache.paimon.data.GenericArray;
+import org.apache.paimon.data.GenericMap;
+import org.apache.paimon.data.GenericRow;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.data.Timestamp;
 import org.apache.paimon.data.serializer.InternalRowSerializer;
@@ -37,6 +40,8 @@ import org.junit.jupiter.api.Test;
 
 import java.math.BigDecimal;
 import java.time.LocalDateTime;
+import java.util.HashMap;
+import java.util.Map;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -52,6 +57,7 @@ public class InternalRowUtilsTest {
                     .field("intArray", 
DataTypes.ARRAY(DataTypes.INT()).nullable())
                     .field("char", DataTypes.CHAR(10).notNull())
                     .field("varchar", DataTypes.VARCHAR(10).notNull())
+                    .field("binary", DataTypes.BINARY(10).notNull())
                     .field("boolean", DataTypes.BOOLEAN().nullable())
                     .field("tinyint", DataTypes.TINYINT())
                     .field("smallint", DataTypes.SMALLINT())
@@ -144,4 +150,76 @@ public class InternalRowUtilsTest {
                                 DataTypeRoot.VARCHAR))
                 .isLessThan(0);
     }
+
+    @Test
+    public void testEqualsAndHashCode() {
+        for (int i = 0; i < 10; i++) {
+            GenericRow row1 = (GenericRow) rowDataGenerator.next();
+            GenericRow row2 = (GenericRow) 
InternalRowUtils.copyInternalRow(row1, ROW_TYPE);
+            GenericRow row3 = (GenericRow) rowDataGenerator.next();
+            assertThat(InternalRowUtils.equals(row1, row2, ROW_TYPE)).isTrue();
+            assertThat(InternalRowUtils.equals(row1, row3, 
ROW_TYPE)).isFalse();
+
+            assertThat(InternalRowUtils.hash(row1, ROW_TYPE))
+                    .isEqualTo(InternalRowUtils.hash(row2, ROW_TYPE));
+            assertThat(InternalRowUtils.hash(row1, ROW_TYPE))
+                    .isNotEqualTo(InternalRowUtils.hash(row3, ROW_TYPE));
+        }
+
+        RowType rowType =
+                RowType.builder()
+                        .field("f1", DataTypes.DOUBLE())
+                        .field("f2", DataTypes.FLOAT())
+                        .field("f3", DataTypes.BINARY(3))
+                        .field("f4", DataTypes.STRING())
+                        .field("f5", 
DataTypes.ARRAY(DataTypes.ROW(DataTypes.INT())))
+                        .field(
+                                "f6",
+                                DataTypes.MAP(DataTypes.STRING(), 
DataTypes.ROW(DataTypes.INT())))
+                        .field("f7", DataTypes.ROW(DataTypes.INT()))
+                        .build();
+        GenericRow row1 = new GenericRow(7);
+        row1.setField(0, Double.NaN);
+        row1.setField(1, Float.NaN);
+        row1.setField(2, "abc".getBytes());
+        row1.setField(3, null);
+        row1.setField(4, new GenericArray(new GenericRow[] {GenericRow.of(1), 
GenericRow.of(10)}));
+        Map<BinaryString, InternalRow> map = new HashMap<>();
+        map.put(BinaryString.fromString("a"), GenericRow.of(1));
+        map.put(BinaryString.fromString("b"), GenericRow.of(2));
+        row1.setField(5, new GenericMap(map));
+        row1.setField(6, GenericRow.of(1));
+        GenericRow row2 = (GenericRow) InternalRowUtils.copyInternalRow(row1, 
rowType);
+        assertThat(InternalRowUtils.equals(row1, row2, rowType)).isTrue();
+        assertThat(InternalRowUtils.hash(row1, rowType))
+                .isEqualTo(InternalRowUtils.hash(row2, rowType));
+    }
+
+    @Test
+    public void testEqualsAndHashCodeNegativeCase() {
+        // different array len
+        RowType rowType = RowType.builder().field("f1", 
DataTypes.ARRAY(DataTypes.INT())).build();
+        GenericRow rowWithArray1 = new GenericRow(1);
+        rowWithArray1.setField(
+                0, new GenericArray(new GenericRow[] {GenericRow.of(1), 
GenericRow.of(10)}));
+        GenericRow rowWithArray2 = new GenericRow(1);
+        rowWithArray2.setField(0, new GenericArray(new GenericRow[] 
{GenericRow.of(1)}));
+        assertThat(InternalRowUtils.equals(rowWithArray1, rowWithArray2, 
rowType)).isFalse();
+
+        // different map len
+        RowType rowType2 =
+                RowType.builder()
+                        .field("f1", DataTypes.MAP(DataTypes.STRING(), 
DataTypes.INT()))
+                        .build();
+        Map<BinaryString, InternalRow> map1 = new HashMap<>();
+        map1.put(BinaryString.fromString("a"), GenericRow.of(1));
+        map1.put(BinaryString.fromString("b"), GenericRow.of(2));
+        GenericRow rowWithMap1 = new GenericRow(1);
+        rowWithMap1.setField(0, new GenericMap(map1));
+        Map<BinaryString, InternalRow> map2 = new HashMap<>();
+        map2.put(BinaryString.fromString("a"), GenericRow.of(1));
+        GenericRow rowWithMap2 = new GenericRow(1);
+        rowWithMap2.setField(0, new GenericMap(map2));
+        assertThat(InternalRowUtils.equals(rowWithMap1, rowWithMap2, 
rowType2)).isFalse();
+    }
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java
index 28604a6d62..283077430e 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java
@@ -25,6 +25,7 @@ import org.apache.paimon.types.BigIntType;
 import org.apache.paimon.types.DataType;
 import org.apache.paimon.types.DataTypeChecks;
 import org.apache.paimon.types.RowType;
+import org.apache.paimon.utils.InternalRowUtils;
 
 import org.apache.spark.sql.catalyst.util.ArrayData;
 import org.apache.spark.sql.catalyst.util.MapData;
@@ -251,11 +252,24 @@ public abstract class AbstractSparkInternalRow extends 
SparkInternalRow {
             return false;
         }
         AbstractSparkInternalRow that = (AbstractSparkInternalRow) o;
-        return Objects.equals(rowType, that.rowType) && Objects.equals(row, 
that.row);
+        if (Objects.equals(rowType, that.rowType)) {
+            try {
+                return Objects.equals(row, that.row);
+            } catch (Exception e) {
+                // The underlying row may not support equals or hashcode, 
e.g., `ProjectedRow`,
+                // to be safe, we fallback to a slow performance version.
+                return InternalRowUtils.equals(row, that.row, rowType);
+            }
+        }
+        return false;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(rowType, row);
+        try {
+            return Objects.hash(rowType, row);
+        } catch (Exception e) {
+            return InternalRowUtils.hash(row, rowType);
+        }
     }
 }

Reply via email to