This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 13d2759fb feat(java/driver/jni): implement incremental execution,
Substrait (#4362)
13d2759fb is described below
commit 13d2759fb97712a824db612bc38c91558abf7ad8
Author: David Li <[email protected]>
AuthorDate: Mon Jun 8 17:29:01 2026 -0700
feat(java/driver/jni): implement incremental execution, Substrait (#4362)
Also, fix an incidental bug in the Go driver FFI layer.
Closes #3257.
---
.github/workflows/java.yml | 2 +-
go/adbc/driver/flightsql/cmd/testserver/main.go | 32 +++-
go/adbc/pkg/_tmpl/driver.go.tmpl | 4 +-
go/adbc/pkg/flightsql/driver.go | 4 +-
go/adbc/pkg/panicdummy/driver.go | 4 +-
.../adbc/driver/jni/FlightSqlIntegrationTest.java | 122 ++++++++++++--
...st.java => FlightSqlSqliteIntegrationTest.java} | 3 +-
java/driver/jni/src/main/cpp/jni_wrapper.cc | 182 +++++++++++++--------
.../apache/arrow/adbc/driver/jni/JniDriver.java | 6 +
.../apache/arrow/adbc/driver/jni/JniStatement.java | 59 +++++++
.../arrow/adbc/driver/jni/impl/JniLoader.java | 5 +
.../arrow/adbc/driver/jni/impl/NativeAdbc.java | 2 +
12 files changed, 333 insertions(+), 92 deletions(-)
diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml
index f2a3f0a20..d2ac4b9cf 100644
--- a/.github/workflows/java.yml
+++ b/.github/workflows/java.yml
@@ -233,7 +233,7 @@ jobs:
- name: Start Dependencies
if: matrix.os == 'Linux' && matrix.arch == 'amd64'
run: |
- docker compose up --detach --wait flightsql-sqlite-test mssql-test
postgres-test
+ docker compose up --detach --wait flightsql-test
flightsql-sqlite-test mssql-test postgres-test
cat .env | grep -v -e '^#' | grep -e '^ADBC_' | awk NF | sed
's/"//g' | tee -a $GITHUB_ENV
- name: Download thirdparty driver
diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go
b/go/adbc/driver/flightsql/cmd/testserver/main.go
index 577bf3983..5a38b3e11 100644
--- a/go/adbc/driver/flightsql/cmd/testserver/main.go
+++ b/go/adbc/driver/flightsql/cmd/testserver/main.go
@@ -38,6 +38,7 @@ import (
"math/big"
"net"
"os"
+ "slices"
"strconv"
"strings"
"sync"
@@ -79,6 +80,9 @@ var recordedHeadersSchema = arrow.NewSchema([]arrow.Field{
{Name: "value", Type: arrow.BinaryTypes.String, Nullable: false},
}, nil)
+// Not a real plan, just ensuring binary values get passed through correctly
+var substraitPlan = []byte{42, 0, 129, 255}
+
func StatusWithDetail(code codes.Code, message string, details
...proto.Message) error {
p := status.New(code, message).Proto()
// Have to do this by hand because gRPC uses deprecated proto import
@@ -184,6 +188,22 @@ func (srv *ExampleServer) GetFlightInfoStatement(ctx
context.Context, cmd flight
}, nil
}
+func (srv *ExampleServer) GetFlightInfoSubstraitPlan(ctx context.Context, cmd
flightsql.StatementSubstraitPlan, desc *flight.FlightDescriptor)
(*flight.FlightInfo, error) {
+ srv.recordHeaders(ctx, "GetFlightInfoSubstraitPlan")
+ log.Printf("GetFlightInfoSubstraitPlan: %v", cmd.GetPlan().Plan)
+ ticket, err := flightsql.CreateStatementQueryTicket(cmd.GetPlan().Plan)
+ if err != nil {
+ return nil, err
+ }
+
+ return &flight.FlightInfo{
+ Endpoint: []*flight.FlightEndpoint{{Ticket:
&flight.Ticket{Ticket: ticket}}},
+ FlightDescriptor: desc,
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }, nil
+}
+
func (srv *ExampleServer) PollFlightInfo(ctx context.Context, desc
*flight.FlightDescriptor) (*flight.PollInfo, error) {
srv.mu.Lock()
defer srv.mu.Unlock()
@@ -406,8 +426,16 @@ func (srv *ExampleServer) DoGetPreparedStatement(ctx
context.Context, cmd flight
}
func (srv *ExampleServer) DoGetStatement(ctx context.Context, cmd
flightsql.StatementQueryTicket) (schema *arrow.Schema, out <-chan
flight.StreamChunk, err error) {
- schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type:
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
- rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema,
strings.NewReader(`[{"ints": 5}]`))
+ log.Printf("DoGetStatement: %v", cmd.GetStatementHandle())
+
+ var rec arrow.RecordBatch
+ if slices.Equal(cmd.GetStatementHandle(), substraitPlan) {
+ schema = arrow.NewSchema([]arrow.Field{{Name: "substrait",
Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ rec, _, err = array.RecordFromJSON(memory.DefaultAllocator,
schema, strings.NewReader(`[{"substrait": 5}]`))
+ } else {
+ schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type:
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ rec, _, err = array.RecordFromJSON(memory.DefaultAllocator,
schema, strings.NewReader(`[{"ints": 5}]`))
+ }
ch := make(chan flight.StreamChunk)
go func() {
diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index 05b2b11a3..673e73265 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -1618,7 +1618,9 @@ func {{.Prefix}}StatementSetSubstraitPlan(stmt
*C.struct_AdbcStatement, plan *C.
return C.ADBC_STATUS_INVALID_STATE
}
- return C.AdbcStatusCode(errToAdbcErr(err,
st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length)))))
+ // XXX: must copy here or the C side may invalidate *plan
+ goPlan := C.GoBytes(unsafe.Pointer(plan), C.int(length))
+ return C.AdbcStatusCode(errToAdbcErr(err,
st.stmt.SetSubstraitPlan(goPlan)))
}
//export {{.Prefix}}StatementBind
diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go
index 3cecf4f79..46506c415 100644
--- a/go/adbc/pkg/flightsql/driver.go
+++ b/go/adbc/pkg/flightsql/driver.go
@@ -1621,7 +1621,9 @@ func FlightSQLStatementSetSubstraitPlan(stmt
*C.struct_AdbcStatement, plan *C.cu
return C.ADBC_STATUS_INVALID_STATE
}
- return C.AdbcStatusCode(errToAdbcErr(err,
st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length)))))
+ // XXX: must copy here or the C side may invalidate *plan
+ goPlan := C.GoBytes(unsafe.Pointer(plan), C.int(length))
+ return C.AdbcStatusCode(errToAdbcErr(err,
st.stmt.SetSubstraitPlan(goPlan)))
}
//export FlightSQLStatementBind
diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go
index 07bbc4ac6..2deea3f88 100644
--- a/go/adbc/pkg/panicdummy/driver.go
+++ b/go/adbc/pkg/panicdummy/driver.go
@@ -1621,7 +1621,9 @@ func PanicDummyStatementSetSubstraitPlan(stmt
*C.struct_AdbcStatement, plan *C.c
return C.ADBC_STATUS_INVALID_STATE
}
- return C.AdbcStatusCode(errToAdbcErr(err,
st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length)))))
+ // XXX: must copy here or the C side may invalidate *plan
+ goPlan := C.GoBytes(unsafe.Pointer(plan), C.int(length))
+ return C.AdbcStatusCode(errToAdbcErr(err,
st.stmt.SetSubstraitPlan(goPlan)))
}
//export PanicDummyStatementBind
diff --git
a/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
b/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
index b5ab1c7ee..c63bf8dd0 100644
---
a/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
+++
b/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
@@ -18,22 +18,34 @@
package org.apache.arrow.adbc.driver.jni;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import org.apache.arrow.adbc.core.AdbcConnection;
import org.apache.arrow.adbc.core.AdbcDatabase;
import org.apache.arrow.adbc.core.AdbcDriver;
+import org.apache.arrow.adbc.core.AdbcStatement;
+import org.apache.arrow.adbc.driver.testsuite.ArrowAssertions;
+import org.apache.arrow.adbc.driver.testsuite.ArrowToJava;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.types.Types;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+/** Integration tests against the mock test server. */
public class FlightSqlIntegrationTest {
- public static final String URI_ENV = "ADBC_SQLITE_FLIGHTSQL_URI";
+ public static final String URI_ENV = "ADBC_TEST_FLIGHTSQL_URI";
static String URI = System.getenv(URI_ENV);
BufferAllocator allocator;
@@ -55,6 +67,8 @@ public class FlightSqlIntegrationTest {
Map<String, Object> parameters = new HashMap<>();
JniDriver.PARAM_DRIVER.set(parameters, "adbc_driver_flightsql");
AdbcDriver.PARAM_URI.set(parameters, URI);
+ parameters.put("adbc.flight.sql.client_option.tls_skip_verify", "true");
+ parameters.put("adbc.flight.sql.authorization_header", "Basic
dXNlcjpwYXNzd29yZA==");
db = driver.open(parameters);
conn = db.connect();
}
@@ -69,30 +83,106 @@ public class FlightSqlIntegrationTest {
@Test
void simple() throws Exception {
try (var stmt = conn.createStatement()) {
+ // smoke test; the server doesn't actually run any query and returns a
fixed result
stmt.setSqlQuery("SELECT 1 + 1 AS sum");
try (var reader = stmt.executeQuery()) {
- assertThat(reader.getReader().loadNextBatch()).isTrue();
-
assertThat(reader.getReader().getVectorSchemaRoot().getVector("sum").getObject(0))
- .isEqualTo(2L);
+ assertThat(ArrowToJava.toIntegers(reader.getReader(),
"ints")).containsExactly(5);
}
}
}
@Test
- void partitioned() throws Exception {
+ void substrait() throws Exception {
try (var stmt = conn.createStatement()) {
- stmt.setSqlQuery("SELECT 1 + 1 AS sum");
- var partitions = stmt.executePartitioned();
- assertThat(partitions.getPartitionDescriptors().size()).isEqualTo(1);
- assertThat(partitions.getAffectedRows()).isEqualTo(-1);
- // The test server doesn't give a schema.
- assertThat(partitions.getSchema()).isNull();
-
- try (var reader =
-
conn.readPartition(partitions.getPartitionDescriptors().get(0).getDescriptor()))
{
- assertThat(reader.loadNextBatch()).isTrue();
-
assertThat(reader.getVectorSchemaRoot().getVector("sum").getObject(0)).isEqualTo(2L);
+ stmt.setSubstraitPlan(ByteBuffer.wrap(new byte[] {42, 0, (byte) 129,
(byte) 255}));
+ try (var reader = stmt.executeQuery()) {
+ assertThat(ArrowToJava.toIntegers(reader.getReader(),
"substrait")).containsExactly(5);
}
}
}
+
+ @Test
+ void poll() throws Exception {
+ var expectedSchema = new Schema(List.of(Field.nullable("ints",
Types.MinorType.INT.getType())));
+ try (var stmt = conn.createStatement()) {
+ stmt.setSqlQuery("poll");
+ stmt.prepare();
+ var iter = stmt.pollPartitioned();
+ var elements = new ArrayList<AdbcStatement.PartitionResult>();
+ iter.forEachRemaining(elements::add);
+ assertThat(elements).size().isEqualTo(5);
+ assertThat(elements.get(0).getSchema()).isNull();
+ assertThat(elements.get(1).getSchema()).isNull();
+
ArrowAssertions.assertSchema(elements.get(2).getSchema()).isEqualTo(expectedSchema);
+
ArrowAssertions.assertSchema(elements.get(3).getSchema()).isEqualTo(expectedSchema);
+
ArrowAssertions.assertSchema(elements.get(4).getSchema()).isEqualTo(expectedSchema);
+
+ elements.forEach(
+ partitionResult -> {
+
assertThat(partitionResult.getPartitionDescriptors()).size().isEqualTo(1);
+ });
+ }
+ }
+
+ @Test
+ void progress() throws Exception {
+ try (var stmt = conn.createStatement()) {
+ stmt.setSqlQuery("poll");
+ stmt.prepare();
+ var iter = stmt.pollPartitioned();
+ assertThat(stmt.getProgress()).isEqualTo(0.0);
+ assertThat(stmt.getMaxProgress()).isEqualTo(1.0);
+
+ for (int i = 0; i < 5; i++) {
+ assertThat(iter.hasNext()).isTrue();
+ assertThat(stmt.getProgress()).isCloseTo(0.2 * (i + 1),
Assertions.offset(0.05));
+ assertThat(stmt.getMaxProgress()).isEqualTo(1.0);
+ assertThat(iter.next()).isNotNull();
+ }
+
+ assertThat(iter.hasNext()).isFalse();
+ }
+ }
+
+ @Test
+ void pollError() throws Exception {
+ try (var stmt = conn.createStatement()) {
+ stmt.setSqlQuery("error_poll_later");
+ stmt.prepare();
+ var iter = stmt.pollPartitioned();
+ assertThat(iter.hasNext()).isTrue();
+ assertThat(iter.next()).isNotNull();
+ assertThatThrownBy(iter::hasNext).hasMessageContaining("expected error");
+ }
+ }
+
+ @Test
+ void pollForever() throws Exception {
+ try (var stmt = conn.createStatement()) {
+ stmt.setSqlQuery("forever");
+ stmt.prepare();
+ var iter = stmt.pollPartitioned();
+ assertThat(iter.hasNext()).isTrue();
+ assertThat(iter.next()).isNotNull();
+ // XXX: rather janky. must cancel on the background; cancelling in
between statements has no
+ // effect. we should perhaps reconsider this behavior
+ var t =
+ new Thread(
+ () -> {
+ try {
+ while (stmt.getProgress() < 0.05) {
+ Thread.sleep(100);
+ }
+ System.out.println("progress: " + stmt.getProgress());
+ stmt.cancel();
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ });
+ t.start();
+
+ assertThatThrownBy(iter::hasNext).hasMessageContaining("context
canceled");
+ t.join(10000);
+ }
+ }
}
diff --git
a/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
b/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlSqliteIntegrationTest.java
similarity index 97%
copy from
java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
copy to
java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlSqliteIntegrationTest.java
index b5ab1c7ee..e73f7bb6a 100644
---
a/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
+++
b/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlSqliteIntegrationTest.java
@@ -32,7 +32,8 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-public class FlightSqlIntegrationTest {
+/** Integration tests against the SQLite server. */
+public class FlightSqlSqliteIntegrationTest {
public static final String URI_ENV = "ADBC_SQLITE_FLIGHTSQL_URI";
static String URI = System.getenv(URI_ENV);
diff --git a/java/driver/jni/src/main/cpp/jni_wrapper.cc
b/java/driver/jni/src/main/cpp/jni_wrapper.cc
index d1f8309b3..7a8a73d53 100644
--- a/java/driver/jni/src/main/cpp/jni_wrapper.cc
+++ b/java/driver/jni/src/main/cpp/jni_wrapper.cc
@@ -208,6 +208,93 @@ auto WithJniString(JNIEnv* env, jstring jni_string,
Callable&& callable) {
return callable(view.value);
}
+// Get the contents of a ByteBuffer. We may have to copy, so also take a
+// scratch buffer to hold the copy if needed. If possible, will use
+// DirectByteBuffer so no additional allocation is necessary.
+const uint8_t* GetJniByteBuffer(JNIEnv* env, jobject bytebuffer,
+ std::vector<uint8_t>& scratch,
+ size_t& serialized_length) {
+ jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
+ if (!env->IsInstanceOf(bytebuffer, bb_class)) {
+ ThrowJavaException(env, "java/lang/IllegalArgumentException",
+ "Argument must be a ByteBuffer");
+ return nullptr;
+ }
+
+ jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
+ jmethodID bb_position = RequireMethod(env, bb_class, "position", "()I");
+
+ jint remaining = env->CallIntMethod(bytebuffer, bb_remaining);
+ if (remaining < 0) {
+ ThrowJavaException(env, "java/lang/IllegalArgumentException",
+ "ByteBuffer remaining() must be non-negative");
+ return nullptr;
+ } else if (env->ExceptionCheck()) {
+ return nullptr;
+ }
+ serialized_length = static_cast<size_t>(remaining);
+
+ jint position = env->CallIntMethod(bytebuffer, bb_position);
+ if (position < 0) {
+ ThrowJavaException(env, "java/lang/IllegalArgumentException",
+ "ByteBuffer position() must be non-negative");
+ return nullptr;
+ } else if (env->ExceptionCheck()) {
+ return nullptr;
+ }
+ const size_t offset = static_cast<size_t>(position);
+
+ // fast path (if direct buffer)
+ void* buf = env->GetDirectBufferAddress(bytebuffer);
+ if (buf) {
+ return static_cast<const uint8_t*>(buf) + offset;
+ }
+
+ // middle path (copy from backing array)
+ {
+ jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
+ jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
+ jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset",
"()I");
+ jboolean has_array = env->CallBooleanMethod(bytebuffer, bb_has_array);
+ if (env->ExceptionCheck()) return nullptr;
+ if (has_array) {
+ jint array_offset = env->CallIntMethod(bytebuffer, bb_array_offset);
+ if (env->ExceptionCheck()) return nullptr;
+
+ auto array =
+ reinterpret_cast<jbyteArray>(env->CallObjectMethod(bytebuffer,
bb_array));
+ if (env->ExceptionCheck()) return nullptr;
+
+ assert(serialized_length <=
static_cast<size_t>(env->GetArrayLength(array)));
+ scratch.resize(serialized_length);
+ env->GetByteArrayRegion(array, array_offset + position,
+ static_cast<jsize>(serialized_length),
+ reinterpret_cast<jbyte*>(scratch.data()));
+ return scratch.data();
+ }
+ }
+
+ // slow path (copy via invoking Java code to copy into a temp array)
+ {
+ jmethodID bb_get = RequireMethod(env, bb_class, "get",
"([B)Ljava/nio/ByteBuffer;");
+ jbyteArray temp = env->NewByteArray(static_cast<jsize>(serialized_length));
+ if (!temp) {
+ ThrowJavaException(env, "java/lang/OutOfMemoryError",
+ "Failed to allocate byte array to copy ByteBuffer");
+ return nullptr;
+ }
+
+ env->CallObjectMethod(bytebuffer, bb_get, temp);
+ if (env->ExceptionCheck()) return nullptr;
+
+ scratch.resize(serialized_length);
+ env->GetByteArrayRegion(temp, 0, static_cast<jsize>(serialized_length),
+ reinterpret_cast<jbyte*>(scratch.data()));
+ return scratch.data();
+ }
+ // not reachable
+}
+
} // namespace
extern "C" {
@@ -475,9 +562,9 @@
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementExecuteQuery(
JNIEXPORT void JNICALL
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementSetSqlQuery(
JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jstring query) {
+ struct AdbcError error = ADBC_ERROR_INIT;
+ auto* ptr = reinterpret_cast<struct
AdbcStatement*>(static_cast<uintptr_t>(handle));
try {
- struct AdbcError error = ADBC_ERROR_INIT;
- auto* ptr = reinterpret_cast<struct
AdbcStatement*>(static_cast<uintptr_t>(handle));
JniStringView query_str(env, query);
CHECK_ADBC_ERROR(AdbcStatementSetSqlQuery(ptr, query_str.value, &error),
error);
} catch (const AdbcException& e) {
@@ -485,6 +572,26 @@
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementSetSqlQuery(
}
}
+JNIEXPORT void JNICALL
+Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementSetSubstraitPlan(
+ JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jobject plan) {
+ struct AdbcError error = ADBC_ERROR_INIT;
+ auto* ptr = reinterpret_cast<struct
AdbcStatement*>(static_cast<uintptr_t>(handle));
+ std::vector<uint8_t> allocated_plan;
+ size_t plan_length = 0;
+ try {
+ const uint8_t* plan_ptr = GetJniByteBuffer(env, plan, allocated_plan,
plan_length);
+ if (!plan_ptr || env->ExceptionCheck()) {
+ return; // GetJniByteBuffer failed
+ }
+ assert(plan_ptr != nullptr);
+ CHECK_ADBC_ERROR(AdbcStatementSetSubstraitPlan(ptr, plan_ptr, plan_length,
&error),
+ error);
+ } catch (const AdbcException& e) {
+ e.ThrowJavaException(env);
+ }
+}
+
JNIEXPORT void JNICALL
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementBind(
JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jlong values,
jlong schema) {
@@ -1059,81 +1166,18 @@
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_connectionReadPartition(
auto* conn = reinterpret_cast<struct
AdbcConnection*>(static_cast<uintptr_t>(handle));
struct ArrowArrayStream out = {};
size_t serialized_length = 0;
- const uint8_t* serialized_partition = nullptr;
std::vector<uint8_t> allocated_partition;
try {
- jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
- jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
-
- if (!env->IsInstanceOf(partition, bb_class)) {
- ThrowJavaException(env, "java/lang/IllegalArgumentException",
- "Partition must be a ByteBuffer");
- return nullptr;
- }
- jint remaining = env->CallIntMethod(partition, bb_remaining);
- if (remaining < 0) {
- ThrowJavaException(env, "java/lang/IllegalArgumentException",
- "ByteBuffer remaining() must be non-negative");
- return nullptr;
+ const uint8_t* serialized_partition =
+ GetJniByteBuffer(env, partition, allocated_partition,
serialized_length);
+ if (!serialized_partition || env->ExceptionCheck()) {
+ return nullptr; // GetJniByteBuffer failed
}
- serialized_length = static_cast<size_t>(remaining);
-
- // fast path (if direct buffer)
- void* buf = env->GetDirectBufferAddress(partition);
- if (buf) {
- serialized_partition = static_cast<const uint8_t*>(buf);
- }
-
- // middle path (backing array)
- if (!serialized_partition) {
- jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
- jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
- jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset",
"()I");
- jboolean has_array = env->CallBooleanMethod(partition, bb_has_array);
- if (env->ExceptionCheck()) return nullptr;
- if (has_array) {
- jint array_offset = env->CallIntMethod(partition, bb_array_offset);
- if (env->ExceptionCheck()) return nullptr;
-
- auto array =
- reinterpret_cast<jbyteArray>(env->CallObjectMethod(partition,
bb_array));
- if (env->ExceptionCheck()) return nullptr;
-
- assert(serialized_length <=
static_cast<size_t>(env->GetArrayLength(array)));
- allocated_partition.resize(serialized_length);
- env->GetByteArrayRegion(array, array_offset,
- static_cast<jsize>(serialized_length),
-
reinterpret_cast<jbyte*>(allocated_partition.data()));
- serialized_partition = allocated_partition.data();
- }
- }
-
- // slow path (copy)
- if (!serialized_partition) {
- jmethodID bb_get = RequireMethod(env, bb_class, "get",
"([B)Ljava/nio/ByteBuffer;");
- jbyteArray temp =
env->NewByteArray(static_cast<jsize>(serialized_length));
- if (!temp) {
- ThrowJavaException(env, "java/lang/OutOfMemoryError",
- "Failed to allocate byte array for partition");
- return nullptr;
- }
-
- env->CallVoidMethod(partition, bb_get, temp);
- if (env->ExceptionCheck()) return nullptr;
-
- allocated_partition.resize(serialized_length);
- env->GetByteArrayRegion(temp, 0, static_cast<jsize>(serialized_length),
-
reinterpret_cast<jbyte*>(allocated_partition.data()));
- serialized_partition = allocated_partition.data();
- }
-
assert(serialized_partition != nullptr);
-
CHECK_ADBC_ERROR(AdbcConnectionReadPartition(conn, serialized_partition,
serialized_length, &out,
&error),
error);
-
return MakeNativeQueryResult(env, -1, &out);
} catch (const AdbcException& e) {
e.ThrowJavaException(env);
diff --git
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniDriver.java
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniDriver.java
index f118f8b08..9f33e0193 100644
---
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniDriver.java
+++
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniDriver.java
@@ -63,6 +63,12 @@ public class JniDriver implements AdbcDriver {
static final String ISOLATION_LEVEL_SNAPSHOT =
"adbc.connection.transaction.isolation.snapshot";
static final String ISOLATION_LEVEL_SERIALIZABLE =
"adbc.connection.transaction.isolation.serializable";
+ static final TypedKey<Double> PROGRESS =
+ new TypedKey<>("adbc.statement.exec.progress", Double.class);
+ static final TypedKey<Double> MAX_PROGRESS =
+ new TypedKey<>("adbc.statement.exec.max_progress", Double.class);
+ static final TypedKey<Boolean> INCREMENTAL =
+ new TypedKey<>("adbc.statement.exec.incremental", Boolean.class);
static final TypedKey<String> CURRENT_CATALOG =
new TypedKey<>("adbc.connection.catalog", String.class);
diff --git
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniStatement.java
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniStatement.java
index d3bc16ebc..0e5da15b2 100644
---
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniStatement.java
+++
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniStatement.java
@@ -18,6 +18,9 @@
package org.apache.arrow.adbc.driver.jni;
import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatement;
import org.apache.arrow.adbc.core.TypedKey;
@@ -57,6 +60,11 @@ public class JniStatement implements AdbcStatement {
JniLoader.INSTANCE.statementSetSqlQuery(handle, query);
}
+ @Override
+ public void setSubstraitPlan(ByteBuffer plan) throws AdbcException {
+ JniLoader.INSTANCE.statementSetSubstraitPlan(handle, plan);
+ }
+
@Override
public void bind(VectorSchemaRoot root) throws AdbcException {
clearBind();
@@ -86,6 +94,7 @@ public class JniStatement implements AdbcStatement {
// The C Data export takes ownership of the data at bind time and ignores
subsequent
// client changes to the bound root. Defer the export until execution so we
capture
// the final state of the VectorSchemaRoot.
+
private void exportBind() throws AdbcException {
if (bindRoot != null) {
try (final ArrowArray batch = ArrowArray.allocateNew(allocator);
@@ -143,6 +152,22 @@ public class JniStatement implements AdbcStatement {
JniLoader.INSTANCE.statementPrepare(handle);
}
+ public double getProgress() throws AdbcException {
+ return getOption(JniDriver.PROGRESS);
+ }
+
+ @Override
+ public double getMaxProgress() throws AdbcException {
+ return getOption(JniDriver.MAX_PROGRESS);
+ }
+
+ @Override
+ public Iterator<PartitionResult> pollPartitioned() throws AdbcException {
+ exportBind();
+ setOption(JniDriver.INCREMENTAL, true);
+ return new PartitionPollIterator();
+ }
+
@Override
public void close() throws AdbcException {
try {
@@ -206,4 +231,38 @@ public class JniStatement implements AdbcStatement {
"[jni] unsupported statement option type " + value.getClass());
}
}
+
+ public final class PartitionPollIterator implements
Iterator<PartitionResult> {
+ @Nullable PartitionResult nextResult;
+ boolean finished = false;
+
+ @Override
+ public boolean hasNext() {
+ if (finished) return false;
+ if (nextResult == null) {
+ NativePartitionResult result;
+ try {
+ result = JniLoader.INSTANCE.statementExecutePartitions(handle);
+ } catch (AdbcException e) {
+ throw new RuntimeException(e);
+ }
+ nextResult = result.importResult(allocator);
+ }
+ if (nextResult.getPartitionDescriptors().isEmpty()) {
+ finished = true;
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public PartitionResult next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException("No more partitions");
+ }
+ PartitionResult result = nextResult;
+ nextResult = null;
+ return result;
+ }
+ }
}
diff --git
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/JniLoader.java
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/JniLoader.java
index 500429a88..3174655c5 100644
---
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/JniLoader.java
+++
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/JniLoader.java
@@ -102,6 +102,11 @@ public enum JniLoader {
NativeAdbc.statementSetSqlQuery(statement.getStatementHandle(), query);
}
+ public void statementSetSubstraitPlan(NativeStatementHandle statement,
ByteBuffer plan)
+ throws AdbcException {
+ NativeAdbc.statementSetSubstraitPlan(statement.getStatementHandle(), plan);
+ }
+
public void statementBind(NativeStatementHandle statement, ArrowArray batch,
ArrowSchema schema)
throws AdbcException {
NativeAdbc.statementBind(
diff --git
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/NativeAdbc.java
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/NativeAdbc.java
index a7f2b4fbf..c2059532f 100644
---
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/NativeAdbc.java
+++
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/NativeAdbc.java
@@ -59,6 +59,8 @@ class NativeAdbc {
static native void statementSetSqlQuery(long handle, String query) throws
AdbcException;
+ static native void statementSetSubstraitPlan(long handle, ByteBuffer plan)
throws AdbcException;
+
static native byte[] statementGetOptionBytes(long handle, String key) throws
AdbcException;
static native double statementGetOptionDouble(long handle, String key)
throws AdbcException;