This is an automated email from the ASF dual-hosted git repository.

liyang pushed a commit to branch kylin5
in repository https://gitbox.apache.org/repos/asf/kylin.git

commit 30c4f28a1616dc80e674c4c4f7017b6cc63114b4
Author: Yaguang Jia <jiayagu...@foxmail.com>
AuthorDate: Fri Aug 11 13:10:02 2023 +0800

    KYLIN-5786 Add a write lock when merging v3 dict
---
 .../spark/builder/v3dict/DictionaryBuilder.scala   | 63 +++++++++++++++-------
 .../builder/v3dict/GlobalDictionarySuite.scala     | 12 +++--
 2 files changed, 51 insertions(+), 24 deletions(-)

diff --git 
a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/DictionaryBuilder.scala
 
b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/DictionaryBuilder.scala
index e47dc7107e..5fd9ae1673 100644
--- 
a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/DictionaryBuilder.scala
+++ 
b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/DictionaryBuilder.scala
@@ -40,6 +40,7 @@ import util.retry.blocking.RetryStrategy.RetryStrategyProducer
 import util.retry.blocking.{Failure, Retry, RetryStrategy, Success}
 
 import java.nio.file.Paths
+import java.util.concurrent.locks.ReentrantReadWriteLock
 import scala.collection.mutable.ListBuffer
 import scala.concurrent.duration.DurationInt
 import scala.util.control.NonFatal
