Copilot commented on code in PR #4362:
URL: https://github.com/apache/arrow-adbc/pull/4362#discussion_r3360865995


##########
java/driver/jni/src/main/cpp/jni_wrapper.cc:
##########
@@ -208,6 +208,79 @@ auto WithJniString(JNIEnv* env, jstring jni_string, 
Callable&& callable) {
   return callable(view.value);
 }
 
+// Get the contents of a ByteBuffer. We may have to copy, so also take a
+// scratch buffer to hold the copy if needed. If possible, will use
+// DirectByteBuffer so no additional allocation is necessary.
+const uint8_t* GetJniByteBuffer(JNIEnv* env, jobject bytebuffer,
+                                std::vector<uint8_t>& scratch,
+                                size_t& serialized_length) {
+  jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
+  jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
+
+  if (!env->IsInstanceOf(bytebuffer, bb_class)) {
+    ThrowJavaException(env, "java/lang/IllegalArgumentException",
+                       "Argument must be a ByteBuffer");
+    return nullptr;
+  }
+
+  jint remaining = env->CallIntMethod(bytebuffer, bb_remaining);
+  if (remaining < 0) {
+    ThrowJavaException(env, "java/lang/IllegalArgumentException",
+                       "ByteBuffer remaining() must be non-negative");
+    return nullptr;
+  }
+  serialized_length = static_cast<size_t>(remaining);
+
+  // fast path (if direct buffer)
+  void* buf = env->GetDirectBufferAddress(bytebuffer);
+  if (buf) {
+    return static_cast<const uint8_t*>(buf);
+  }

Review Comment:
   GetDirectBufferAddress() returns the base address of the direct ByteBuffer, 
but this helper ignores the buffer's current position. Since serialized_length 
is based on remaining() (limit - position), this will pass the wrong bytes 
whenever position != 0.



##########
java/driver/jni/src/main/cpp/jni_wrapper.cc:
##########
@@ -208,6 +208,79 @@ auto WithJniString(JNIEnv* env, jstring jni_string, 
Callable&& callable) {
   return callable(view.value);
 }
 
+// Get the contents of a ByteBuffer. We may have to copy, so also take a
+// scratch buffer to hold the copy if needed. If possible, will use
+// DirectByteBuffer so no additional allocation is necessary.
+const uint8_t* GetJniByteBuffer(JNIEnv* env, jobject bytebuffer,
+                                std::vector<uint8_t>& scratch,
+                                size_t& serialized_length) {
+  jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
+  jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
+
+  if (!env->IsInstanceOf(bytebuffer, bb_class)) {
+    ThrowJavaException(env, "java/lang/IllegalArgumentException",
+                       "Argument must be a ByteBuffer");
+    return nullptr;
+  }
+
+  jint remaining = env->CallIntMethod(bytebuffer, bb_remaining);
+  if (remaining < 0) {
+    ThrowJavaException(env, "java/lang/IllegalArgumentException",
+                       "ByteBuffer remaining() must be non-negative");
+    return nullptr;
+  }
+  serialized_length = static_cast<size_t>(remaining);
+
+  // fast path (if direct buffer)
+  void* buf = env->GetDirectBufferAddress(bytebuffer);
+  if (buf) {
+    return static_cast<const uint8_t*>(buf);
+  }
+
+  // middle path (copy from backing array)
+  {
+    jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
+    jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
+    jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset", 
"()I");
+    jboolean has_array = env->CallBooleanMethod(bytebuffer, bb_has_array);
+    if (env->ExceptionCheck()) return nullptr;
+    if (has_array) {
+      jint array_offset = env->CallIntMethod(bytebuffer, bb_array_offset);
+      if (env->ExceptionCheck()) return nullptr;
+
+      auto array =
+          reinterpret_cast<jbyteArray>(env->CallObjectMethod(bytebuffer, 
bb_array));
+      if (env->ExceptionCheck()) return nullptr;
+
+      assert(serialized_length <= 
static_cast<size_t>(env->GetArrayLength(array)));
+      scratch.resize(serialized_length);
+      env->GetByteArrayRegion(array, array_offset, 
static_cast<jsize>(serialized_length),
+                              reinterpret_cast<jbyte*>(scratch.data()));
+      return scratch.data();

Review Comment:
   When copying from a ByteBuffer backed by an array, this uses arrayOffset() 
but ignores the buffer's current position. For sliced/positioned buffers, this 
will copy from the wrong location (and the bounds assert doesn't account for 
offset+position).



##########
java/driver/jni/src/main/cpp/jni_wrapper.cc:
##########
@@ -475,16 +548,33 @@ 
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementExecuteQuery(
 JNIEXPORT void JNICALL
 Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementSetSqlQuery(
     JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jstring query) {
+  struct AdbcError error = ADBC_ERROR_INIT;
+  auto* ptr = reinterpret_cast<struct 
AdbcStatement*>(static_cast<uintptr_t>(handle));
   try {
-    struct AdbcError error = ADBC_ERROR_INIT;
-    auto* ptr = reinterpret_cast<struct 
AdbcStatement*>(static_cast<uintptr_t>(handle));
     JniStringView query_str(env, query);
     CHECK_ADBC_ERROR(AdbcStatementSetSqlQuery(ptr, query_str.value, &error), 
error);
   } catch (const AdbcException& e) {
     e.ThrowJavaException(env);
   }
 }
 
+JNIEXPORT void JNICALL
+Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementSetSubstraitPlan(
+    JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jobject plan) {
+  struct AdbcError error = ADBC_ERROR_INIT;
+  auto* ptr = reinterpret_cast<struct 
AdbcStatement*>(static_cast<uintptr_t>(handle));
+  std::vector<uint8_t> allocated_plan;
+  size_t plan_length = 0;
+  try {
+    const uint8_t* plan_ptr = GetJniByteBuffer(env, plan, allocated_plan, 
plan_length);
+    assert(plan_ptr != nullptr);
+    CHECK_ADBC_ERROR(AdbcStatementSetSubstraitPlan(ptr, plan_ptr, plan_length, 
&error),
+                     error);

Review Comment:
   GetJniByteBuffer can return nullptr after throwing a Java exception (e.g., 
wrong argument type). Asserting non-null and continuing risks passing a null 
pointer into AdbcStatementSetSubstraitPlan (asserts may be compiled out), 
leading to a crash. Return early when a Java exception is pending / pointer is 
null.



##########
java/driver/jni/src/main/cpp/jni_wrapper.cc:
##########
@@ -208,6 +208,79 @@ auto WithJniString(JNIEnv* env, jstring jni_string, 
Callable&& callable) {
   return callable(view.value);
 }
 
+// Get the contents of a ByteBuffer. We may have to copy, so also take a
+// scratch buffer to hold the copy if needed. If possible, will use
+// DirectByteBuffer so no additional allocation is necessary.
+const uint8_t* GetJniByteBuffer(JNIEnv* env, jobject bytebuffer,
+                                std::vector<uint8_t>& scratch,
+                                size_t& serialized_length) {
+  jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
+  jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
+
+  if (!env->IsInstanceOf(bytebuffer, bb_class)) {
+    ThrowJavaException(env, "java/lang/IllegalArgumentException",
+                       "Argument must be a ByteBuffer");
+    return nullptr;
+  }
+
+  jint remaining = env->CallIntMethod(bytebuffer, bb_remaining);
+  if (remaining < 0) {
+    ThrowJavaException(env, "java/lang/IllegalArgumentException",
+                       "ByteBuffer remaining() must be non-negative");
+    return nullptr;
+  }
+  serialized_length = static_cast<size_t>(remaining);
+
+  // fast path (if direct buffer)
+  void* buf = env->GetDirectBufferAddress(bytebuffer);
+  if (buf) {
+    return static_cast<const uint8_t*>(buf);
+  }
+
+  // middle path (copy from backing array)
+  {
+    jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
+    jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
+    jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset", 
"()I");
+    jboolean has_array = env->CallBooleanMethod(bytebuffer, bb_has_array);
+    if (env->ExceptionCheck()) return nullptr;
+    if (has_array) {
+      jint array_offset = env->CallIntMethod(bytebuffer, bb_array_offset);
+      if (env->ExceptionCheck()) return nullptr;
+
+      auto array =
+          reinterpret_cast<jbyteArray>(env->CallObjectMethod(bytebuffer, 
bb_array));
+      if (env->ExceptionCheck()) return nullptr;
+
+      assert(serialized_length <= 
static_cast<size_t>(env->GetArrayLength(array)));
+      scratch.resize(serialized_length);
+      env->GetByteArrayRegion(array, array_offset, 
static_cast<jsize>(serialized_length),
+                              reinterpret_cast<jbyte*>(scratch.data()));
+      return scratch.data();
+    }
+  }
+
+  // slow path (copy via invoking Java code to copy into a temp array)
+  {
+    jmethodID bb_get = RequireMethod(env, bb_class, "get", 
"([B)Ljava/nio/ByteBuffer;");
+    jbyteArray temp = env->NewByteArray(static_cast<jsize>(serialized_length));
+    if (!temp) {
+      ThrowJavaException(env, "java/lang/OutOfMemoryError",
+                         "Failed to allocate byte array to copy ByteBuffer");
+      return nullptr;
+    }
+
+    env->CallVoidMethod(bytebuffer, bb_get, temp);

Review Comment:
   The method ID for ByteBuffer.get(byte[]) has a non-void return type, but the 
JNI call uses CallVoidMethod. Using the wrong Call*Method variant is undefined 
behavior in JNI and can crash on some JVMs.



##########
java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java:
##########
@@ -69,30 +83,106 @@ void afterEach() throws Exception {
   @Test
   void simple() throws Exception {
     try (var stmt = conn.createStatement()) {
+      // smoke test; the server doesn't actually run any query and returns a 
fixed result
       stmt.setSqlQuery("SELECT 1 + 1 AS sum");
       try (var reader = stmt.executeQuery()) {
-        assertThat(reader.getReader().loadNextBatch()).isTrue();
-        
assertThat(reader.getReader().getVectorSchemaRoot().getVector("sum").getObject(0))
-            .isEqualTo(2L);
+        assertThat(ArrowToJava.toIntegers(reader.getReader(), 
"ints")).containsExactly(5);
       }
     }
   }
 
   @Test
-  void partitioned() throws Exception {
+  void substrait() throws Exception {
     try (var stmt = conn.createStatement()) {
-      stmt.setSqlQuery("SELECT 1 + 1 AS sum");
-      var partitions = stmt.executePartitioned();
-      assertThat(partitions.getPartitionDescriptors().size()).isEqualTo(1);
-      assertThat(partitions.getAffectedRows()).isEqualTo(-1);
-      // The test server doesn't give a schema.
-      assertThat(partitions.getSchema()).isNull();
-
-      try (var reader =
-          
conn.readPartition(partitions.getPartitionDescriptors().get(0).getDescriptor()))
 {
-        assertThat(reader.loadNextBatch()).isTrue();
-        
assertThat(reader.getVectorSchemaRoot().getVector("sum").getObject(0)).isEqualTo(2L);
+      stmt.setSubstraitPlan(ByteBuffer.wrap(new byte[] {42, 0, (byte) 129, 
(byte) 255}));
+      try (var reader = stmt.executeQuery()) {
+        assertThat(ArrowToJava.toIntegers(reader.getReader(), 
"substrait")).containsExactly(5);
       }
     }
   }
+
+  @Test
+  void poll() throws Exception {
+    var expectedSchema = new Schema(List.of(Field.nullable("ints", 
Types.MinorType.INT.getType())));
+    try (var stmt = conn.createStatement()) {
+      stmt.setSqlQuery("poll");
+      stmt.prepare();
+      var iter = stmt.pollPartitioned();
+      var elements = new ArrayList<AdbcStatement.PartitionResult>();
+      iter.forEachRemaining(elements::add);
+      assertThat(elements).size().isEqualTo(5);
+      assertThat(elements.get(0).getSchema()).isNull();
+      assertThat(elements.get(1).getSchema()).isNull();
+      
ArrowAssertions.assertSchema(elements.get(2).getSchema()).isEqualTo(expectedSchema);
+      
ArrowAssertions.assertSchema(elements.get(3).getSchema()).isEqualTo(expectedSchema);
+      
ArrowAssertions.assertSchema(elements.get(4).getSchema()).isEqualTo(expectedSchema);
+
+      elements.forEach(
+          partitionResult -> {
+            
assertThat(partitionResult.getPartitionDescriptors()).size().isEqualTo(1);
+          });
+    }
+  }
+
+  @Test
+  void progress() throws Exception {
+    try (var stmt = conn.createStatement()) {
+      stmt.setSqlQuery("poll");
+      stmt.prepare();
+      var iter = stmt.pollPartitioned();
+      assertThat(stmt.getProgress()).isEqualTo(0.0);
+      assertThat(stmt.getMaxProgress()).isEqualTo(1.0);
+
+      for (int i = 0; i < 5; i++) {
+        assertThat(iter.hasNext()).isTrue();
+        assertThat(stmt.getProgress()).isCloseTo(0.2 * (i + 1), 
Assertions.offset(0.05));
+        assertThat(stmt.getMaxProgress()).isEqualTo(1.0);
+        assertThat(iter.next()).isNotNull();
+      }
+
+      assertThat(iter.hasNext()).isFalse();
+    }
+  }
+
+  @Test
+  void pollError() throws Exception {
+    try (var stmt = conn.createStatement()) {
+      stmt.setSqlQuery("error_poll_later");
+      stmt.prepare();
+      var iter = stmt.pollPartitioned();
+      assertThat(iter.hasNext()).isTrue();
+      assertThat(iter.next()).isNotNull();
+      assertThatThrownBy(iter::hasNext).hasMessageContaining("expected error");
+    }
+  }
+
+  @Test
+  void pollForever() throws Exception {
+    try (var stmt = conn.createStatement()) {
+      stmt.setSqlQuery("forever");
+      stmt.prepare();
+      var iter = stmt.pollPartitioned();
+      assertThat(iter.hasNext()).isTrue();
+      assertThat(iter.next()).isNotNull();
+      // XXX: rather janky. must cancel on the background; cancelling in 
between statements has no
+      // effect. we should perhaps reconsider this behavior
+      var t =
+          new Thread(
+              () -> {
+                try {
+                  while (stmt.getProgress() < 0.05) {
+                    Thread.sleep(100);
+                  }
+                  System.out.println("progress: " + stmt.getProgress());
+                  stmt.cancel();
+                } catch (Exception e) {
+                  e.printStackTrace();
+                }

Review Comment:
   This background thread calls stmt.getProgress() while the main thread is 
concurrently polling the statement. Per AdbcStatement's thread-safety contract, 
only cancel() is required to be thread-safe, so this can lead to flaky tests 
depending on the driver implementation. Prefer waiting via sleep (or other 
external signal) and only calling cancel() from the background thread.



##########
java/driver/jni/src/main/cpp/jni_wrapper.cc:
##########
@@ -1059,81 +1149,15 @@ 
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_connectionReadPartition(
   auto* conn = reinterpret_cast<struct 
AdbcConnection*>(static_cast<uintptr_t>(handle));
   struct ArrowArrayStream out = {};
   size_t serialized_length = 0;
-  const uint8_t* serialized_partition = nullptr;
   std::vector<uint8_t> allocated_partition;
 
   try {
-    jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
-    jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
-
-    if (!env->IsInstanceOf(partition, bb_class)) {
-      ThrowJavaException(env, "java/lang/IllegalArgumentException",
-                         "Partition must be a ByteBuffer");
-      return nullptr;
-    }
-    jint remaining = env->CallIntMethod(partition, bb_remaining);
-    if (remaining < 0) {
-      ThrowJavaException(env, "java/lang/IllegalArgumentException",
-                         "ByteBuffer remaining() must be non-negative");
-      return nullptr;
-    }
-    serialized_length = static_cast<size_t>(remaining);
-
-    // fast path (if direct buffer)
-    void* buf = env->GetDirectBufferAddress(partition);
-    if (buf) {
-      serialized_partition = static_cast<const uint8_t*>(buf);
-    }
-
-    // middle path (backing array)
-    if (!serialized_partition) {
-      jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
-      jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
-      jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset", 
"()I");
-      jboolean has_array = env->CallBooleanMethod(partition, bb_has_array);
-      if (env->ExceptionCheck()) return nullptr;
-      if (has_array) {
-        jint array_offset = env->CallIntMethod(partition, bb_array_offset);
-        if (env->ExceptionCheck()) return nullptr;
-
-        auto array =
-            reinterpret_cast<jbyteArray>(env->CallObjectMethod(partition, 
bb_array));
-        if (env->ExceptionCheck()) return nullptr;
-
-        assert(serialized_length <= 
static_cast<size_t>(env->GetArrayLength(array)));
-        allocated_partition.resize(serialized_length);
-        env->GetByteArrayRegion(array, array_offset,
-                                static_cast<jsize>(serialized_length),
-                                
reinterpret_cast<jbyte*>(allocated_partition.data()));
-        serialized_partition = allocated_partition.data();
-      }
-    }
-
-    // slow path (copy)
-    if (!serialized_partition) {
-      jmethodID bb_get = RequireMethod(env, bb_class, "get", 
"([B)Ljava/nio/ByteBuffer;");
-      jbyteArray temp = 
env->NewByteArray(static_cast<jsize>(serialized_length));
-      if (!temp) {
-        ThrowJavaException(env, "java/lang/OutOfMemoryError",
-                           "Failed to allocate byte array for partition");
-        return nullptr;
-      }
-
-      env->CallVoidMethod(partition, bb_get, temp);
-      if (env->ExceptionCheck()) return nullptr;
-
-      allocated_partition.resize(serialized_length);
-      env->GetByteArrayRegion(temp, 0, static_cast<jsize>(serialized_length),
-                              
reinterpret_cast<jbyte*>(allocated_partition.data()));
-      serialized_partition = allocated_partition.data();
-    }
-
+    const uint8_t* serialized_partition =
+        GetJniByteBuffer(env, partition, allocated_partition, 
serialized_length);
     assert(serialized_partition != nullptr);
-
     CHECK_ADBC_ERROR(AdbcConnectionReadPartition(conn, serialized_partition,
                                                  serialized_length, &out, 
&error),

Review Comment:
   GetJniByteBuffer can return nullptr after throwing a Java exception (e.g., 
wrong argument type). Asserting non-null and continuing risks passing a null 
pointer into AdbcConnectionReadPartition (asserts may be compiled out), leading 
to a crash. Return early when a Java exception is pending / pointer is null.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to