This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new fa7b49181e Spark 3.4: Fix expression to SQL with transforms (#8257)
fa7b49181e is described below

commit fa7b49181e9b5895b96c287db24e52d036429a97
Author: Xianyang Liu <[email protected]>
AuthorDate: Tue Aug 15 07:04:42 2023 +0800

    Spark 3.4: Fix expression to SQL with transforms (#8257)
    
    Co-authored-by: xianyangliu <[email protected]>
---
 .../java/org/apache/iceberg/spark/Spark3Util.java  | 41 +++++++++++++-------
 .../org/apache/iceberg/spark/TestSpark3Util.java   | 45 ++++++++++++++++++++++
 2 files changed, 72 insertions(+), 14 deletions(-)

diff --git 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java
index ad4e2d16b7..bbd7986b26 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java
@@ -44,6 +44,8 @@ import org.apache.iceberg.expressions.BoundPredicate;
 import org.apache.iceberg.expressions.ExpressionVisitors;
 import org.apache.iceberg.expressions.Term;
 import org.apache.iceberg.expressions.UnboundPredicate;
+import org.apache.iceberg.expressions.UnboundTerm;
+import org.apache.iceberg.expressions.UnboundTransform;
 import org.apache.iceberg.expressions.Zorder;
 import org.apache.iceberg.relocated.com.google.common.base.Joiner;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
@@ -626,38 +628,49 @@ public class Spark3Util {
     public <T> String predicate(UnboundPredicate<T> pred) {
       switch (pred.op()) {
         case IS_NULL:
-          return pred.ref().name() + " IS NULL";
+          return sqlString(pred.term()) + " IS NULL";
         case NOT_NULL:
-          return pred.ref().name() + " IS NOT NULL";
+          return sqlString(pred.term()) + " IS NOT NULL";
         case IS_NAN:
-          return "is_nan(" + pred.ref().name() + ")";
+          return "is_nan(" + sqlString(pred.term()) + ")";
         case NOT_NAN:
-          return "not_nan(" + pred.ref().name() + ")";
+          return "not_nan(" + sqlString(pred.term()) + ")";
         case LT:
-          return pred.ref().name() + " < " + sqlString(pred.literal());
+          return sqlString(pred.term()) + " < " + sqlString(pred.literal());
         case LT_EQ:
-          return pred.ref().name() + " <= " + sqlString(pred.literal());
+          return sqlString(pred.term()) + " <= " + sqlString(pred.literal());
         case GT:
-          return pred.ref().name() + " > " + sqlString(pred.literal());
+          return sqlString(pred.term()) + " > " + sqlString(pred.literal());
         case GT_EQ:
-          return pred.ref().name() + " >= " + sqlString(pred.literal());
+          return sqlString(pred.term()) + " >= " + sqlString(pred.literal());
         case EQ:
-          return pred.ref().name() + " = " + sqlString(pred.literal());
+          return sqlString(pred.term()) + " = " + sqlString(pred.literal());
         case NOT_EQ:
-          return pred.ref().name() + " != " + sqlString(pred.literal());
+          return sqlString(pred.term()) + " != " + sqlString(pred.literal());
         case STARTS_WITH:
-          return pred.ref().name() + " LIKE '" + pred.literal().value() + "%'";
+          return sqlString(pred.term()) + " LIKE '" + pred.literal().value() + 
"%'";
         case NOT_STARTS_WITH:
-          return pred.ref().name() + " NOT LIKE '" + pred.literal().value() + 
"%'";
+          return sqlString(pred.term()) + " NOT LIKE '" + 
pred.literal().value() + "%'";
         case IN:
-          return pred.ref().name() + " IN (" + sqlString(pred.literals()) + 
")";
+          return sqlString(pred.term()) + " IN (" + sqlString(pred.literals()) 
+ ")";
         case NOT_IN:
-          return pred.ref().name() + " NOT IN (" + sqlString(pred.literals()) 
+ ")";
+          return sqlString(pred.term()) + " NOT IN (" + 
sqlString(pred.literals()) + ")";
         default:
           throw new UnsupportedOperationException("Cannot convert predicate to 
SQL: " + pred);
       }
     }
 
+    private static <T> String sqlString(UnboundTerm<T> term) {
+      if (term instanceof org.apache.iceberg.expressions.NamedReference) {
+        return term.ref().name();
+      } else if (term instanceof UnboundTransform) {
+        UnboundTransform<?, ?> transform = (UnboundTransform<?, ?>) term;
+        return transform.transform().toString() + "(" + transform.ref().name() 
+ ")";
+      } else {
+        throw new UnsupportedOperationException("Cannot convert term to SQL: " 
+ term);
+      }
+    }
+
     private static <T> String 
sqlString(List<org.apache.iceberg.expressions.Literal<T>> literals) {
       return literals.stream()
           .map(DescribeExpressionVisitor::sqlString)
diff --git 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java
index 96dc2c29eb..ce11f0c05f 100644
--- 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java
+++ 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java
@@ -20,6 +20,20 @@ package org.apache.iceberg.spark;
 
 import static org.apache.iceberg.NullOrder.NULLS_FIRST;
 import static org.apache.iceberg.NullOrder.NULLS_LAST;
+import static org.apache.iceberg.expressions.Expressions.and;
+import static org.apache.iceberg.expressions.Expressions.bucket;
+import static org.apache.iceberg.expressions.Expressions.day;
+import static org.apache.iceberg.expressions.Expressions.equal;
+import static org.apache.iceberg.expressions.Expressions.greaterThan;
+import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
+import static org.apache.iceberg.expressions.Expressions.hour;
+import static org.apache.iceberg.expressions.Expressions.in;
+import static org.apache.iceberg.expressions.Expressions.lessThan;
+import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
+import static org.apache.iceberg.expressions.Expressions.month;
+import static org.apache.iceberg.expressions.Expressions.notIn;
+import static org.apache.iceberg.expressions.Expressions.truncate;
+import static org.apache.iceberg.expressions.Expressions.year;
 import static org.apache.iceberg.types.Types.NestedField.optional;
 import static org.apache.iceberg.types.Types.NestedField.required;
 
@@ -29,7 +43,9 @@ import org.apache.iceberg.SortOrder;
 import org.apache.iceberg.SortOrderParser;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.catalog.Catalog;
+import org.apache.iceberg.expressions.Expression;
 import org.apache.iceberg.types.Types;
+import org.assertj.core.api.Assertions;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -122,6 +138,35 @@ public class TestSpark3Util extends SparkTestBase {
         "Should retrieve underlying catalog class", catalog instanceof 
CachingCatalog);
   }
 
+  @Test
+  public void testDescribeExpression() {
+    Expression refExpression = equal("id", 1);
+    Assertions.assertThat(Spark3Util.describe(refExpression)).isEqualTo("id = 
1");
+
+    Expression yearExpression = greaterThan(year("ts"), 10);
+    
Assertions.assertThat(Spark3Util.describe(yearExpression)).isEqualTo("year(ts) 
> 10");
+
+    Expression monthExpression = greaterThanOrEqual(month("ts"), 10);
+    
Assertions.assertThat(Spark3Util.describe(monthExpression)).isEqualTo("month(ts)
 >= 10");
+
+    Expression dayExpression = lessThan(day("ts"), 10);
+    
Assertions.assertThat(Spark3Util.describe(dayExpression)).isEqualTo("day(ts) < 
10");
+
+    Expression hourExpression = lessThanOrEqual(hour("ts"), 10);
+    
Assertions.assertThat(Spark3Util.describe(hourExpression)).isEqualTo("hour(ts) 
<= 10");
+
+    Expression bucketExpression = in(bucket("id", 5), 3);
+    
Assertions.assertThat(Spark3Util.describe(bucketExpression)).isEqualTo("bucket[5](id)
 IN (3)");
+
+    Expression truncateExpression = notIn(truncate("name", 3), "abc");
+    Assertions.assertThat(Spark3Util.describe(truncateExpression))
+        .isEqualTo("truncate[3](name) NOT IN ('abc')");
+
+    Expression andExpression = and(refExpression, yearExpression);
+    Assertions.assertThat(Spark3Util.describe(andExpression))
+        .isEqualTo("(id = 1 AND year(ts) > 10)");
+  }
+
   private SortOrder buildSortOrder(String transform, Schema schema, int 
sourceId) {
     String jsonString =
         "{\n"

Reply via email to