This is an automated email from the ASF dual-hosted git repository.

aglinxinyuan pushed a commit to branch xinyuan-loop-feb
in repository https://gitbox.apache.org/repos/asf/texera.git


The following commit(s) were added to refs/heads/xinyuan-loop-feb by this push:
     new 6b369fec62 fix fmt
6b369fec62 is described below

commit 6b369fec62a3897271ef52228ef63e1cc96d0109
Author: Xinyuan Lin <[email protected]>
AuthorDate: Sat Apr 18 03:00:53 2026 -0700

    fix fmt
---
 amber/src/main/python/core/models/state.py         | 54 +++++----------
 amber/src/main/python/core/runnables/main_loop.py  |  4 +-
 .../main/python/core/runnables/network_sender.py   |  4 +-
 .../pythonworker/PythonProxyClient.scala           |  2 +-
 .../org/apache/texera/amber/core/state/State.scala | 77 ++++++++--------------
 5 files changed, 48 insertions(+), 93 deletions(-)

diff --git a/amber/src/main/python/core/models/state.py 
b/amber/src/main/python/core/models/state.py
index 35d3c3620b..57c3cb4206 100644
--- a/amber/src/main/python/core/models/state.py
+++ b/amber/src/main/python/core/models/state.py
@@ -17,7 +17,6 @@
 
 import base64
 import json
-import pickle
 from typing import Any, Dict, TypeAlias
 
 from .schema import AttributeType, Schema
@@ -30,7 +29,6 @@ LOOP_COUNTER = "loop_counter"
 _TYPE_MARKER = "__texera_type__"
 _PAYLOAD_MARKER = "payload"
 _BYTES_TYPE = "bytes"
-_PYTHON_PICKLE_TYPE = "python_pickle"
 
 STATE_SCHEMA = Schema()
 STATE_SCHEMA.add(SERIALIZED_STATE_CONTENT, AttributeType.STRING)
@@ -42,16 +40,13 @@ def state_uri_from_result_uri(result_uri: str) -> str:
 
 
 def serialize_state(state: State) -> Tuple:
-    return serialize_state_dict(state)
-
-
-def serialize_state_dict(state_dict: State) -> Tuple:
-    loop_counter = int(state_dict.get(LOOP_COUNTER, 0))
-    payload_dict = dict(state_dict)
-    payload_dict.pop(LOOP_COUNTER, None)
+    payload = dict(state)
+    loop_counter = int(payload.pop(LOOP_COUNTER, 0))
     return Tuple(
         {
-            SERIALIZED_STATE_CONTENT: dumps_payload(payload_dict),
+            SERIALIZED_STATE_CONTENT: json.dumps(
+                _to_json_value(payload), separators=(",", ":")
+            ),
             LOOP_COUNTER: loop_counter,
         },
         schema=STATE_SCHEMA,
@@ -59,21 +54,12 @@ def serialize_state_dict(state_dict: State) -> Tuple:
 
 
 def deserialize_state(row: Tuple) -> State:
-    serialized_content = row[SERIALIZED_STATE_CONTENT] or "{}"
-    state_dict = loads_payload(serialized_content)
-    state_dict[LOOP_COUNTER] = int(row[LOOP_COUNTER])
-    return state_dict
-
-
-def dumps_payload(payload: State) -> str:
-    return json.dumps(_normalize_for_json(payload), separators=(",", ":"))
+    state = _from_json_value(json.loads(row[SERIALIZED_STATE_CONTENT] or "{}"))
+    state[LOOP_COUNTER] = int(row[LOOP_COUNTER])
+    return state
 
 
-def loads_payload(serialized_payload: str) -> State:
-    return _denormalize_from_json(json.loads(serialized_payload or "{}"))
-
-
-def _normalize_for_json(value: Any) -> Any:
+def _to_json_value(value: Any) -> Any:
     if value is None or isinstance(value, (bool, int, float, str)):
         return value
     if isinstance(value, bytes):
@@ -82,25 +68,17 @@ def _normalize_for_json(value: Any) -> Any:
             _PAYLOAD_MARKER: base64.b64encode(value).decode("ascii"),
         }
     if isinstance(value, dict):
-        return {str(key): _normalize_for_json(inner) for key, inner in 
value.items()}
+        return {str(key): _to_json_value(inner) for key, inner in 
value.items()}
     if isinstance(value, (list, tuple)):
