This is an automated email from the ASF dual-hosted git repository.

liyang pushed a commit to branch kylin5
in repository https://gitbox.apache.org/repos/asf/kylin.git

commit dd16079c42e4f6eb446945ec42980891934a20c6
Author: Zhiting Guo <35057824+fre...@users.noreply.github.com>
AuthorDate: Fri Aug 11 13:15:12 2023 +0800

    KYLIN-5787 Use t-digest as spark percentile_approx function
---
 .../org/apache/kylin/common/KylinConfigBase.java   |  6 ++++
 .../measure/percentile/PercentileCounter.java      |  1 +
 .../java/org/apache/kylin/util/ExecAndComp.java    | 42 +++++++++++++---------
 .../kylin/query/engine/AsyncQueryApplication.java  |  5 +++
 .../query/engine/AsyncQueryApplicationTest.java    |  3 ++
 .../kylin/query/runtime/plan/AggregatePlan.scala   | 11 ++++--
 .../scala/org/apache/spark/sql/KapFunctions.scala  |  8 +++--
 .../scala/org/apache/spark/sql/SparderEnv.scala    |  3 ++
 .../org/apache/spark/sql/udf/UdfManager.scala      |  9 ++---
 .../org/apache/spark/sql/udaf/Percentile.scala     |  5 +++
 10 files changed, 67 insertions(+), 26 deletions(-)

diff --git 
a/src/core-common/src/main/java/org/apache/kylin/common/KylinConfigBase.java 
b/src/core-common/src/main/java/org/apache/kylin/common/KylinConfigBase.java
index ff7bd4c7c3..5bfe8c262a 100644
--- a/src/core-common/src/main/java/org/apache/kylin/common/KylinConfigBase.java
+++ b/src/core-common/src/main/java/org/apache/kylin/common/KylinConfigBase.java
@@ -1914,6 +1914,12 @@ public abstract class KylinConfigBase implements 
Serializable {
         return 
Boolean.parseBoolean(this.getOptional("kylin.query.replace-dynamic-params-enabled",
 FALSE));
     }
 
+
+    public String getPercentileApproxAlgorithm() {
+        // Valid values: t-digest
+        return this.getOptional("kylin.query.percentile-approx-algorithm", "");
+    }
+
     public boolean isCollectUnionInOrder() {
         return 
Boolean.parseBoolean(this.getOptional("kylin.query.collect-union-in-order", 
TRUE));
     }
diff --git 
a/src/core-metadata/src/main/java/org/apache/kylin/measure/percentile/PercentileCounter.java
 
