This is an automated email from the ASF dual-hosted git repository.
aglinxinyuan pushed a commit to branch xinyuan-loop-feb
in repository https://gitbox.apache.org/repos/asf/texera.git
The following commit(s) were added to refs/heads/xinyuan-loop-feb by this push:
new 6b369fec62 fix fmt
6b369fec62 is described below
commit 6b369fec62a3897271ef52228ef63e1cc96d0109
Author: Xinyuan Lin <[email protected]>
AuthorDate: Sat Apr 18 03:00:53 2026 -0700
fix fmt
---
amber/src/main/python/core/models/state.py | 54 +++++----------
amber/src/main/python/core/runnables/main_loop.py | 4 +-
.../main/python/core/runnables/network_sender.py | 4 +-
.../pythonworker/PythonProxyClient.scala | 2 +-
.../org/apache/texera/amber/core/state/State.scala | 77 ++++++++--------------
5 files changed, 48 insertions(+), 93 deletions(-)
diff --git a/amber/src/main/python/core/models/state.py
b/amber/src/main/python/core/models/state.py
index 35d3c3620b..57c3cb4206 100644
--- a/amber/src/main/python/core/models/state.py
+++ b/amber/src/main/python/core/models/state.py
@@ -17,7 +17,6 @@
import base64
import json
-import pickle
from typing import Any, Dict, TypeAlias
from .schema import AttributeType, Schema
@@ -30,7 +29,6 @@ LOOP_COUNTER = "loop_counter"
_TYPE_MARKER = "__texera_type__"
_PAYLOAD_MARKER = "payload"
_BYTES_TYPE = "bytes"
-_PYTHON_PICKLE_TYPE = "python_pickle"
STATE_SCHEMA = Schema()
STATE_SCHEMA.add(SERIALIZED_STATE_CONTENT, AttributeType.STRING)
@@ -42,16 +40,13 @@ def state_uri_from_result_uri(result_uri: str) -> str:
def serialize_state(state: State) -> Tuple:
- return serialize_state_dict(state)
-
-
-def serialize_state_dict(state_dict: State) -> Tuple:
- loop_counter = int(state_dict.get(LOOP_COUNTER, 0))
- payload_dict = dict(state_dict)
- payload_dict.pop(LOOP_COUNTER, None)
+ payload = dict(state)
+ loop_counter = int(payload.pop(LOOP_COUNTER, 0))
return Tuple(
{
- SERIALIZED_STATE_CONTENT: dumps_payload(payload_dict),
+ SERIALIZED_STATE_CONTENT: json.dumps(
+ _to_json_value(payload), separators=(",", ":")
+ ),
LOOP_COUNTER: loop_counter,
},
schema=STATE_SCHEMA,
@@ -59,21 +54,12 @@ def serialize_state_dict(state_dict: State) -> Tuple:
def deserialize_state(row: Tuple) -> State:
- serialized_content = row[SERIALIZED_STATE_CONTENT] or "{}"
- state_dict = loads_payload(serialized_content)
- state_dict[LOOP_COUNTER] = int(row[LOOP_COUNTER])
- return state_dict
-
-
-def dumps_payload(payload: State) -> str:
- return json.dumps(_normalize_for_json(payload), separators=(",", ":"))
+ state = _from_json_value(json.loads(row[SERIALIZED_STATE_CONTENT] or "{}"))
+ state[LOOP_COUNTER] = int(row[LOOP_COUNTER])
+ return state
-def loads_payload(serialized_payload: str) -> State:
- return _denormalize_from_json(json.loads(serialized_payload or "{}"))
-
-
-def _normalize_for_json(value: Any) -> Any:
+def _to_json_value(value: Any) -> Any:
if value is None or isinstance(value, (bool, int, float, str)):
return value
if isinstance(value, bytes):
@@ -82,25 +68,17 @@ def _normalize_for_json(value: Any) -> Any:
_PAYLOAD_MARKER: base64.b64encode(value).decode("ascii"),
}
if isinstance(value, dict):
- return {str(key): _normalize_for_json(inner) for key, inner in
value.items()}
+ return {str(key): _to_json_value(inner) for key, inner in
value.items()}
if isinstance(value, (list, tuple)):
- return [_normalize_for_json(inner) for inner in value]
- return {
- _TYPE_MARKER: _PYTHON_PICKLE_TYPE,
- _PAYLOAD_MARKER: base64.b64encode(pickle.dumps(value)).decode("ascii"),
- }
+ return [_to_json_value(inner) for inner in value]
+ raise TypeError(f"State value of type {type(value).__name__} is not JSON
serializable")
-def _denormalize_from_json(value: Any) -> Any:
+def _from_json_value(value: Any) -> Any:
if isinstance(value, list):
- return [_denormalize_from_json(inner) for inner in value]
+ return [_from_json_value(inner) for inner in value]
if isinstance(value, dict):
- marker = value.get(_TYPE_MARKER)
- if marker == _BYTES_TYPE:
+ if value.get(_TYPE_MARKER) == _BYTES_TYPE:
return base64.b64decode(value[_PAYLOAD_MARKER])
- if marker == _PYTHON_PICKLE_TYPE:
- return pickle.loads(base64.b64decode(value[_PAYLOAD_MARKER]))
- return {
- key: _denormalize_from_json(inner) for key, inner in value.items()
- }
+ return {key: _from_json_value(inner) for key, inner in value.items()}
return value
diff --git a/amber/src/main/python/core/runnables/main_loop.py
b/amber/src/main/python/core/runnables/main_loop.py
index 910e66e22e..ffd52c70d4 100644
--- a/amber/src/main/python/core/runnables/main_loop.py
+++ b/amber/src/main/python/core/runnables/main_loop.py
@@ -42,7 +42,7 @@ from core.models.operator import LoopEndOperator,
LoopStartOperator
from core.models.state import (
State,
STATE_SCHEMA,
- serialize_state_dict,
+ serialize_state,
state_uri_from_result_uri,
)
from core.runnables.data_processor import DataProcessor
@@ -115,7 +115,7 @@ class MainLoop(StoppableQueueBlockingRunnable):
del executor.state["LoopStartStateURI"]
del executor.state["LoopStartId"]
writer = DocumentFactory.create_document(uri,
STATE_SCHEMA).writer("0")
- writer.put_one(serialize_state_dict(executor.state))
+ writer.put_one(serialize_state(executor.state))
writer.close()
executor.close()
# stop the data processing thread
diff --git a/amber/src/main/python/core/runnables/network_sender.py
b/amber/src/main/python/core/runnables/network_sender.py
index baf503c1f5..11824dbe68 100644
--- a/amber/src/main/python/core/runnables/network_sender.py
+++ b/amber/src/main/python/core/runnables/network_sender.py
@@ -31,7 +31,7 @@ from core.models.state import (
SERIALIZED_STATE_CONTENT,
LOOP_COUNTER,
STATE_SCHEMA,
- serialize_state_dict,
+ serialize_state,
)
from core.proxy import ProxyClient
from core.util import StoppableQueueBlockingRunnable
@@ -107,7 +107,7 @@ class NetworkSender(StoppableQueueBlockingRunnable):
data_header = PythonDataHeader(
tag=to, payload_type="State"
)
- serialized_state = serialize_state_dict(data_payload.frame)
+ serialized_state = serialize_state(data_payload.frame)
table = pa.Table.from_pydict(
{
SERIALIZED_STATE_CONTENT: [
diff --git
a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala
b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala
index 6ec68d3cd9..2af84db0bc 100644
---
a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala
+++
b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala
@@ -127,7 +127,7 @@ class PythonProxyClient(portNumberPromise: Promise[Int],
val actorId: ActorVirtu
writeArrowStream(mutable.Queue(ArraySeq.unsafeWrapArray(frame): _*),
from, "Data")
case StateFrame(state) =>
writeArrowStream(
- mutable.Queue(State.serializeStateMap(state)),
+ mutable.Queue(State.serializeState(state)),
from,
"State"
)
diff --git
a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
index bccff06a58..cabff0e42d 100644
---
a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
+++
b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala
@@ -39,58 +39,37 @@ object State {
new Attribute(LoopCounterColumn, AttributeType.LONG)
)
- def stateUriFromResultUri(resultUri: URI): URI = {
+ def stateUriFromResultUri(resultUri: URI): URI =
new URI(resultUri.toString.replace("/result", "/state"))
- }
def serializeState(state: Map[String, Any]): Tuple = {
- serializeStateMap(state)
- }
-
- def serializeStateMap(state: Map[String, Any]): Tuple = {
val loopCounter = state.get(LoopCounterColumn).map(toLong).getOrElse(0L)
- val payloadJson = dumpsPayload(state.removed(LoopCounterColumn))
-
- Tuple
- .builder(materializationSchema)
- .addSequentially(Array(payloadJson, loopCounter))
- .build()
+ val payloadJson =
objectMapper.writeValueAsString(toJsonValue(state.removed(LoopCounterColumn)))
+ Tuple.builder(materializationSchema).addSequentially(Array(payloadJson,
loopCounter)).build()
}
def deserializeState(tuple: Tuple): Map[String, Any] = {
- val serializedContent =
+ val payload =
Option(tuple.getField[String](SerializedStateContentColumn)).getOrElse("{}")
- val loopCounter = toLong(tuple.getField[Any](LoopCounterColumn))
- loadsPayload(serializedContent) + (LoopCounterColumn -> loopCounter)
- }
-
- private def dumpsPayload(payload: Map[String, Any]): String = {
-
objectMapper.writeValueAsString(payload.view.mapValues(normalizeForJson).toMap)
+ val root = objectMapper.readTree(payload)
+ val state =
+ if (root == null || !root.isObject) Map.empty[String, Any]
+ else root.fields().asScala.map(entry => entry.getKey ->
fromJsonValue(entry.getValue)).toMap
+ state + (LoopCounterColumn ->
toLong(tuple.getField[Any](LoopCounterColumn)))
}
- private def loadsPayload(serializedPayload: String): Map[String, Any] = {
- val root = objectMapper.readTree(Option(serializedPayload).getOrElse("{}"))
- if (root == null || !root.isObject) {
- Map.empty
- } else {
- root.fields().asScala.map(entry => entry.getKey ->
denormalizeFromJson(entry.getValue)).toMap
- }
- }
-
- private def normalizeForJson(value: Any): Any = {
- value match {
- case null => null
- case bytes: Array[Byte] =>
- Map(BytesTypeMarker -> BytesValue, PayloadMarker ->
Base64.getEncoder.encodeToString(bytes))
- case map: Map[_, _] =>
- map.iterator.map { case (k, v) => k.toString -> normalizeForJson(v)
}.toMap
- case iterable: Iterable[_] =>
- iterable.map(normalizeForJson).toList
- case other => other
- }
+ private def toJsonValue(value: Any): Any = value match {
+ case null => null
+ case bytes: Array[Byte] =>
+ Map(BytesTypeMarker -> BytesValue, PayloadMarker ->
Base64.getEncoder.encodeToString(bytes))
+ case map: Map[_, _] =>
+ map.iterator.map { case (k, v) => k.toString -> toJsonValue(v) }.toMap
+ case iterable: Iterable[_] =>
+ iterable.map(toJsonValue).toList
+ case other => other
}
- private def denormalizeFromJson(node: JsonNode): Any = {
+ private def fromJsonValue(node: JsonNode): Any = {
if (node == null || node.isNull) {
null
} else if (node.isObject) {
@@ -99,10 +78,10 @@ object State {
case Some(typeNode) if typeNode.isTextual && typeNode.asText() ==
BytesValue =>
Base64.getDecoder.decode(fields(PayloadMarker).asText())
case _ =>
- fields.view.mapValues(denormalizeFromJson).toMap
+ fields.view.mapValues(fromJsonValue).toMap
}
} else if (node.isArray) {
- node.elements().asScala.map(denormalizeFromJson).toList
+ node.elements().asScala.map(fromJsonValue).toList
} else if (node.isBoolean) {
node.asBoolean()
} else if (node.isIntegralNumber) {
@@ -114,13 +93,11 @@ object State {
}
}
- private def toLong(value: Any): Long = {
- value match {
- case null => 0L
- case number: java.lang.Number => number.longValue()
- case text: String => text.toLong
- case other =>
- throw new IllegalArgumentException(s"Cannot convert $other to loop
counter")
- }
+ private def toLong(value: Any): Long = value match {
+ case null => 0L
+ case number: java.lang.Number => number.longValue()
+ case text: String => text.toLong
+ case other =>
+ throw new IllegalArgumentException(s"Cannot convert $other to loop
counter")
}
}