-        return [_normalize_for_json(inner) for inner in value]
-    return {
-        _TYPE_MARKER: _PYTHON_PICKLE_TYPE,
-        _PAYLOAD_MARKER: base64.b64encode(pickle.dumps(value)).decode("ascii"),
-    }
+        return [_to_json_value(inner) for inner in value]
+    raise TypeError(f"State value of type {type(value).__name__} is not JSON 
serializable")
 
 
-def _denormalize_from_json(value: Any) -> Any:
+def _from_json_value(value: Any) -> Any:
     if isinstance(value, list):
-        return [_denormalize_from_json(inner) for inner in value]
+        return [_from_json_value(inner) for inner in value]
     if isinstance(value, dict):
-        marker = value.get(_TYPE_MARKER)
-        if marker == _BYTES_TYPE:
+        if value.get(_TYPE_MARKER) == _BYTES_TYPE:
             return base64.b64decode(value[_PAYLOAD_MARKER])
-        if marker == _PYTHON_PICKLE_TYPE:
-            return pickle.loads(base64.b64decode(value[_PAYLOAD_MARKER]))
-        return {
-            key: _denormalize_from_json(inner) for key, inner in value.items()
-        }
+        return {key: _from_json_value(inner) for key, inner in value.items()}
     return value
diff --git a/amber/src/main/python/core/runnables/main_loop.py 
b/amber/src/main/python/core/runnables/main_loop.py
index 910e66e22e..ffd52c70d4 100644
--- a/amber/src/main/python/core/runnables/main_loop.py
+++ b/amber/src/main/python/core/runnables/main_loop.py
@@ -42,7 +42,7 @@ from core.models.operator import LoopEndOperator, 
LoopStartOperator
 from core.models.state import (
     State,
     STATE_SCHEMA,
-    serialize_state_dict,
+    serialize_state,
     state_uri_from_result_uri,
 )
 from core.runnables.data_processor import DataProcessor
@@ -115,7 +115,7 @@ class MainLoop(StoppableQueueBlockingRunnable):
             del executor.state["LoopStartStateURI"]
             del executor.state["LoopStartId"]
             writer = DocumentFactory.create_document(uri, 
STATE_SCHEMA).writer("0")
-            writer.put_one(serialize_state_dict(executor.state))
+            writer.put_one(serialize_state(executor.state))
             writer.close()
         executor.close()
         # stop the data processing thread
diff --git a/amber/src/main/python/core/runnables/network_sender.py 
b/amber/src/main/python/core/runnables/network_sender.py
index baf503c1f5..11824dbe68 100644
--- a/amber/src/main/python/core/runnables/network_sender.py
+++ b/amber/src/main/python/core/runnables/network_sender.py
@@ -31,7 +31,7 @@ from core.models.state import (
     SERIALIZED_STATE_CONTENT,
     LOOP_COUNTER,
     STATE_SCHEMA,
-    serialize_state_dict,
+    serialize_state,
 )
 from core.proxy import ProxyClient
 from core.util import StoppableQueueBlockingRunnable
@@ -107,7 +107,7 @@ class NetworkSender(StoppableQueueBlockingRunnable):
             data_header = PythonDataHeader(
                 tag=to, payload_type="State"
             )
-            serialized_state = serialize_state_dict(data_payload.frame)
+            serialized_state = serialize_state(data_payload.frame)
             table = pa.Table.from_pydict(
                 {
                     SERIALIZED_STATE_CONTENT: [
diff --git 
a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala
 
b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala
index 6ec68d3cd9..2af84db0bc 100644
--- 
a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala
+++ 
b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala
@@ -127,7 +127,7 @@ class PythonProxyClient(portNumberPromise: Promise[Int], 
val actorId: ActorVirtu
         writeArrowStream(mutable.Queue(ArraySeq.unsafeWrapArray(frame): _*), 
from, "Data")
       case StateFrame(state) =>
         writeArrowStream(
-          mutable.Queue(State.serializeStateMap(state)),
+          mutable.Queue(State.serializeState(state)),
           from,
           "State"
         )
diff --git 
a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
 
b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
index bccff06a58..cabff0e42d 100644
--- 
a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
+++ 
b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
@@ -39,58 +39,37 @@ object State {
     new Attribute(LoopCounterColumn, AttributeType.LONG)
   )
 
-  def stateUriFromResultUri(resultUri: URI): URI = {
+  def stateUriFromResultUri(resultUri: URI): URI =
     new URI(resultUri.toString.replace("/result", "/state"))
-  }
 
   def serializeState(state: Map[String, Any]): Tuple = {
-    serializeStateMap(state)
-  }
-
-  def serializeStateMap(state: Map[String, Any]): Tuple = {
     val loopCounter = state.get(LoopCounterColumn).map(toLong).getOrElse(0L)
-    val payloadJson = dumpsPayload(state.removed(LoopCounterColumn))
-
-    Tuple
-      .builder(materializationSchema)
-      .addSequentially(Array(payloadJson, loopCounter))
-      .build()
+    val payloadJson = 
objectMapper.writeValueAsString(toJsonValue(state.removed(LoopCounterColumn)))
+    Tuple.builder(materializationSchema).addSequentially(Array(payloadJson, 
loopCounter)).build()
   }
 
   def deserializeState(tuple: Tuple): Map[String, Any] = {
-    val serializedContent =
+    val payload =
       
Option(tuple.getField[String](SerializedStateContentColumn)).getOrElse("{}")
-    val loopCounter = toLong(tuple.getField[Any](LoopCounterColumn))
-    loadsPayload(serializedContent) + (LoopCounterColumn -> loopCounter)
-  }
-
-  private def dumpsPayload(payload: Map[String, Any]): String = {
-    
objectMapper.writeValueAsString(payload.view.mapValues(normalizeForJson).toMap)
+    val root = objectMapper.readTree(payload)
+    val state =
+      if (root == null || !root.isObject) Map.empty[String, Any]
+      else root.fields().asScala.map(entry => entry.getKey -> 
fromJsonValue(entry.getValue)).toMap
+    state + (LoopCounterColumn -> 
toLong(tuple.getField[Any](LoopCounterColumn)))
   }
 
-  private def loadsPayload(serializedPayload: String): Map[String, Any] = {
-    val root = objectMapper.readTree(Option(serializedPayload).getOrElse("{}"))
-    if (root == null || !root.isObject) {
-      Map.empty
-    } else {
-      root.fields().asScala.map(entry => entry.getKey -> 
denormalizeFromJson(entry.getValue)).toMap
-    }
-  }
-
-  private def normalizeForJson(value: Any): Any = {
-    value match {
-      case null                  => null
-      case bytes: Array[Byte]    =>
-        Map(BytesTypeMarker -> BytesValue, PayloadMarker -> 
Base64.getEncoder.encodeToString(bytes))
-      case map: Map[_, _]        =>
-        map.iterator.map { case (k, v) => k.toString -> normalizeForJson(v) 
}.toMap
-      case iterable: Iterable[_] =>
-        iterable.map(normalizeForJson).toList
-      case other                 => other
-    }
+  private def toJsonValue(value: Any): Any = value match {
+    case null                  => null
+    case bytes: Array[Byte]    =>
+      Map(BytesTypeMarker -> BytesValue, PayloadMarker -> 
Base64.getEncoder.encodeToString(bytes))
+    case map: Map[_, _]        =>
+      map.iterator.map { case (k, v) => k.toString -> toJsonValue(v) }.toMap
+    case iterable: Iterable[_] =>
+      iterable.map(toJsonValue).toList
+    case other                 => other
   }
 
-  private def denormalizeFromJson(node: JsonNode): Any = {
+  private def fromJsonValue(node: JsonNode): Any = {
     if (node == null || node.isNull) {
       null
     } else if (node.isObject) {
@@ -99,10 +78,10 @@ object State {
         case Some(typeNode) if typeNode.isTextual && typeNode.asText() == 
BytesValue =>
           Base64.getDecoder.decode(fields(PayloadMarker).asText())
         case _ =>
-          fields.view.mapValues(denormalizeFromJson).toMap
+          fields.view.mapValues(fromJsonValue).toMap
       }
     } else if (node.isArray) {
-      node.elements().asScala.map(denormalizeFromJson).toList
+      node.elements().asScala.map(fromJsonValue).toList
     } else if (node.isBoolean) {
       node.asBoolean()
     } else if (node.isIntegralNumber) {
@@ -114,13 +93,11 @@ object State {
     }
   }
 
-  private def toLong(value: Any): Long = {
-    value match {
-      case null                     => 0L
-      case number: java.lang.Number => number.longValue()
-      case text: String             => text.toLong
-      case other =>
-        throw new IllegalArgumentException(s"Cannot convert $other to loop 
counter")
-    }
+  private def toLong(value: Any): Long = value match {
+    case null                     => 0L
+    case number: java.lang.Number => number.longValue()
+    case text: String             => text.toLong
+    case other                    =>
+      throw new IllegalArgumentException(s"Cannot convert $other to loop 
counter")
   }
 }

Reply via email to