@@ -48,6 +49,7 @@ object DictionaryBuilder extends Logging {
 
   implicit val retryStrategy: RetryStrategyProducer =
     RetryStrategy.randomBackOff(5.seconds, 15.seconds, maxAttempts = 20)
+  lazy val v3dictMergeLock = new ReentrantReadWriteLock(true)
 
   def buildGlobalDict(
                        project: String,
@@ -86,11 +88,12 @@ object DictionaryBuilder extends Logging {
   private def transformerDictPlan(
                                    spark: SparkSession,
                                    context: DictionaryContext,
-                                   plan: LogicalPlan): LogicalPlan = {
+                                   plan: LogicalPlan): DictPlanVersion = {
 
     val dictPath = getDictionaryPathAndCheck(context)
-    val dictTable: DeltaTable = DeltaTable.forPath(dictPath)
-    val maxOffset = dictTable.toDF.count()
+    val dictVersion = DeltaLog.forTable(spark, dictPath).snapshot.version
+    val dictTableDF = spark.read.format("delta").option("versionAsOf", 
dictVersion).load(dictPath);
+    val maxOffset = dictTableDF.count()
     logInfo(s"Dict $dictPath item count $maxOffset")
 
     plan match {
@@ -99,19 +102,19 @@ object DictionaryBuilder extends Logging {
         val windowSpec = 
org.apache.spark.sql.expressions.Window.orderBy(col(column))
         val joinCondition = createColumn(
           EqualTo(col(column).cast(StringType).expr,
-            getLogicalPlan(dictTable.toDF).output.head))
-        val filterKey = getLogicalPlan(dictTable.toDF).output.head.name
+            getLogicalPlan(dictTableDF).output.head))
+        val filterKey = getLogicalPlan(dictTableDF).output.head.name
         val antiJoinDF = getDataFrame(spark, windowChild)
           .filter(col(filterKey).isNotNull)
-          .join(dictTable.toDF,
+          .join(dictTableDF,
             joinCondition,
             "left_anti")
           .select(col(column).cast(StringType) as "dict_key",
             (row_number().over(windowSpec) + lit(maxOffset)).cast(LongType) as 
"dict_value")
         logInfo(s"Dict logical plan : 
${antiJoinDF.queryExecution.logical.treeString}")
-        getLogicalPlan(antiJoinDF)
+        DictPlanVersion(getLogicalPlan(antiJoinDF), dictVersion)
 
-      case _ => plan
+      case _ => DictPlanVersion(plan, dictVersion)
     }
   }
 
@@ -166,32 +169,52 @@ object DictionaryBuilder extends Logging {
     val dictPath = getDictionaryPath(context)
     logInfo(s"Save dict values into path $dictPath.")
     try {
-      dictDF.write.mode(SaveMode.Overwrite).format("delta").save(dictPath)
+      dictDF.write.mode(SaveMode.ErrorIfExists).format("delta").save(dictPath)
     } catch {
       case e: DeltaConcurrentModificationException =>
         logWarning(s"Concurrent modifications occurred: $dictPath", e)
         throw e
       case NonFatal(e) =>
+        logWarning(s"A NoFatal exception occurs when initializing the 
dictionary, and will be retried", e)
         if (!DeltaTable.isDeltaTable(dictPath)) {
           logWarning(s"Try to delete v3dict: $dictPath", e)
           HadoopUtil.deletePath(HadoopUtil.getCurrentConfiguration, new 
Path(dictPath))
         }
         throw e
+      case e: Throwable => throw e
     }
   }
 
   private def mergeIncrementDict(spark: SparkSession, context: 
DictionaryContext, plan: LogicalPlan): Unit = {
-    val dictPlan = transformerDictPlan(spark, context, plan)
-    val incrementDictDF = getDataFrame(spark, dictPlan)
+    val incDictVersion = transformerDictPlan(spark, context, plan)
+    val incrementDictDF = getDataFrame(spark, incDictVersion.incDictPlan)
     val dictPath = getDictionaryPathAndCheck(context)
-    logInfo(s"increment build global dict $dictPath")
-    val dictTable = DeltaTable.forPath(dictPath)
-    dictTable.alias("dict")
-      .merge(incrementDictDF.alias("incre_dict"),
-        "incre_dict.dict_key = dict.dict_key " +
-          "and incre_dict.dict_value != dict.dict_value")
-      .whenNotMatched().insertAll()
-      .execute()
+    if (incrementDictDF.isEmpty) {
+      logInfo(s"Increment dict for global dict $dictPath is empty, no need to 
merge.")
+      return
+    }
+    tryMergeIncrementDict(spark, dictPath, incDictVersion.sourceDeltaVersion, 
incrementDictDF)
+  }
+
+  private def tryMergeIncrementDict(spark: SparkSession, dictPath: String, 
dictVersion: Long,
+                                    incDictDF: Dataset[Row]): Unit = {
+    v3dictMergeLock.writeLock().lock()
+    try {
+      logInfo(s"Increment build global dict $dictPath")
+      val curVersion = DeltaLog.forTable(spark, dictPath).snapshot.version
+      if (dictVersion != curVersion) {
+        logInfo(s"Cur v3dict version is $curVersion, incDict is based on 
version $curVersion, will be retry")
+        throw new KylinRuntimeException(s"Cur v3dict version is $curVersion, " 
+
+          s"incDict is based on version $curVersion, will be retry")
+      } else {
+        val dictTable = DeltaTable.forPath(dictPath)
+        dictTable.merge(incDictDF, "1 != 1")
+          .whenNotMatched().insertAll()
+          .execute()
+      }
+    } finally {
+      v3dictMergeLock.writeLock().unlock()
+    }
   }
 
   /**
@@ -324,6 +347,7 @@ object DictionaryBuilder extends Logging {
   def getDictionaryPathAndCheck(context: DictionaryContext): String = {
     val v3ditPath = getDictionaryPath(context)
     if (!DeltaTable.isDeltaTable(v3ditPath)) {
+      logWarning(s"This v3dict path: {$v3ditPath} is not a delta table.")
       throw new KylinRuntimeException(s"This v3dict path: {$v3ditPath} is not 
a delta table.")
     }
     v3ditPath
@@ -333,6 +357,7 @@ object DictionaryBuilder extends Logging {
     NSparkCubingUtil.convertFromDot(ref.getBackTickIdentity)
   }
 }
+case class DictPlanVersion(incDictPlan: LogicalPlan, sourceDeltaVersion: Long)
 
 class DictionaryContext(
                          val project: String,
diff --git 
a/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionarySuite.scala
 
b/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionarySuite.scala
index c72de5c0b7..5b8d3da516 100644
--- 
a/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionarySuite.scala
+++ 
b/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionarySuite.scala
@@ -83,7 +83,7 @@ class GlobalDictionarySuite extends SparderBaseFunSuite with 
LocalMetadata with
     checkAnswer(originalDF, dictDF)
   }
 
-  test("KE-35145 Test Concurrent Build Dictionary") {
+  test("KE-35145 & KE-42401 Test Concurrent Build Dictionary") {
     val project = "p1"
     val dbName = "db1"
     val tableName = "t1"
@@ -101,7 +101,7 @@ class GlobalDictionarySuite extends SparderBaseFunSuite 
with LocalMetadata with
       ec.submit(buildDictTask)
     }
     ec.shutdown()
-    ec.awaitTermination(2, TimeUnit.MINUTES)
+    ec.awaitTermination(5, TimeUnit.MINUTES)
 
     val originalDF = spark.sql(
       """
@@ -110,8 +110,10 @@ class GlobalDictionarySuite extends SparderBaseFunSuite 
with LocalMetadata with
       """.stripMargin)
 
     val dictPath: String = DictionaryBuilder.getDictionaryPath(context)
-    val dictResultDF = 
DeltaTable.forPath(dictPath).toDF.agg(count(col("dict_key")))
-    checkAnswer(originalDF, dictResultDF)
+    val dictResultDFKey = 
DeltaTable.forPath(dictPath).toDF.select(countDistinct("dict_key"))
+    val dictResultDFValue = 
DeltaTable.forPath(dictPath).toDF.select(countDistinct("dict_value"))
+    checkAnswer(originalDF, dictResultDFKey)
+    checkAnswer(originalDF, dictResultDFValue)
   }
 
   test("KE-35145 Test the v3 dictionary with random data") {
@@ -278,7 +280,7 @@ class GlobalDictionarySuite extends SparderBaseFunSuite 
with LocalMetadata with
     new Runnable {
       override def run(): Unit = {
         val encodeColName: String = context.tableName + 
NSparkCubingUtil.SEPARATOR + context.columnName
-        val originalDF = genRandomData(spark, encodeColName, 100, 1)
+        val originalDF = genRandomData(spark, encodeColName, 20, 1)
         val dictDF = genDataWithWrapEncodeCol(context.dbName, encodeColName, 
originalDF)
         DeltaTable.forName("original")
           .merge(originalDF, "1 != 1")

Reply via email to