This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 13d2759fb feat(java/driver/jni): implement incremental execution, 
Substrait (#4362)
13d2759fb is described below

commit 13d2759fb97712a824db612bc38c91558abf7ad8
Author: David Li <[email protected]>
AuthorDate: Mon Jun 8 17:29:01 2026 -0700

    feat(java/driver/jni): implement incremental execution, Substrait (#4362)
    
    Also, fix an incidental bug in the Go driver FFI layer.
    
    Closes #3257.
---
 .github/workflows/java.yml                         |   2 +-
 go/adbc/driver/flightsql/cmd/testserver/main.go    |  32 +++-
 go/adbc/pkg/_tmpl/driver.go.tmpl                   |   4 +-
 go/adbc/pkg/flightsql/driver.go                    |   4 +-
 go/adbc/pkg/panicdummy/driver.go                   |   4 +-
 .../adbc/driver/jni/FlightSqlIntegrationTest.java  | 122 ++++++++++++--
 ...st.java => FlightSqlSqliteIntegrationTest.java} |   3 +-
 java/driver/jni/src/main/cpp/jni_wrapper.cc        | 182 +++++++++++++--------
 .../apache/arrow/adbc/driver/jni/JniDriver.java    |   6 +
 .../apache/arrow/adbc/driver/jni/JniStatement.java |  59 +++++++
 .../arrow/adbc/driver/jni/impl/JniLoader.java      |   5 +
 .../arrow/adbc/driver/jni/impl/NativeAdbc.java     |   2 +
 12 files changed, 333 insertions(+), 92 deletions(-)

diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml
index f2a3f0a20..d2ac4b9cf 100644
--- a/.github/workflows/java.yml
+++ b/.github/workflows/java.yml
@@ -233,7 +233,7 @@ jobs:
       - name: Start Dependencies
         if: matrix.os == 'Linux' && matrix.arch == 'amd64'
         run: |
-          docker compose up --detach --wait flightsql-sqlite-test mssql-test 
postgres-test
+          docker compose up --detach --wait flightsql-test 
flightsql-sqlite-test mssql-test postgres-test
           cat .env | grep -v -e '^#' | grep -e '^ADBC_' | awk NF | sed 
's/"//g' | tee -a $GITHUB_ENV
 
       - name: Download thirdparty driver
diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go 
b/go/adbc/driver/flightsql/cmd/testserver/main.go
index 577bf3983..5a38b3e11 100644
--- a/go/adbc/driver/flightsql/cmd/testserver/main.go
+++ b/go/adbc/driver/flightsql/cmd/testserver/main.go
@@ -38,6 +38,7 @@ import (
        "math/big"
        "net"
        "os"
+       "slices"
        "strconv"
        "strings"
        "sync"
@@ -79,6 +80,9 @@ var recordedHeadersSchema = arrow.NewSchema([]arrow.Field{
        {Name: "value", Type: arrow.BinaryTypes.String, Nullable: false},
 }, nil)
 
+// Not a real plan, just ensuring binary values get passed through correctly
+var substraitPlan = []byte{42, 0, 129, 255}
+
 func StatusWithDetail(code codes.Code, message string, details 
...proto.Message) error {
        p := status.New(code, message).Proto()
        // Have to do this by hand because gRPC uses deprecated proto import
@@ -184,6 +188,22 @@ func (srv *ExampleServer) GetFlightInfoStatement(ctx 
context.Context, cmd flight
        }, nil
 }
 
+func (srv *ExampleServer) GetFlightInfoSubstraitPlan(ctx context.Context, cmd 
flightsql.StatementSubstraitPlan, desc *flight.FlightDescriptor) 
(*flight.FlightInfo, error) {
+       srv.recordHeaders(ctx, "GetFlightInfoSubstraitPlan")
+       log.Printf("GetFlightInfoSubstraitPlan: %v", cmd.GetPlan().Plan)
+       ticket, err := flightsql.CreateStatementQueryTicket(cmd.GetPlan().Plan)
+       if err != nil {
+               return nil, err
+       }
+
+       return &flight.FlightInfo{
+               Endpoint:         []*flight.FlightEndpoint{{Ticket: 
&flight.Ticket{Ticket: ticket}}},
+               FlightDescriptor: desc,
+               TotalRecords:     -1,
+               TotalBytes:       -1,
+       }, nil
+}
+
 func (srv *ExampleServer) PollFlightInfo(ctx context.Context, desc 
*flight.FlightDescriptor) (*flight.PollInfo, error) {
        srv.mu.Lock()
        defer srv.mu.Unlock()
@@ -406,8 +426,16 @@ func (srv *ExampleServer) DoGetPreparedStatement(ctx 
context.Context, cmd flight
 }
 
 func (srv *ExampleServer) DoGetStatement(ctx context.Context, cmd 
flightsql.StatementQueryTicket) (schema *arrow.Schema, out <-chan 
flight.StreamChunk, err error) {
-       schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
-       rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, 
strings.NewReader(`[{"ints": 5}]`))
+       log.Printf("DoGetStatement: %v", cmd.GetStatementHandle())
+
+       var rec arrow.RecordBatch
+       if slices.Equal(cmd.GetStatementHandle(), substraitPlan) {
+               schema = arrow.NewSchema([]arrow.Field{{Name: "substrait", 
Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+               rec, _, err = array.RecordFromJSON(memory.DefaultAllocator, 
schema, strings.NewReader(`[{"substrait": 5}]`))
+       } else {
+               schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+               rec, _, err = array.RecordFromJSON(memory.DefaultAllocator, 
schema, strings.NewReader(`[{"ints": 5}]`))
+       }
 
        ch := make(chan flight.StreamChunk)
        go func() {
diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index 05b2b11a3..673e73265 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -1618,7 +1618,9 @@ func {{.Prefix}}StatementSetSubstraitPlan(stmt 
*C.struct_AdbcStatement, plan *C.
                return C.ADBC_STATUS_INVALID_STATE
        }
 
-       return C.AdbcStatusCode(errToAdbcErr(err, 
st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length)))))
+       // XXX: must copy here or the C side may invalidate *plan
+       goPlan := C.GoBytes(unsafe.Pointer(plan), C.int(length))
+       return C.AdbcStatusCode(errToAdbcErr(err, 
st.stmt.SetSubstraitPlan(goPlan)))
 }
 
 //export {{.Prefix}}StatementBind
diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go
index 3cecf4f79..46506c415 100644
--- a/go/adbc/pkg/flightsql/driver.go
+++ b/go/adbc/pkg/flightsql/driver.go
@@ -1621,7 +1621,9 @@ func FlightSQLStatementSetSubstraitPlan(stmt 
*C.struct_AdbcStatement, plan *C.cu
                return C.ADBC_STATUS_INVALID_STATE
        }
 
-       return C.AdbcStatusCode(errToAdbcErr(err, 
st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length)))))
+       // XXX: must copy here or the C side may invalidate *plan
+       goPlan := C.GoBytes(unsafe.Pointer(plan), C.int(length))
+       return C.AdbcStatusCode(errToAdbcErr(err, 
st.stmt.SetSubstraitPlan(goPlan)))
 }
 
 //export FlightSQLStatementBind
diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go
index 07bbc4ac6..2deea3f88 100644
--- a/go/adbc/pkg/panicdummy/driver.go
+++ b/go/adbc/pkg/panicdummy/driver.go
@@ -1621,7 +1621,9 @@ func PanicDummyStatementSetSubstraitPlan(stmt 
*C.struct_AdbcStatement, plan *C.c
                return C.ADBC_STATUS_INVALID_STATE
        }
 
-       return C.AdbcStatusCode(errToAdbcErr(err, 
st.stmt.SetSubstraitPlan(fromCArr[byte](plan, int(length)))))
+       // XXX: must copy here or the C side may invalidate *plan
+       goPlan := C.GoBytes(unsafe.Pointer(plan), C.int(length))
+       return C.AdbcStatusCode(errToAdbcErr(err, 
st.stmt.SetSubstraitPlan(goPlan)))
 }
 
 //export PanicDummyStatementBind
diff --git 
a/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
 
b/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
index b5ab1c7ee..c63bf8dd0 100644
--- 
a/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
+++ 
b/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
@@ -18,22 +18,34 @@
 package org.apache.arrow.adbc.driver.jni;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import org.apache.arrow.adbc.core.AdbcConnection;
 import org.apache.arrow.adbc.core.AdbcDatabase;
 import org.apache.arrow.adbc.core.AdbcDriver;
