ntjohnson1 commented on code in PR #1544: URL: https://github.com/apache/datafusion-python/pull/1544#discussion_r3250913441
########## python/tests/test_pickle_expr.py: ########## @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""In-process pickle round-trip tests for :class:`Expr`. + +Built-in functions and Python scalar UDFs travel with the pickled +expression and do not need worker-side pre-registration. The worker +context (:mod:`datafusion.ipc`) is only consulted for UDFs imported +via the FFI capsule protocol. +""" + +from __future__ import annotations + +import pickle + +import pyarrow as pa +import pytest +from datafusion import Expr, SessionContext, col, lit, udf +from datafusion.ipc import ( + clear_worker_ctx, + set_worker_ctx, +) + + [email protected](autouse=True) +def _reset_worker_ctx(): + """Ensure every test starts with no worker context installed.""" + clear_worker_ctx() + yield + clear_worker_ctx() + + +def _double_udf(): + return udf( + lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + [pa.int64()], + pa.int64(), + volatility="immutable", + name="double", + ) + + +class TestProtoRoundTrip: + def test_builtin_round_trip(self): + e = col("a") + lit(1) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert decoded.canonical_name() == e.canonical_name() + + def test_to_bytes_from_bytes(self): + e = col("x") * lit(7) + blob = e.to_bytes() + assert isinstance(blob, bytes) + decoded = Expr.from_bytes(blob) + assert decoded.canonical_name() == e.canonical_name() + + def test_explicit_ctx_used(self, ctx): + e = col("a") + lit(1) + decoded = Expr.from_bytes(e.to_bytes(), ctx=ctx) + assert decoded.canonical_name() == e.canonical_name() + + +class TestUDFCodec: + """Python scalar UDFs ride inside the proto blob via the Rust codec. + + No worker context needed on the receiver — the cloudpickled callable is + embedded in ``fun_definition`` and reconstructed automatically. + """ + + def test_udf_self_contained_blob(self): + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + # The codec inlines the callable, so the blob is much bigger than a Review Comment: I think this is testing the thing I was asking about but I haven't thought deeply enough if it actually does. If I know cloud pickle says it can serialize lambdas but if I instead had ```python from foo import double def _double_udf(): return udf( double, [pa.int64()], pa.int64(), volatility="immutable", name="double", ) ``` Would I still be able to deserialize this on remote in a python environment without foo? ########## crates/core/src/codec.rs: ########## @@ -284,3 +365,282 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { self.inner.try_decode_udwf(name, buf) } } + +// ============================================================================= +// Shared Python scalar UDF encode / decode helpers +// +// Both `PythonLogicalCodec` and `PythonPhysicalCodec` consult these on +// every `try_encode_udf` / `try_decode_udf` call. Same wire format on +// both layers — a Python `ScalarUDF` referenced inside a `LogicalPlan` +// or an `ExecutionPlan` round-trips identically. +// ============================================================================= + +/// Encode a Python scalar UDF inline if `node` is one. Returns +/// `Ok(true)` when the payload (`DFPYUDF` family prefix, version byte, +/// cloudpickled tuple) was written and the caller should skip its +/// inner codec. Returns `Ok(false)` for any non-Python UDF, signalling +/// the caller to delegate to its `inner`. +pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<bool> { + let Some(py_udf) = node + .inner() + .as_any() + .downcast_ref::<PythonFunctionScalarUDF>() + else { + return Ok(false); + }; + + Python::attach(|py| -> Result<bool> { + let bytes = encode_python_scalar_udf(py, py_udf) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + write_wire_header(buf, PY_SCALAR_UDF_FAMILY); + buf.extend_from_slice(&bytes); + Ok(true) + }) +} + +/// Decode an inline Python scalar UDF payload. Returns `Ok(None)` +/// when `buf` does not carry the `DFPYUDF` family prefix, signalling +/// the caller to delegate to its `inner` codec (and eventually the +/// `FunctionRegistry`). +pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result<Option<Arc<ScalarUDF>>> { + let Some(payload) = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF")? else { + return Ok(None); + }; + + Python::attach(|py| -> Result<Option<Arc<ScalarUDF>>> { + let udf = decode_python_scalar_udf(py, payload) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf)))) + }) +} + +/// Build the cloudpickle payload for a `PythonFunctionScalarUDF`. Review Comment: Maybe it is capture more clearly somewhere else but it feels like there is some nuance of the dependency on cloudpickle that's not fully communicated here. I didn't do too much of a deep dive on it. 1. cloudpickle only works on the same version of python (I'm not sure if it detects the mismatch with a nice error). So potentially your header might want to capture the source python version to give a nicer error and advertise that there is a limitation of only sending to the same version of python for remote workers 2. cloudpickle seems to have serialize by reference (more like dill) and by value (super cool). The former needs the function installed in the environment so when deserialized it can reference it where maybe the later tries to just capture all necessary bits (here is where I didn't deep dive a ton). Those are fairly different mental models for support. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
