xuechendi commented on a change in pull request #32717: URL: https://github.com/apache/spark/pull/32717#discussion_r646351018
########## File path: sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala ########## @@ -143,3 +143,109 @@ class CachedBatchSerializerSuite extends QueryTest with SharedSparkSession { } } } + +object DummyAllocator { + private var allocated: Long = 0 + def alloc(size: Long): Unit = synchronized { + allocated += size + } + def release(size: Long): Unit = synchronized { + allocated -= size + } + def getAllocatedMemory: Long = synchronized { + allocated + } +} + +case class RefCountedCachedBatch( + numRows: Int, + stats: InternalRow, + size: Long, + cachedBatch: CachedBatch) extends SimpleMetricsCachedBatch with AutoCloseable { + DummyAllocator.alloc(size) + var allocated_size: Long = size + override def close(): Unit = synchronized { + DummyAllocator.release(allocated_size) + allocated_size = 0 + } + override def sizeInBytes: Long = allocated_size +} + +class RefCountedTestCachedBatchSerializer extends DefaultCachedBatchSerializer { + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + val batchSize = conf.columnBatchSize + val useCompression = conf.useCompression + val cachedBatchRdd = convertForCacheInternal(input, schema, batchSize, useCompression) + cachedBatchRdd.mapPartitionsInternal { cachedBatchIter => + cachedBatchIter.map(cachedBatch => { + val actualCachedBatch = cachedBatch.asInstanceOf[DefaultCachedBatch] + new RefCountedCachedBatch( + actualCachedBatch.numRows, + actualCachedBatch.stats, + 100, + cachedBatch) + }) + } + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + val actualCachedBatchIter = input.mapPartitionsInternal { cachedBatchIter => + cachedBatchIter.map(_.asInstanceOf[RefCountedCachedBatch].cachedBatch) + } + super.convertCachedBatchToInternalRow( + actualCachedBatchIter, + cacheAttributes, + selectedAttributes, + conf) + } + + override def supportsColumnarOutput(schema: StructType): Boolean = false + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = false +} + +class RefCountedTestCachedBatchSerializerSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + override protected def sparkConf: SparkConf = { + super.sparkConf.set( + StaticSQLConf.SPARK_CACHE_SERIALIZER.key, + classOf[RefCountedTestCachedBatchSerializer].getName) + } + + protected override def beforeAll(): Unit = { + super.beforeAll() + clearSerializer() + } + + protected override def afterAll(): Unit = { + clearSerializer() + super.afterAll() + } + + test("SPARK-35396: Manual Release objects stored in InMemoryRelation when clearCache called") { + val df = spark.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") + val cached = df.cache() + // count triggers the caching action. It should not throw. + cached.count() + + // Make sure, the DataFrame is indeed cached. + assert(spark.sharedState.cacheManager.lookupCachedData(cached).nonEmpty) + assert(DummyAllocator.getAllocatedMemory > 0) + + // Drop the cache. + cached.unpersist() Review comment: I see, changed to using blocking = true, and I verified locally, if I did a 1sec wait in close function, only blocking = true can pass the test. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org