This is an automated email from the ASF dual-hosted git repository.

xxyu pushed a commit to branch kylin-on-parquet-v2
in repository https://gitbox.apache.org/repos/asf/kylin.git


The following commit(s) were added to refs/heads/kylin-on-parquet-v2 by this 
push:
     new f19f328  KYLIN-4760 Optimize TopN measure
f19f328 is described below

commit f19f3286a5a0122c313faa2bf3c5248aae4c726f
Author: rupengwang <wangrup...@live.cn>
AuthorDate: Wed Sep 16 13:51:10 2020 +0800

    KYLIN-4760 Optimize TopN measure
---
 .../spark/sql/udaf/NullSafeValueSerializer.scala   | 169 ++++++++++++++++
 .../scala/org/apache/spark/sql/udaf/TopN.scala     | 215 +++++++++++++++++++++
 .../kylin/engine/spark/job/CuboidAggregator.scala  |  21 +-
 3 files changed, 402 insertions(+), 3 deletions(-)

diff --git 
a/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/udaf/NullSafeValueSerializer.scala
 
b/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/udaf/NullSafeValueSerializer.scala
new file mode 100644
index 0000000..5be5626
--- /dev/null
+++ 
b/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/udaf/NullSafeValueSerializer.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+*/
+
+package org.apache.spark.sql.udaf
+
+import java.io.{DataInput, DataOutput}
+import java.nio.charset.StandardCharsets
+
+import org.apache.spark.unsafe.types.UTF8String
+
+@SerialVersionUID(1)
+sealed trait NullSafeValueSerializer {
+  final def serialize(output: DataOutput, value: Any): Unit = {
+    if (value == null) {
+      output.writeInt(0)
+    } else {
+      serialize0(output, value.asInstanceOf[Any])
+    }
+  }
+
+  @inline protected def serialize0(output: DataOutput, value: Any): Unit
+
+  def deserialize(input: DataInput): Any = {
+    val length = input.readInt()
+    if (length == 0) {
+      null
+    } else {
+      deSerialize0(input, length)
+    }
+  }
+
+  @inline protected def deSerialize0(input: DataInput, length: Int): Any
+}
+
+@SerialVersionUID(1)
+class BooleanSerializer extends NullSafeValueSerializer {
+  override protected def serialize0(output: DataOutput, value: Any): Unit = {
+    output.writeInt(1)
+    output.writeBoolean(value.asInstanceOf[Boolean])
+  }
+
+  override protected def deSerialize0(input: DataInput, length: Int): Any = {
+    input.readBoolean()
+  }
+}
+
+@SerialVersionUID(1)
+class ByteSerializer extends NullSafeValueSerializer {
+  override def serialize0(output: DataOutput, value: Any): Unit = {
+    output.writeInt(1)
+    output.writeByte(value.asInstanceOf[Byte])
+  }
+
+  override def deSerialize0(input: DataInput, length: Int): Any = {
+    input.readByte()
+  }
+}
+
+@SerialVersionUID(1)
+class ShortSerializer extends NullSafeValueSerializer {
+  override def serialize0(output: DataOutput, value: Any): Unit = {
+    output.writeInt(2)
+    output.writeShort(value.asInstanceOf[Short])
+  }
+
+  override def deSerialize0(input: DataInput, length: Int): Any = {
+    input.readShort()
+  }
+}
+
+@SerialVersionUID(1)
+class IntegerSerializer extends NullSafeValueSerializer {
+  override def serialize0(output: DataOutput, value: Any): Unit = {
+    output.writeInt(4)
+    output.writeInt(value.asInstanceOf[Int])
+  }
+
+  override def deSerialize0(input: DataInput, length: Int): Any = {
+    input.readInt()
+  }
+}
+
+@SerialVersionUID(1)
+class FloatSerializer extends NullSafeValueSerializer {
+  override def serialize0(output: DataOutput, value: Any): Unit = {
+    output.writeInt(4)
+    output.writeFloat(value.asInstanceOf[Float])
+  }
+
+  override def deSerialize0(input: DataInput, length: Int): Any = {
+    input.readFloat()
+  }
+}
+
+@SerialVersionUID(1)
+class DoubleSerializer extends NullSafeValueSerializer {
+  override def serialize0(output: DataOutput, value: Any): Unit = {
+    output.writeInt(8)
+    output.writeDouble(value.asInstanceOf[Double])
+  }
+
+  override def deSerialize0(input: DataInput, length: Int): Any = {
+    input.readDouble()
+  }
+}
+
+@SerialVersionUID(1)
+class LongSerializer extends NullSafeValueSerializer {
+  override def serialize0(output: DataOutput, value: Any): Unit = {
+    output.writeInt(8)
+    output.writeLong(value.asInstanceOf[Long])
+  }
+
+  override def deSerialize0(input: DataInput, length: Int): Any = {
+    input.readLong()
+  }
+}
+
+@SerialVersionUID(1)
+class StringSerializer extends NullSafeValueSerializer {
+  override def serialize0(output: DataOutput, value: Any): Unit = {
+    val bytes = value.toString.getBytes(StandardCharsets.UTF_8)
+    output.writeInt(bytes.length)
+    output.write(bytes)
+  }
+
+  override def deSerialize0(input: DataInput, length: Int): Any = {
+    val bytes = new Array[Byte](length)
+    input.readFully(bytes)
+    UTF8String.fromBytes(bytes)
+  }
+}
+
+@SerialVersionUID(1)
+class DecimalSerializer extends NullSafeValueSerializer {
+  override def serialize0(output: DataOutput, value: Any): Unit = {
+    val decimal = value.asInstanceOf[BigDecimal]
+    val bytes = decimal.toString().getBytes(StandardCharsets.UTF_8)
+    output.writeInt(1 + bytes.length)
+    output.writeByte(decimal.scale)
+    output.writeInt(bytes.length)
+    output.write(bytes)
+  }
+
+  override def deSerialize0(input: DataInput, length: Int): Any = {
+    val scale = input.readByte()
+    val length = input.readInt()
+    val bytes = new Array[Byte](length)
+    input.readFully(bytes)
+    val decimal = BigDecimal.apply(new String(bytes, StandardCharsets.UTF_8))
+    decimal.setScale(scale)
+  }
+}
+
diff --git 
a/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/udaf/TopN.scala
 