+import org.apache.arrow.adbc.core.AdbcStatement;
+import org.apache.arrow.adbc.driver.testsuite.ArrowAssertions;
+import org.apache.arrow.adbc.driver.testsuite.ArrowToJava;
 import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.types.Types;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.assertj.core.api.Assertions;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.Assumptions;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
+/** Integration tests against the mock test server. */
 public class FlightSqlIntegrationTest {
-  public static final String URI_ENV = "ADBC_SQLITE_FLIGHTSQL_URI";
+  public static final String URI_ENV = "ADBC_TEST_FLIGHTSQL_URI";
   static String URI = System.getenv(URI_ENV);
 
   BufferAllocator allocator;
@@ -55,6 +67,8 @@ public class FlightSqlIntegrationTest {
     Map<String, Object> parameters = new HashMap<>();
     JniDriver.PARAM_DRIVER.set(parameters, "adbc_driver_flightsql");
     AdbcDriver.PARAM_URI.set(parameters, URI);
+    parameters.put("adbc.flight.sql.client_option.tls_skip_verify", "true");
+    parameters.put("adbc.flight.sql.authorization_header", "Basic 
dXNlcjpwYXNzd29yZA==");
     db = driver.open(parameters);
     conn = db.connect();
   }
@@ -69,30 +83,106 @@ public class FlightSqlIntegrationTest {
   @Test
   void simple() throws Exception {
     try (var stmt = conn.createStatement()) {
+      // smoke test; the server doesn't actually run any query and returns a 
fixed result
       stmt.setSqlQuery("SELECT 1 + 1 AS sum");
       try (var reader = stmt.executeQuery()) {
-        assertThat(reader.getReader().loadNextBatch()).isTrue();
-        
assertThat(reader.getReader().getVectorSchemaRoot().getVector("sum").getObject(0))
-            .isEqualTo(2L);
+        assertThat(ArrowToJava.toIntegers(reader.getReader(), 
"ints")).containsExactly(5);
       }
     }
   }
 
   @Test
-  void partitioned() throws Exception {
+  void substrait() throws Exception {
     try (var stmt = conn.createStatement()) {
-      stmt.setSqlQuery("SELECT 1 + 1 AS sum");
-      var partitions = stmt.executePartitioned();
-      assertThat(partitions.getPartitionDescriptors().size()).isEqualTo(1);
-      assertThat(partitions.getAffectedRows()).isEqualTo(-1);
-      // The test server doesn't give a schema.
-      assertThat(partitions.getSchema()).isNull();
-
-      try (var reader =
-          
conn.readPartition(partitions.getPartitionDescriptors().get(0).getDescriptor()))
 {
-        assertThat(reader.loadNextBatch()).isTrue();
-        
assertThat(reader.getVectorSchemaRoot().getVector("sum").getObject(0)).isEqualTo(2L);
+      stmt.setSubstraitPlan(ByteBuffer.wrap(new byte[] {42, 0, (byte) 129, 
(byte) 255}));
+      try (var reader = stmt.executeQuery()) {
+        assertThat(ArrowToJava.toIntegers(reader.getReader(), 
"substrait")).containsExactly(5);
       }
     }
   }
+
+  @Test
+  void poll() throws Exception {
+    var expectedSchema = new Schema(List.of(Field.nullable("ints", 
Types.MinorType.INT.getType())));
+    try (var stmt = conn.createStatement()) {
+      stmt.setSqlQuery("poll");
+      stmt.prepare();
+      var iter = stmt.pollPartitioned();
+      var elements = new ArrayList<AdbcStatement.PartitionResult>();
+      iter.forEachRemaining(elements::add);
+      assertThat(elements).size().isEqualTo(5);
+      assertThat(elements.get(0).getSchema()).isNull();
+      assertThat(elements.get(1).getSchema()).isNull();
+      
ArrowAssertions.assertSchema(elements.get(2).getSchema()).isEqualTo(expectedSchema);
+      
ArrowAssertions.assertSchema(elements.get(3).getSchema()).isEqualTo(expectedSchema);
+      
ArrowAssertions.assertSchema(elements.get(4).getSchema()).isEqualTo(expectedSchema);
+
+      elements.forEach(
+          partitionResult -> {
+            
assertThat(partitionResult.getPartitionDescriptors()).size().isEqualTo(1);
+          });
+    }
+  }
+
+  @Test
+  void progress() throws Exception {
+    try (var stmt = conn.createStatement()) {
+      stmt.setSqlQuery("poll");
+      stmt.prepare();
+      var iter = stmt.pollPartitioned();
+      assertThat(stmt.getProgress()).isEqualTo(0.0);
+      assertThat(stmt.getMaxProgress()).isEqualTo(1.0);
+
+      for (int i = 0; i < 5; i++) {
+        assertThat(iter.hasNext()).isTrue();
+        assertThat(stmt.getProgress()).isCloseTo(0.2 * (i + 1), 
Assertions.offset(0.05));
+        assertThat(stmt.getMaxProgress()).isEqualTo(1.0);
+        assertThat(iter.next()).isNotNull();
+      }
+
+      assertThat(iter.hasNext()).isFalse();
+    }
+  }
+
+  @Test
+  void pollError() throws Exception {
+    try (var stmt = conn.createStatement()) {
+      stmt.setSqlQuery("error_poll_later");
+      stmt.prepare();
+      var iter = stmt.pollPartitioned();
+      assertThat(iter.hasNext()).isTrue();
+      assertThat(iter.next()).isNotNull();
+      assertThatThrownBy(iter::hasNext).hasMessageContaining("expected error");
+    }
+  }
+
+  @Test
+  void pollForever() throws Exception {
+    try (var stmt = conn.createStatement()) {
+      stmt.setSqlQuery("forever");
+      stmt.prepare();
+      var iter = stmt.pollPartitioned();
+      assertThat(iter.hasNext()).isTrue();
+      assertThat(iter.next()).isNotNull();
+      // XXX: rather janky. must cancel on the background; cancelling in 
between statements has no
+      // effect. we should perhaps reconsider this behavior
+      var t =
+          new Thread(
+              () -> {
+                try {
+                  while (stmt.getProgress() < 0.05) {
+                    Thread.sleep(100);
+                  }
+                  System.out.println("progress: " + stmt.getProgress());
+                  stmt.cancel();
+                } catch (Exception e) {
+                  e.printStackTrace();
+                }
+              });
+      t.start();
+
+      assertThatThrownBy(iter::hasNext).hasMessageContaining("context 
canceled");
+      t.join(10000);
+    }
+  }
 }
diff --git 
a/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
 
b/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlSqliteIntegrationTest.java
similarity index 97%
copy from 
java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
copy to 
java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlSqliteIntegrationTest.java
index b5ab1c7ee..e73f7bb6a 100644
--- 
a/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java
+++ 
b/java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlSqliteIntegrationTest.java
@@ -32,7 +32,8 @@ import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
-public class FlightSqlIntegrationTest {
+/** Integration tests against the SQLite server. */
+public class FlightSqlSqliteIntegrationTest {
   public static final String URI_ENV = "ADBC_SQLITE_FLIGHTSQL_URI";
   static String URI = System.getenv(URI_ENV);
 
diff --git a/java/driver/jni/src/main/cpp/jni_wrapper.cc 
b/java/driver/jni/src/main/cpp/jni_wrapper.cc
index d1f8309b3..7a8a73d53 100644
--- a/java/driver/jni/src/main/cpp/jni_wrapper.cc
+++ b/java/driver/jni/src/main/cpp/jni_wrapper.cc
@@ -208,6 +208,93 @@ auto WithJniString(JNIEnv* env, jstring jni_string, 
Callable&& callable) {
   return callable(view.value);
 }
 
+// Get the contents of a ByteBuffer. We may have to copy, so also take a
+// scratch buffer to hold the copy if needed. If possible, will use
+// DirectByteBuffer so no additional allocation is necessary.
+const uint8_t* GetJniByteBuffer(JNIEnv* env, jobject bytebuffer,
+                                std::vector<uint8_t>& scratch,
+                                size_t& serialized_length) {
+  jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
+  if (!env->IsInstanceOf(bytebuffer, bb_class)) {
+    ThrowJavaException(env, "java/lang/IllegalArgumentException",
+                       "Argument must be a ByteBuffer");
+    return nullptr;
+  }
+
+  jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
+  jmethodID bb_position = RequireMethod(env, bb_class, "position", "()I");
+
+  jint remaining = env->CallIntMethod(bytebuffer, bb_remaining);
+  if (remaining < 0) {
+    ThrowJavaException(env, "java/lang/IllegalArgumentException",
+                       "ByteBuffer remaining() must be non-negative");
+    return nullptr;
+  } else if (env->ExceptionCheck()) {
+    return nullptr;
+  }
+  serialized_length = static_cast<size_t>(remaining);
+
+  jint position = env->CallIntMethod(bytebuffer, bb_position);
+  if (position < 0) {
+    ThrowJavaException(env, "java/lang/IllegalArgumentException",
+                       "ByteBuffer position() must be non-negative");
+    return nullptr;
+  } else if (env->ExceptionCheck()) {
+    return nullptr;
+  }
+  const size_t offset = static_cast<size_t>(position);
+
+  // fast path (if direct buffer)
+  void* buf = env->GetDirectBufferAddress(bytebuffer);
+  if (buf) {
+    return static_cast<const uint8_t*>(buf) + offset;
+  }
+
+  // middle path (copy from backing array)
+  {
+    jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
+    jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
+    jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset", 
"()I");
+    jboolean has_array = env->CallBooleanMethod(bytebuffer, bb_has_array);
+    if (env->ExceptionCheck()) return nullptr;
+    if (has_array) {
+      jint array_offset = env->CallIntMethod(bytebuffer, bb_array_offset);
+      if (env->ExceptionCheck()) return nullptr;
+
+      auto array =
+          reinterpret_cast<jbyteArray>(env->CallObjectMethod(bytebuffer, 
bb_array));
+      if (env->ExceptionCheck()) return nullptr;
+
+      assert(serialized_length <= 
static_cast<size_t>(env->GetArrayLength(array)));
+      scratch.resize(serialized_length);
+      env->GetByteArrayRegion(array, array_offset + position,
+                              static_cast<jsize>(serialized_length),
+                              reinterpret_cast<jbyte*>(scratch.data()));
+      return scratch.data();
+    }
+  }
+
+  // slow path (copy via invoking Java code to copy into a temp array)
+  {
+    jmethodID bb_get = RequireMethod(env, bb_class, "get", 
"([B)Ljava/nio/ByteBuffer;");
+    jbyteArray temp = env->NewByteArray(static_cast<jsize>(serialized_length));
+    if (!temp) {
+      ThrowJavaException(env, "java/lang/OutOfMemoryError",
+                         "Failed to allocate byte array to copy ByteBuffer");
+      return nullptr;
+    }
+
+    env->CallObjectMethod(bytebuffer, bb_get, temp);
+    if (env->ExceptionCheck()) return nullptr;
+
+    scratch.resize(serialized_length);
+    env->GetByteArrayRegion(temp, 0, static_cast<jsize>(serialized_length),
+                            reinterpret_cast<jbyte*>(scratch.data()));
+    return scratch.data();
+  }
+  // not reachable
+}
+
 }  // namespace
 
 extern "C" {
@@ -475,9 +562,9 @@ 
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementExecuteQuery(
 JNIEXPORT void JNICALL
 Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementSetSqlQuery(
     JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jstring query) {
+  struct AdbcError error = ADBC_ERROR_INIT;
+  auto* ptr = reinterpret_cast<struct 
AdbcStatement*>(static_cast<uintptr_t>(handle));
   try {
-    struct AdbcError error = ADBC_ERROR_INIT;
-    auto* ptr = reinterpret_cast<struct 
AdbcStatement*>(static_cast<uintptr_t>(handle));
     JniStringView query_str(env, query);
     CHECK_ADBC_ERROR(AdbcStatementSetSqlQuery(ptr, query_str.value, &error), 
error);
   } catch (const AdbcException& e) {
@@ -485,6 +572,26 @@ 
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementSetSqlQuery(
   }
 }
 
+JNIEXPORT void JNICALL
+Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementSetSubstraitPlan(
+    JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jobject plan) {
+  struct AdbcError error = ADBC_ERROR_INIT;
+  auto* ptr = reinterpret_cast<struct 
AdbcStatement*>(static_cast<uintptr_t>(handle));
+  std::vector<uint8_t> allocated_plan;
+  size_t plan_length = 0;
+  try {
+    const uint8_t* plan_ptr = GetJniByteBuffer(env, plan, allocated_plan, 
plan_length);
+    if (!plan_ptr || env->ExceptionCheck()) {
+      return;  // GetJniByteBuffer failed
+    }
+    assert(plan_ptr != nullptr);
+    CHECK_ADBC_ERROR(AdbcStatementSetSubstraitPlan(ptr, plan_ptr, plan_length, 
&error),
+                     error);
+  } catch (const AdbcException& e) {
+    e.ThrowJavaException(env);
+  }
+}
+
 JNIEXPORT void JNICALL
 Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementBind(
     JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jlong values, 
jlong schema) {
@@ -1059,81 +1166,18 @@ 
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_connectionReadPartition(
   auto* conn = reinterpret_cast<struct 
AdbcConnection*>(static_cast<uintptr_t>(handle));
   struct ArrowArrayStream out = {};
   size_t serialized_length = 0;
-  const uint8_t* serialized_partition = nullptr;
   std::vector<uint8_t> allocated_partition;
 
   try {
-    jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
-    jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
-
-    if (!env->IsInstanceOf(partition, bb_class)) {
-      ThrowJavaException(env, "java/lang/IllegalArgumentException",
-                         "Partition must be a ByteBuffer");
-      return nullptr;
-    }
-    jint remaining = env->CallIntMethod(partition, bb_remaining);
-    if (remaining < 0) {
-      ThrowJavaException(env, "java/lang/IllegalArgumentException",
-                         "ByteBuffer remaining() must be non-negative");
-      return nullptr;
+    const uint8_t* serialized_partition =
+        GetJniByteBuffer(env, partition, allocated_partition, 
serialized_length);
+    if (!serialized_partition || env->ExceptionCheck()) {
+      return nullptr;  // GetJniByteBuffer failed
     }
-    serialized_length = static_cast<size_t>(remaining);
-
-    // fast path (if direct buffer)
-    void* buf = env->GetDirectBufferAddress(partition);
-    if (buf) {
-      serialized_partition = static_cast<const uint8_t*>(buf);
-    }
-
-    // middle path (backing array)
-    if (!serialized_partition) {
-      jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
-      jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
-      jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset", 
"()I");
-      jboolean has_array = env->CallBooleanMethod(partition, bb_has_array);
-      if (env->ExceptionCheck()) return nullptr;
-      if (has_array) {
-        jint array_offset = env->CallIntMethod(partition, bb_array_offset);
-        if (env->ExceptionCheck()) return nullptr;
-
-        auto array =
-            reinterpret_cast<jbyteArray>(env->CallObjectMethod(partition, 
bb_array));
-        if (env->ExceptionCheck()) return nullptr;
-
-        assert(serialized_length <= 
static_cast<size_t>(env->GetArrayLength(array)));
-        allocated_partition.resize(serialized_length);
-        env->GetByteArrayRegion(array, array_offset,
-                                static_cast<jsize>(serialized_length),
-                                
reinterpret_cast<jbyte*>(allocated_partition.data()));
-        serialized_partition = allocated_partition.data();
-      }
-    }
-
-    // slow path (copy)
-    if (!serialized_partition) {
-      jmethodID bb_get = RequireMethod(env, bb_class, "get", 
"([B)Ljava/nio/ByteBuffer;");
-      jbyteArray temp = 
env->NewByteArray(static_cast<jsize>(serialized_length));
-      if (!temp) {
-        ThrowJavaException(env, "java/lang/OutOfMemoryError",
-                           "Failed to allocate byte array for partition");
-        return nullptr;
-      }
-
-      env->CallVoidMethod(partition, bb_get, temp);
-      if (env->ExceptionCheck()) return nullptr;
-
-      allocated_partition.resize(serialized_length);
-      env->GetByteArrayRegion(temp, 0, static_cast<jsize>(serialized_length),
-                              
reinterpret_cast<jbyte*>(allocated_partition.data()));
-      serialized_partition = allocated_partition.data();
-    }
-
     assert(serialized_partition != nullptr);
-
     CHECK_ADBC_ERROR(AdbcConnectionReadPartition(conn, serialized_partition,
                                                  serialized_length, &out, 
&error),
                      error);
-
     return MakeNativeQueryResult(env, -1, &out);
   } catch (const AdbcException& e) {
     e.ThrowJavaException(env);
diff --git 
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniDriver.java 
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniDriver.java
index f118f8b08..9f33e0193 100644
--- 
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniDriver.java
+++ 
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniDriver.java
@@ -63,6 +63,12 @@ public class JniDriver implements AdbcDriver {
   static final String ISOLATION_LEVEL_SNAPSHOT = 
"adbc.connection.transaction.isolation.snapshot";
   static final String ISOLATION_LEVEL_SERIALIZABLE =
       "adbc.connection.transaction.isolation.serializable";
+  static final TypedKey<Double> PROGRESS =
+      new TypedKey<>("adbc.statement.exec.progress", Double.class);
+  static final TypedKey<Double> MAX_PROGRESS =
+      new TypedKey<>("adbc.statement.exec.max_progress", Double.class);
+  static final TypedKey<Boolean> INCREMENTAL =
+      new TypedKey<>("adbc.statement.exec.incremental", Boolean.class);
 
   static final TypedKey<String> CURRENT_CATALOG =
       new TypedKey<>("adbc.connection.catalog", String.class);
diff --git 
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniStatement.java
 
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniStatement.java
index d3bc16ebc..0e5da15b2 100644
--- 
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniStatement.java
+++ 
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/JniStatement.java
@@ -18,6 +18,9 @@
 package org.apache.arrow.adbc.driver.jni;
 
 import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
 import org.apache.arrow.adbc.core.AdbcException;
 import org.apache.arrow.adbc.core.AdbcStatement;
 import org.apache.arrow.adbc.core.TypedKey;
@@ -57,6 +60,11 @@ public class JniStatement implements AdbcStatement {
     JniLoader.INSTANCE.statementSetSqlQuery(handle, query);
   }
 
+  @Override
+  public void setSubstraitPlan(ByteBuffer plan) throws AdbcException {
+    JniLoader.INSTANCE.statementSetSubstraitPlan(handle, plan);
+  }
+
   @Override
   public void bind(VectorSchemaRoot root) throws AdbcException {
     clearBind();
@@ -86,6 +94,7 @@ public class JniStatement implements AdbcStatement {
   // The C Data export takes ownership of the data at bind time and ignores 
subsequent
   // client changes to the bound root. Defer the export until execution so we 
capture
   // the final state of the VectorSchemaRoot.
+
   private void exportBind() throws AdbcException {
     if (bindRoot != null) {
       try (final ArrowArray batch = ArrowArray.allocateNew(allocator);
@@ -143,6 +152,22 @@ public class JniStatement implements AdbcStatement {
     JniLoader.INSTANCE.statementPrepare(handle);
   }
 
+  public double getProgress() throws AdbcException {
+    return getOption(JniDriver.PROGRESS);
+  }
+
+  @Override
+  public double getMaxProgress() throws AdbcException {
+    return getOption(JniDriver.MAX_PROGRESS);
+  }
+
+  @Override
+  public Iterator<PartitionResult> pollPartitioned() throws AdbcException {
+    exportBind();
+    setOption(JniDriver.INCREMENTAL, true);
+    return new PartitionPollIterator();
+  }
+
   @Override
   public void close() throws AdbcException {
     try {
@@ -206,4 +231,38 @@ public class JniStatement implements AdbcStatement {
           "[jni] unsupported statement option type " + value.getClass());
     }
   }
+
+  public final class PartitionPollIterator implements 
Iterator<PartitionResult> {
+    @Nullable PartitionResult nextResult;
+    boolean finished = false;
+
+    @Override
+    public boolean hasNext() {
+      if (finished) return false;
+      if (nextResult == null) {
+        NativePartitionResult result;
+        try {
+          result = JniLoader.INSTANCE.statementExecutePartitions(handle);
+        } catch (AdbcException e) {
+          throw new RuntimeException(e);
+        }
+        nextResult = result.importResult(allocator);
+      }
+      if (nextResult.getPartitionDescriptors().isEmpty()) {
+        finished = true;
+        return false;
+      }
+      return true;
+    }
+
+    @Override
+    public PartitionResult next() {
+      if (!hasNext()) {
+        throw new NoSuchElementException("No more partitions");
+      }
+      PartitionResult result = nextResult;
+      nextResult = null;
+      return result;
+    }
+  }
 }
diff --git 
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/JniLoader.java
 
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/JniLoader.java
index 500429a88..3174655c5 100644
--- 
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/JniLoader.java
+++ 
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/JniLoader.java
@@ -102,6 +102,11 @@ public enum JniLoader {
     NativeAdbc.statementSetSqlQuery(statement.getStatementHandle(), query);
   }
 
+  public void statementSetSubstraitPlan(NativeStatementHandle statement, 
ByteBuffer plan)
+      throws AdbcException {
+    NativeAdbc.statementSetSubstraitPlan(statement.getStatementHandle(), plan);
+  }
+
   public void statementBind(NativeStatementHandle statement, ArrowArray batch, 
ArrowSchema schema)
       throws AdbcException {
     NativeAdbc.statementBind(
diff --git 
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/NativeAdbc.java
 
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/NativeAdbc.java
index a7f2b4fbf..c2059532f 100644
--- 
a/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/NativeAdbc.java
+++ 
b/java/driver/jni/src/main/java/org/apache/arrow/adbc/driver/jni/impl/NativeAdbc.java
@@ -59,6 +59,8 @@ class NativeAdbc {
 
   static native void statementSetSqlQuery(long handle, String query) throws 
AdbcException;
 
+  static native void statementSetSubstraitPlan(long handle, ByteBuffer plan) 
throws AdbcException;
+
   static native byte[] statementGetOptionBytes(long handle, String key) throws 
AdbcException;
 
   static native double statementGetOptionDouble(long handle, String key) 
throws AdbcException;

Reply via email to