b/src/core-metadata/src/main/java/org/apache/kylin/measure/percentile/PercentileCounter.java
index 823b47ba38..11851d48c6 100644
--- 
a/src/core-metadata/src/main/java/org/apache/kylin/measure/percentile/PercentileCounter.java
+++ 
b/src/core-metadata/src/main/java/org/apache/kylin/measure/percentile/PercentileCounter.java
@@ -25,6 +25,7 @@ import com.tdunning.math.stats.AVLTreeDigest;
 import com.tdunning.math.stats.TDigest;
 
 public class PercentileCounter implements Serializable {
+    public static final int DEFAULT_PERCENTILE_ACCURACY = 100;
     private static final double INVALID_QUANTILE_RATIO = -1;
 
     double compression;
diff --git a/src/kylin-it/src/test/java/org/apache/kylin/util/ExecAndComp.java 
b/src/kylin-it/src/test/java/org/apache/kylin/util/ExecAndComp.java
index 34126ab7f4..d089e9e34a 100644
--- a/src/kylin-it/src/test/java/org/apache/kylin/util/ExecAndComp.java
+++ b/src/kylin-it/src/test/java/org/apache/kylin/util/ExecAndComp.java
@@ -135,27 +135,37 @@ public class ExecAndComp {
 
     @SneakyThrows
     public static QueryResult queryWithSpark(String prj, String originSql, 
String joinType, String sqlPath) {
+        return queryWithSpark(prj, originSql, joinType, sqlPath, true);
+    }
+
+    @SneakyThrows
+    public static QueryResult queryWithSpark(String prj, String originSql, 
String joinType, String sqlPath,
+            boolean withCache) {
         int index = sqlPath.lastIndexOf('/');
         String resultFilePath = "";
         String schemaFilePath = "";
-        if (index > 0) {
-            resultFilePath = sqlPath.substring(0, index) + "/result-" + 
joinType + sqlPath.substring(index) + ".json";
-            schemaFilePath = sqlPath.substring(0, index) + "/result-" + 
joinType + sqlPath.substring(index) + ".schema";
-        }
+        if (withCache) {
+            if (index > 0) {
+                resultFilePath = sqlPath.substring(0, index) + "/result-" + 
joinType + sqlPath.substring(index)
+                        + ".json";
+                schemaFilePath = sqlPath.substring(0, index) + "/result-" + 
joinType + sqlPath.substring(index)
+                        + ".schema";
+            }
 
-        // query with cache
-        try {
-            if (index > 0 && Files.exists(Paths.get(resultFilePath)) && 
Files.exists(Paths.get(schemaFilePath))) {
-                StructType schema = StructType.fromDDL(new 
String(Files.readAllBytes(Paths.get(schemaFilePath))));
-                List<StructField> structs = Arrays.stream(schema.fields())
-                        
.map(SparderTypeUtil::convertSparkFieldToJavaField).collect(Collectors.toList());
-                Dataset<Row> ds = 
SparderEnv.getSparkSession().read().schema(schema).json(resultFilePath);
-                val dsIter = ds.toIterator();
-                Iterable<List<String>> listIter = 
SparkSqlClient.readPushDownResultRow(dsIter._1(), false);
-                return new QueryResult(Lists.newArrayList(listIter), (int) 
dsIter._2(), structs);
+            // query with cache
+            try {
+                if (index > 0 && Files.exists(Paths.get(resultFilePath)) && 
Files.exists(Paths.get(schemaFilePath))) {
+                    StructType schema = StructType.fromDDL(new 
String(Files.readAllBytes(Paths.get(schemaFilePath))));
+                    List<StructField> structs = Arrays.stream(schema.fields())
+                            
.map(SparderTypeUtil::convertSparkFieldToJavaField).collect(Collectors.toList());
+                    Dataset<Row> ds = 
SparderEnv.getSparkSession().read().schema(schema).json(resultFilePath);
+                    val dsIter = ds.toIterator();
+                    Iterable<List<String>> listIter = 
SparkSqlClient.readPushDownResultRow(dsIter._1(), false);
+                    return new QueryResult(Lists.newArrayList(listIter), (int) 
dsIter._2(), structs);
+                }
+            } catch (Exception e) {
+                log.warn("try to use cache failed, compare with spark {}", 
sqlPath, e);
             }
-        } catch (Exception e) {
-            log.warn("try to use cache failed, compare with spark {}", 
sqlPath, e);
         }
         // query with spark and cache result
         return queryWithSpark(prj, originSql, joinType, sqlPath, 
resultFilePath, schemaFilePath);
diff --git 
a/src/query/src/main/java/org/apache/kylin/query/engine/AsyncQueryApplication.java
 
b/src/query/src/main/java/org/apache/kylin/query/engine/AsyncQueryApplication.java
index ab6d943ffc..b1cf0a6571 100644
--- 
a/src/query/src/main/java/org/apache/kylin/query/engine/AsyncQueryApplication.java
+++ 
b/src/query/src/main/java/org/apache/kylin/query/engine/AsyncQueryApplication.java
@@ -41,6 +41,8 @@ import org.apache.kylin.metadata.query.RDBMSQueryHistoryDAO;
 import org.apache.kylin.metadata.query.util.QueryHistoryUtil;
 import org.apache.kylin.query.util.AsyncQueryUtil;
 import org.apache.kylin.query.util.QueryParams;
+import org.apache.spark.sql.KapFunctions;
+import org.apache.spark.sql.udf.UdfManager;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -67,6 +69,9 @@ public class AsyncQueryApplication extends SparkApplication {
         QueryContext queryContext = null;
         QueryParams queryParams = null;
         try {
+            if 
(getConfig().getPercentileApproxAlgorithm().equalsIgnoreCase("t-digest")) {
+                UdfManager.register(getSparkSession(), 
KapFunctions.percentileFunction());
+            }
             queryContext = JsonUtil.readValue(getParam(P_QUERY_CONTEXT), 
QueryContext.class);
             QueryContext.set(queryContext);
             QueryMetricsContext.start(queryContext.getQueryId(), "");
diff --git 
a/src/query/src/test/java/org/apache/kylin/query/engine/AsyncQueryApplicationTest.java
 
b/src/query/src/test/java/org/apache/kylin/query/engine/AsyncQueryApplicationTest.java
index 79b4570c99..24c435817a 100644
--- 
a/src/query/src/test/java/org/apache/kylin/query/engine/AsyncQueryApplicationTest.java
+++ 
b/src/query/src/test/java/org/apache/kylin/query/engine/AsyncQueryApplicationTest.java
@@ -29,6 +29,8 @@ import static org.mockito.Mockito.mockStatic;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.when;
 
+import java.util.Properties;
+
 import org.apache.kylin.common.KylinConfig;
 import org.apache.kylin.common.QueryContext;
 import org.apache.kylin.metadata.query.QueryMetricsContext;
@@ -50,6 +52,7 @@ public class AsyncQueryApplicationTest {
     @Before
     public void setUp() throws Exception {
         asyncQueryApplication = spy(new AsyncQueryApplication());
+        doReturn(KylinConfig.createKylinConfig(new 
Properties())).when(asyncQueryApplication).getConfig();
     }
 
     @Test
diff --git 
a/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/plan/AggregatePlan.scala
 
b/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/plan/AggregatePlan.scala
index 714b42d15d..a44b338cbb 100644
--- 
a/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/plan/AggregatePlan.scala
+++ 
b/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/plan/AggregatePlan.scala
@@ -22,6 +22,7 @@ import org.apache.calcite.rex.RexLiteral
 import org.apache.calcite.sql.SqlKind
 import org.apache.kylin.common.KylinConfig
 import org.apache.kylin.engine.spark.utils.LogEx
+import org.apache.kylin.measure.percentile.PercentileCounter
 import org.apache.kylin.metadata.model.FunctionDesc
 import org.apache.kylin.query.relnode.{KapAggregateRel, KapProjectRel, 
KylinAggregateCall, OLAPAggregateRel}
 import org.apache.kylin.query.util.RuntimeHelper
@@ -213,9 +214,13 @@ object AggregatePlan extends LogEx {
                 val accuracyArg = if (call.getArgList.size() < 3) { None } 
else { Some(projectRel.getChildExps.get(call.getArgList.get(2))) }
                 (percentageArg, accuracyArg) match {
                   case (percentageLitRex: RexLiteral, accuracyArgLitRex: 
Option[RexLiteral]) =>
-                    val percentage = percentageLitRex.getValue
-                    val accuracy = accuracyArgLitRex.map(arg => 
arg.getValue).getOrElse(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)
-                    percentile_approx(col(argNames.head), lit(percentage), 
lit(accuracy)).alias(aggName)
+                    if 
(KylinConfig.getInstanceFromEnv.getPercentileApproxAlgorithm.equalsIgnoreCase("t-digest"))
 {
+                      KapFunctions.k_percentile(columnName.head, 
columnName(1), PercentileCounter.DEFAULT_PERCENTILE_ACCURACY).alias(aggName)
+                    } else {
+                      val percentage = percentageLitRex.getValue
+                      val accuracy = accuracyArgLitRex.map(arg => 
arg.getValue).getOrElse(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)
+                      percentile_approx(col(argNames.head), lit(percentage), 
lit(accuracy)).alias(aggName)
+                    }
                 }
               case _ =>
                 throw new UnsupportedOperationException(s"Invalid 
percentile_approx parameters, " +
diff --git 
a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/KapFunctions.scala
 
b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/KapFunctions.scala
index 352dfcb977..77cce256e1 100644
--- 
a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/KapFunctions.scala
+++ 
b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/KapFunctions.scala
@@ -23,9 +23,9 @@ import 
org.apache.spark.sql.catalyst.expressions.ExpressionUtils.expression
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
 import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, 
CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.expressions.{ApproxCountDistinctDecode, 
CeilDateTime, DictEncode, DictEncodeV3, EmptyRow, Expression, ExpressionInfo, 
FloorDateTime, ImplicitCastInputTypes, In, KapAddMonths, KapSubtractMonths, 
Like, Literal, PercentileDecode, PreciseCountDistinctDecode, RLike, RoundBase, 
SplitPart, Sum0, SumLCDecode, TimestampAdd, TimestampDiff, Truncate}
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.udaf._
+import org.apache.spark.sql.catalyst.expressions.{ApproxCountDistinctDecode, 
CeilDateTime, DictEncode, DictEncodeV3, EmptyRow, Expression, ExpressionInfo, 
ExpressionUtils, FloorDateTime, ImplicitCastInputTypes, In, KapAddMonths, 
KapSubtractMonths, Like, Literal, PercentileDecode, PreciseCountDistinctDecode, 
RLike, RoundBase, SplitPart, Sum0, SumLCDecode, TimestampAdd, TimestampDiff, 
Truncate}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, ByteType, DataType, 
DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, 
StringType}
+import org.apache.spark.sql.udaf.{ApproxCountDistinct, IntersectCount, 
Percentile, PreciseBitmapBuildBase64Decode, PreciseBitmapBuildBase64WithIndex, 
PreciseBitmapBuildPushDown, PreciseCardinality, PreciseCountDistinct, 
PreciseCountDistinctAndArray, PreciseCountDistinctAndValue, 
ReusePreciseCountDistinct, ReuseSumLC}
 
 object KapFunctions {
 
@@ -272,6 +272,8 @@ object KapFunctions {
     FunctionEntity(expression[PercentileDecode]("percentile_decode")),
     FunctionEntity(expression[PreciseBitmapBuildPushDown]("bitmap_build"))
   )
+
+  val percentileFunction: FunctionEntity = 
FunctionEntity(ExpressionUtils.expression[Percentile]("percentile_approx"))
 }
 
 case class FunctionEntity(name: FunctionIdentifier,
diff --git 
a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala
 
b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala
index 8961a52708..4b1382e6f3 100644
--- 
a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala
+++ 
b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala
@@ -239,6 +239,9 @@ object SparderEnv extends Logging {
       }
 
       injectExtensions(sparkSession.extensions)
+      if 
(KylinConfig.getInstanceFromEnv.getPercentileApproxAlgorithm.equalsIgnoreCase("t-digest"))
 {
+        UdfManager.register(sparkSession, KapFunctions.percentileFunction)
+      }
       spark = sparkSession
       logInfo("Spark context started successfully with stack trace:")
       logInfo(Thread.currentThread().getStackTrace.mkString("\n"))
diff --git 
a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/udf/UdfManager.scala
 
b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/udf/UdfManager.scala
index c38ff23a6e..759575d7cb 100644
--- 
a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/udf/UdfManager.scala
+++ 
b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/udf/UdfManager.scala
@@ -18,17 +18,13 @@
 
 package org.apache.spark.sql.udf
 
-import java.nio.ByteBuffer
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.atomic.AtomicReference
 
-import com.esotericsoftware.kryo.io.{Input, KryoDataInput}
 import org.apache.kylin.guava30.shaded.common.cache.{Cache, CacheBuilder, 
RemovalListener, RemovalNotification}
-import org.apache.kylin.measure.hllc.HLLCounter
 import org.apache.kylin.metadata.datatype.DataType
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{FunctionEntity, KapFunctions, SparkSession}
-import org.roaringbitmap.longlong.Roaring64NavigableMap
 
 class UdfManager(sparkSession: SparkSession) extends Logging {
   private var udfCache: Cache[String, String] = _
@@ -87,4 +83,9 @@ object UdfManager {
     defaultManager.get().doRegister(dataType, func)
   }
 
+  def register(sparkSession: SparkSession, func: FunctionEntity): Unit = {
+    sparkSession.sessionState.functionRegistry.registerFunction(func.name,
+      func.info, func.builder)
+  }
+
 }
diff --git 
a/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/udaf/Percentile.scala
 
b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/udaf/Percentile.scala
index 6bfd178e23..a9940e3163 100644
--- 
a/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/udaf/Percentile.scala
+++ 
b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/udaf/Percentile.scala
@@ -36,6 +36,11 @@ case class Percentile(aggColumn: Expression,
                       inputAggBufferOffset: Int = 0)
   extends TypedImperativeAggregate[PercentileCounter] with Serializable with 
Logging {
 
+  // used by spark pushDown
+  def this(aggColumn: Expression, quantile: Expression) {
+    this(aggColumn, PercentileCounter.DEFAULT_PERCENTILE_ACCURACY, 
Some(quantile), DoubleType)
+  }
+
   override def children: Seq[Expression] = quantile match {
     case None => aggColumn :: Nil
     case Some(q) => aggColumn :: q :: Nil

Reply via email to