b/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/udaf/TopN.scala
new file mode 100644
index 0000000..ad1713f
--- /dev/null
+++ 
b/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/udaf/TopN.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+*/
+
+package org.apache.spark.sql.udaf
+
+import com.esotericsoftware.kryo.KryoException
+import com.esotericsoftware.kryo.io.{Input, KryoDataInput, KryoDataOutput, 
Output}
+import org.apache.kylin.measure.topn.TopNCounter
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, 
TypedImperativeAggregate}
+import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.types._
+
+import scala.collection.JavaConverters._
+
+@SerialVersionUID(1)
+sealed abstract class BaseTopN(precision: Int,
+                               internalSchema: StructType,
+                               mutableAggBufferOffset: Int = 0,
+                               inputAggBufferOffset: Int = 0
+                              ) extends 
TypedImperativeAggregate[TopNCounter[Seq[Any]]] with Serializable with Logging {
+  lazy val serializers: Seq[NullSafeValueSerializer] =
+    Seq(new DoubleSerializer) ++ internalSchema.drop(1).map(_.dataType).map {
+      case BooleanType => new BooleanSerializer
+      case ByteType => new ByteSerializer
+      case ShortType => new ShortSerializer
+      case IntegerType => new IntegerSerializer
+      case FloatType => new FloatSerializer
+      case LongType => new LongSerializer
+      case DoubleType => new DoubleSerializer
+      case TimestampType => new LongSerializer
+      case DateType => new IntegerSerializer
+      case StringType => new StringSerializer
+      case dt => throw new UnsupportedOperationException("Unsupported TopN 
dimension type: " + dt)
+    }
+
+  override def createAggregationBuffer(): TopNCounter[Seq[Any]] = new 
TopNCounter[Seq[Any]](precision * TopNCounter.EXTRA_SPACE_RATE)
+
+  override def merge(buffer: TopNCounter[Seq[Any]], input: 
TopNCounter[Seq[Any]]): TopNCounter[Seq[Any]] = {
+    input.getCounterList.asScala.foreach { c =>
+      buffer.offer(c.getItem, c.getCount)
+    }
+    buffer
+  }
+
+  override def eval(buffer: TopNCounter[Seq[Any]]): Any = {
+    buffer.sortAndRetain()
+    val seq = buffer.getCounterList.asScala.map(
+      entry => InternalRow(entry.getCount, InternalRow(entry.getItem: _*))
+    )
+    ArrayData.toArrayData(seq)
+  }
+
+  var array: Array[Byte] = _
+  var output: Output = _
+
+  override def serialize(topNCounter: TopNCounter[Seq[Any]]): Array[Byte] = {
+    try {
+      if (topNCounter != null) {
+        if (array == null) {
+          array = new Array[Byte](1024 * 1024)
+          output = new Output(array)
+        }
+        output.clear()
+        val out = new KryoDataOutput(output)
+        topNCounter.sortAndRetain()
+        val counters = topNCounter.getCounterList
+        out.writeInt(counters.size())
+        counters.asScala.foreach { counter =>
+          val values = Seq(counter.getCount) ++ counter.getItem
+          values.zip(serializers).foreach { case (value, ser) =>
+            ser.serialize(out, value)
+          }
+        }
+        val i = output.position()
+        output.close()
+        array.slice(0, i)
+      } else {
+        Array.empty[Byte]
+      }
+    } catch {
+      case th: KryoException if th.getMessage.contains("Buffer overflow") =>
+        logWarning(s"Resize buffer size to ${array.length * 2}")
+        array = new Array[Byte](array.length * 2)
+        output.setBuffer(array)
+        serialize(topNCounter)
+      case th =>
+        throw th
+    }
+  }
+
+  override def deserialize(bytes: Array[Byte]): TopNCounter[Seq[Any]] = {
+    val topNCounter = new TopNCounter[Seq[Any]](precision * 
TopNCounter.EXTRA_SPACE_RATE)
+    if (bytes.nonEmpty) {
+      val in = new KryoDataInput(new Input(bytes))
+      val size = in.readInt()
+      for (_ <- 0 until size) {
+        val values = serializers.map(_.deserialize(in))
+        val item = values.drop(1)
+        if (values.head == null) {
+          topNCounter.offer(item, null)
+        } else {
+          topNCounter.offer(item, values.head.asInstanceOf[Double])
+        }
+      }
+    }
+    topNCounter
+  }
+
+  override def nullable: Boolean = false
+
+  override def dataType: DataType = ArrayType(
+    StructType(Seq(
+      StructField("measure", DoubleType),
+      StructField("dim", StructType(internalSchema.fields.drop(1)))
+    )))
+}
+
+case class ReuseTopN(
+                      precision: Int,
+                      internalSchema: StructType,
+                      child: Expression,
+                      mutableAggBufferOffset: Int = 0,
+                      inputAggBufferOffset: Int = 0)
+  extends BaseTopN(precision, internalSchema, mutableAggBufferOffset, 
inputAggBufferOffset) {
+  val dimType: StructType = StructType(internalSchema.fields.drop(1))
+  val innerType: StructType = StructType(Seq(
+    StructField("measure", DoubleType),
+    StructField("dim", dimType)
+  ))
+
+  override def update(buffer: TopNCounter[Seq[Any]], input: InternalRow): 
TopNCounter[Seq[Any]] = {
+    val datum = 
child.eval(input).asInstanceOf[ArrayData].toArray[UnsafeRow](innerType)
+    datum.foreach { data =>
+      val value = data.getDouble(0)
+      val dims = data.get(1, dimType).asInstanceOf[InternalRow]
+      val item = dimType.fields.map(_.dataType).zipWithIndex.map {
+        case (StringType, index) =>
+          val result = dims.get(index, StringType)
+          if (null != result) {
+            dims.getString(index)
+          } else {
+            result
+          }
+        case (dataType, index) =>
+          dims.get(index, dataType)
+      }.toSeq
+      buffer.offer(item, value)
+    }
+    buffer
+  }
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def children: Seq[Expression] = child :: Nil
+}
+
+case class EncodeTopN(
+                       precision: Int,
+                       internalSchema: StructType,
+                       measure: Expression,
+                       dimensions: Seq[Expression],
+                       mutableAggBufferOffset: Int = 0,
+                       inputAggBufferOffset: Int = 0)
+  extends BaseTopN(precision, internalSchema, mutableAggBufferOffset, 
inputAggBufferOffset) {
+  override def update(counter: TopNCounter[Seq[Any]], input: InternalRow): 
TopNCounter[Seq[Any]] = {
+    val m = measure.eval(input)
+    val dims = dimensions.map {
+      case str: Expression if str.dataType.isInstanceOf[StringType] =>
+        val value = str.eval(input)
+        if (value != null) {
+          value.toString
+        } else {
+          null
+        }
+      case str =>
+        str.eval(input)
+    }
+    if (m == null) {
+      counter.offer(dims, null)
+    } else {
+      counter.offer(dims, m.toString.toDouble)
+    }
+    counter
+  }
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def children: Seq[Expression] = Seq(measure) ++ dimensions
+}
diff --git 
a/kylin-spark-project/kylin-spark-engine/src/main/scala/org/apache/kylin/engine/spark/job/CuboidAggregator.scala
 
b/kylin-spark-project/kylin-spark-engine/src/main/scala/org/apache/kylin/engine/spark/job/CuboidAggregator.scala
index 1da2af5..54afadf 100644
--- 
a/kylin-spark-project/kylin-spark-engine/src/main/scala/org/apache/kylin/engine/spark/job/CuboidAggregator.scala
+++ 
b/kylin-spark-project/kylin-spark-engine/src/main/scala/org/apache/kylin/engine/spark/job/CuboidAggregator.scala
@@ -98,9 +98,24 @@ object CuboidAggregator {
             new Column(cdAggregate.toAggregateExpression()).as(id.toString)
           }
         case "TOP_N" =>
-          val schema: StructType = constructTopNSchema(measure.pra)
-          val udfName = 
UdfManager.register(measure.returnType.toKylinDataType, measure.expression, 
schema, !reuseLayout)
-          callUDF(udfName, columns: _*).as(id.toString)
+          // Uses new TopN aggregate function
+          // located in 
kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/udaf/TopN.scala
+          val schema = StructType(measure.pra.map { col =>
+            val dateType = col.dataType
+            if (col == measure) {
+              StructField(s"MEASURE_${col.columnName}", dateType)
+            } else {
+              StructField(s"DIMENSION_${col.columnName}", dateType)
+            }
+          })
+
+          if (reuseLayout) {
+            new Column(ReuseTopN(measure.returnType.precision, schema, 
columns.head.expr)
+              .toAggregateExpression()).as(id.toString)
+          } else {
+            new Column(EncodeTopN(measure.returnType.precision, schema, 
columns.head.expr, columns.drop(1).map(_.expr))
+              .toAggregateExpression()).as(id.toString)
+          }
         case "PERCENTILE_APPROX" =>
           val udfName = 
UdfManager.register(measure.returnType.toKylinDataType, measure.expression, 
null, !reuseLayout)
           if (!reuseLayout) {

Reply via email to