Hello.

I'm running Scala 2.11 w/ Spark 2.3.0.  I've encountered a problem with
mapGroupsWithState, and was wondering if anyone had insight.  We use Joda
time in a number of data structures, and so we've generated a custom
serializer for Joda.  This works well in most dataset/dataframe structured
streaming operations. However, when running mapGroupsWithState we observed
that incorrect dates were being returned from a state.

I created a bug here: https://issues.apache.org/jira/browse/SPARK-30986 in
an effort to assist tracking of related information.

Simple example:
1. Input A has a date D
2. Input A updates state in mapGroupsWithState. Date present in state is D
3. Input A is added again.  Input A has correct date D, but existing state
now has invalid date

Here is a simple repro:

Joda Time UDT:

private[sql] class JodaTimeUDT extends UserDefinedType[DateTime] {
  override def sqlType: DataType  = LongType
  override def serialize(obj: DateTime): Long = obj.getMillis
  def deserialize(datum: Any): DateTime = datum match { case value:
Long => new DateTime(value, DateTimeZone.UTC) }
  override def userClass: Class[DateTime] = classOf[DateTime]
  private[spark] override def asNullable: JodaTimeUDT = this
}

object JodaTimeUDTRegister {
  def register : Unit = {
UDTRegistration.register(classOf[DateTime].getName,
classOf[JodaTimeUDT].getName)  }
}


Test Leveraging Joda UDT:

case class FooWithDate(date: DateTime, s: String, i: Int)

@RunWith(classOf[JUnitRunner])
class TestJodaTimeUdt extends FlatSpec with Matchers with MockFactory
with BeforeAndAfterAll {
  val application = this.getClass.getName
  var session: SparkSession = _

  override def beforeAll(): Unit = {
    System.setProperty("hadoop.home.dir", getClass.getResource("/").getPath)
    val sparkConf = new SparkConf()
      .set("spark.driver.allowMultipleContexts", "true")
      .set("spark.testing", "true")
      .set("spark.memory.fraction", "1")
      .set("spark.ui.enabled", "false")
      .set("spark.streaming.gracefulStopTimeout", "1000")
      .setAppName(application).setMaster("local[*]")


    session = SparkSession.builder().config(sparkConf).getOrCreate()
    session.sparkContext.setCheckpointDir("/")
    JodaTimeUDTRegister.register
  }

  override def afterAll(): Unit = {
    session.stop()
  }

  it should "work correctly for a streaming input with stateful
transformation" in {
    val date = new DateTime(2020, 1, 2, 3, 4, 5, 6, DateTimeZone.UTC)
    val sqlContext = session.sqlContext
    import sqlContext.implicits._

    val input = List(FooWithDate(date, "Foo", 1), FooWithDate(date,
"Foo", 3), FooWithDate(date, "Foo", 3))
    val streamInput: MemoryStream[FooWithDate] = new
MemoryStream[FooWithDate](42, session.sqlContext)
    streamInput.addData(input)
    val ds: Dataset[FooWithDate] = streamInput.toDS()

    val mapGroupsWithStateFunction: (Int, Iterator[FooWithDate],
GroupState[FooWithDate]) => FooWithDate =
TestJodaTimeUdt.updateFooState
    val result: Dataset[FooWithDate] = ds
      .groupByKey(x => x.i)
      
.mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout())(mapGroupsWithStateFunction)
    val writeTo = s"random_table_name"

    
result.writeStream.outputMode(OutputMode.Update).format("memory").queryName(writeTo).trigger(Trigger.Once()).start().awaitTermination()
    val combinedResults: Array[FooWithDate] = session.sql(sqlText =
s"select * from $writeTo").as[FooWithDate].collect()
    val expected = Array(FooWithDate(date, "Foo", 1),
FooWithDate(date, "FooFoo", 6))
    combinedResults should contain theSameElementsAs(expected)
  }
}

object TestJodaTimeUdt {
  def updateFooState(id: Int, inputs: Iterator[FooWithDate], state:
GroupState[FooWithDate]): FooWithDate = {
    if (state.hasTimedOut) {
      state.remove()
      state.getOption.get
    } else {
      val inputsSeq: Seq[FooWithDate] = inputs.toSeq
      val startingState = state.getOption.getOrElse(inputsSeq.head)
      val toProcess = if (state.getOption.isDefined) inputsSeq else
inputsSeq.tail
      val updatedFoo = toProcess.foldLeft(startingState)(concatFoo)

      state.update(updatedFoo)
      state.setTimeoutDuration("1 minute")
      updatedFoo
    }
  }

  def concatFoo(a: FooWithDate, b: FooWithDate): FooWithDate =
FooWithDate(b.date, a.s + b.s, a.i + b.i)
}


The test output shows the invalid date:

org.scalatest.exceptions.TestFailedException:
Array(FooWithDate(2021-02-02T19:26:23.374Z,Foo,1),
FooWithDate(2021-02-02T19:26:23.374Z,FooFoo,6)) did not contain the same
elements as
Array(FooWithDate(2020-01-02T03:04:05.006Z,Foo,1),
FooWithDate(2020-01-02T03:04:05.006Z,FooFoo,6))

Is this something folks have encountered before?

Thank you,

Bryan Jeffrey

Reply via email to