Copilot commented on code in PR #4362:
URL: https://github.com/apache/arrow-adbc/pull/4362#discussion_r3360865995
##########
java/driver/jni/src/main/cpp/jni_wrapper.cc:
##########
@@ -208,6 +208,79 @@ auto WithJniString(JNIEnv* env, jstring jni_string,
Callable&& callable) {
return callable(view.value);
}
+// Get the contents of a ByteBuffer. We may have to copy, so also take a
+// scratch buffer to hold the copy if needed. If possible, will use
+// DirectByteBuffer so no additional allocation is necessary.
+const uint8_t* GetJniByteBuffer(JNIEnv* env, jobject bytebuffer,
+ std::vector<uint8_t>& scratch,
+ size_t& serialized_length) {
+ jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
+ jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
+
+ if (!env->IsInstanceOf(bytebuffer, bb_class)) {
+ ThrowJavaException(env, "java/lang/IllegalArgumentException",
+ "Argument must be a ByteBuffer");
+ return nullptr;
+ }
+
+ jint remaining = env->CallIntMethod(bytebuffer, bb_remaining);
+ if (remaining < 0) {
+ ThrowJavaException(env, "java/lang/IllegalArgumentException",
+ "ByteBuffer remaining() must be non-negative");
+ return nullptr;
+ }
+ serialized_length = static_cast<size_t>(remaining);
+
+ // fast path (if direct buffer)
+ void* buf = env->GetDirectBufferAddress(bytebuffer);
+ if (buf) {
+ return static_cast<const uint8_t*>(buf);
+ }
Review Comment:
GetDirectBufferAddress() returns the base address of the direct ByteBuffer,
but this helper ignores the buffer's current position. Since serialized_length
is based on remaining() (limit - position), this will pass the wrong bytes
whenever position != 0.
##########
java/driver/jni/src/main/cpp/jni_wrapper.cc:
##########
@@ -208,6 +208,79 @@ auto WithJniString(JNIEnv* env, jstring jni_string,
Callable&& callable) {
return callable(view.value);
}
+// Get the contents of a ByteBuffer. We may have to copy, so also take a
+// scratch buffer to hold the copy if needed. If possible, will use
+// DirectByteBuffer so no additional allocation is necessary.
+const uint8_t* GetJniByteBuffer(JNIEnv* env, jobject bytebuffer,
+ std::vector<uint8_t>& scratch,
+ size_t& serialized_length) {
+ jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
+ jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
+
+ if (!env->IsInstanceOf(bytebuffer, bb_class)) {
+ ThrowJavaException(env, "java/lang/IllegalArgumentException",
+ "Argument must be a ByteBuffer");
+ return nullptr;
+ }
+
+ jint remaining = env->CallIntMethod(bytebuffer, bb_remaining);
+ if (remaining < 0) {
+ ThrowJavaException(env, "java/lang/IllegalArgumentException",
+ "ByteBuffer remaining() must be non-negative");
+ return nullptr;
+ }
+ serialized_length = static_cast<size_t>(remaining);
+
+ // fast path (if direct buffer)
+ void* buf = env->GetDirectBufferAddress(bytebuffer);
+ if (buf) {
+ return static_cast<const uint8_t*>(buf);
+ }
+
+ // middle path (copy from backing array)
+ {
+ jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
+ jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
+ jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset",
"()I");
+ jboolean has_array = env->CallBooleanMethod(bytebuffer, bb_has_array);
+ if (env->ExceptionCheck()) return nullptr;
+ if (has_array) {
+ jint array_offset = env->CallIntMethod(bytebuffer, bb_array_offset);
+ if (env->ExceptionCheck()) return nullptr;
+
+ auto array =
+ reinterpret_cast<jbyteArray>(env->CallObjectMethod(bytebuffer,
bb_array));
+ if (env->ExceptionCheck()) return nullptr;
+
+ assert(serialized_length <=
static_cast<size_t>(env->GetArrayLength(array)));
+ scratch.resize(serialized_length);
+ env->GetByteArrayRegion(array, array_offset,
static_cast<jsize>(serialized_length),
+ reinterpret_cast<jbyte*>(scratch.data()));
+ return scratch.data();
Review Comment:
When copying from a ByteBuffer backed by an array, this uses arrayOffset()
but ignores the buffer's current position. For sliced/positioned buffers, this
will copy from the wrong location (and the bounds assert doesn't account for
offset+position).
##########
java/driver/jni/src/main/cpp/jni_wrapper.cc:
##########
@@ -475,16 +548,33 @@
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementExecuteQuery(
JNIEXPORT void JNICALL
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementSetSqlQuery(
JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jstring query) {
+ struct AdbcError error = ADBC_ERROR_INIT;
+ auto* ptr = reinterpret_cast<struct
AdbcStatement*>(static_cast<uintptr_t>(handle));
try {
- struct AdbcError error = ADBC_ERROR_INIT;
- auto* ptr = reinterpret_cast<struct
AdbcStatement*>(static_cast<uintptr_t>(handle));
JniStringView query_str(env, query);
CHECK_ADBC_ERROR(AdbcStatementSetSqlQuery(ptr, query_str.value, &error),
error);
} catch (const AdbcException& e) {
e.ThrowJavaException(env);
}
}
+JNIEXPORT void JNICALL
+Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_statementSetSubstraitPlan(
+ JNIEnv* env, [[maybe_unused]] jclass self, jlong handle, jobject plan) {
+ struct AdbcError error = ADBC_ERROR_INIT;
+ auto* ptr = reinterpret_cast<struct
AdbcStatement*>(static_cast<uintptr_t>(handle));
+ std::vector<uint8_t> allocated_plan;
+ size_t plan_length = 0;
+ try {
+ const uint8_t* plan_ptr = GetJniByteBuffer(env, plan, allocated_plan,
plan_length);
+ assert(plan_ptr != nullptr);
+ CHECK_ADBC_ERROR(AdbcStatementSetSubstraitPlan(ptr, plan_ptr, plan_length,
&error),
+ error);
Review Comment:
GetJniByteBuffer can return nullptr after throwing a Java exception (e.g.,
wrong argument type). Asserting non-null and continuing risks passing a null
pointer into AdbcStatementSetSubstraitPlan (asserts may be compiled out),
leading to a crash. Return early when a Java exception is pending / pointer is
null.
##########
java/driver/jni/src/main/cpp/jni_wrapper.cc:
##########
@@ -208,6 +208,79 @@ auto WithJniString(JNIEnv* env, jstring jni_string,
Callable&& callable) {
return callable(view.value);
}
+// Get the contents of a ByteBuffer. We may have to copy, so also take a
+// scratch buffer to hold the copy if needed. If possible, will use
+// DirectByteBuffer so no additional allocation is necessary.
+const uint8_t* GetJniByteBuffer(JNIEnv* env, jobject bytebuffer,
+ std::vector<uint8_t>& scratch,
+ size_t& serialized_length) {
+ jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
+ jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
+
+ if (!env->IsInstanceOf(bytebuffer, bb_class)) {
+ ThrowJavaException(env, "java/lang/IllegalArgumentException",
+ "Argument must be a ByteBuffer");
+ return nullptr;
+ }
+
+ jint remaining = env->CallIntMethod(bytebuffer, bb_remaining);
+ if (remaining < 0) {
+ ThrowJavaException(env, "java/lang/IllegalArgumentException",
+ "ByteBuffer remaining() must be non-negative");
+ return nullptr;
+ }
+ serialized_length = static_cast<size_t>(remaining);
+
+ // fast path (if direct buffer)
+ void* buf = env->GetDirectBufferAddress(bytebuffer);
+ if (buf) {
+ return static_cast<const uint8_t*>(buf);
+ }
+
+ // middle path (copy from backing array)
+ {
+ jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
+ jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
+ jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset",
"()I");
+ jboolean has_array = env->CallBooleanMethod(bytebuffer, bb_has_array);
+ if (env->ExceptionCheck()) return nullptr;
+ if (has_array) {
+ jint array_offset = env->CallIntMethod(bytebuffer, bb_array_offset);
+ if (env->ExceptionCheck()) return nullptr;
+
+ auto array =
+ reinterpret_cast<jbyteArray>(env->CallObjectMethod(bytebuffer,
bb_array));
+ if (env->ExceptionCheck()) return nullptr;
+
+ assert(serialized_length <=
static_cast<size_t>(env->GetArrayLength(array)));
+ scratch.resize(serialized_length);
+ env->GetByteArrayRegion(array, array_offset,
static_cast<jsize>(serialized_length),
+ reinterpret_cast<jbyte*>(scratch.data()));
+ return scratch.data();
+ }
+ }
+
+ // slow path (copy via invoking Java code to copy into a temp array)
+ {
+ jmethodID bb_get = RequireMethod(env, bb_class, "get",
"([B)Ljava/nio/ByteBuffer;");
+ jbyteArray temp = env->NewByteArray(static_cast<jsize>(serialized_length));
+ if (!temp) {
+ ThrowJavaException(env, "java/lang/OutOfMemoryError",
+ "Failed to allocate byte array to copy ByteBuffer");
+ return nullptr;
+ }
+
+ env->CallVoidMethod(bytebuffer, bb_get, temp);
Review Comment:
The method ID for ByteBuffer.get(byte[]) has a non-void return type, but the
JNI call uses CallVoidMethod. Using the wrong Call*Method variant is undefined
behavior in JNI and can crash on some JVMs.
##########
java/driver/jni-validation/src/test/java/org/apache/arrow/adbc/driver/jni/FlightSqlIntegrationTest.java:
##########
@@ -69,30 +83,106 @@ void afterEach() throws Exception {
@Test
void simple() throws Exception {
try (var stmt = conn.createStatement()) {
+ // smoke test; the server doesn't actually run any query and returns a
fixed result
stmt.setSqlQuery("SELECT 1 + 1 AS sum");
try (var reader = stmt.executeQuery()) {
- assertThat(reader.getReader().loadNextBatch()).isTrue();
-
assertThat(reader.getReader().getVectorSchemaRoot().getVector("sum").getObject(0))
- .isEqualTo(2L);
+ assertThat(ArrowToJava.toIntegers(reader.getReader(),
"ints")).containsExactly(5);
}
}
}
@Test
- void partitioned() throws Exception {
+ void substrait() throws Exception {
try (var stmt = conn.createStatement()) {
- stmt.setSqlQuery("SELECT 1 + 1 AS sum");
- var partitions = stmt.executePartitioned();
- assertThat(partitions.getPartitionDescriptors().size()).isEqualTo(1);
- assertThat(partitions.getAffectedRows()).isEqualTo(-1);
- // The test server doesn't give a schema.
- assertThat(partitions.getSchema()).isNull();
-
- try (var reader =
-
conn.readPartition(partitions.getPartitionDescriptors().get(0).getDescriptor()))
{
- assertThat(reader.loadNextBatch()).isTrue();
-
assertThat(reader.getVectorSchemaRoot().getVector("sum").getObject(0)).isEqualTo(2L);
+ stmt.setSubstraitPlan(ByteBuffer.wrap(new byte[] {42, 0, (byte) 129,
(byte) 255}));
+ try (var reader = stmt.executeQuery()) {
+ assertThat(ArrowToJava.toIntegers(reader.getReader(),
"substrait")).containsExactly(5);
}
}
}
+
+ @Test
+ void poll() throws Exception {
+ var expectedSchema = new Schema(List.of(Field.nullable("ints",
Types.MinorType.INT.getType())));
+ try (var stmt = conn.createStatement()) {
+ stmt.setSqlQuery("poll");
+ stmt.prepare();
+ var iter = stmt.pollPartitioned();
+ var elements = new ArrayList<AdbcStatement.PartitionResult>();
+ iter.forEachRemaining(elements::add);
+ assertThat(elements).size().isEqualTo(5);
+ assertThat(elements.get(0).getSchema()).isNull();
+ assertThat(elements.get(1).getSchema()).isNull();
+
ArrowAssertions.assertSchema(elements.get(2).getSchema()).isEqualTo(expectedSchema);
+
ArrowAssertions.assertSchema(elements.get(3).getSchema()).isEqualTo(expectedSchema);
+
ArrowAssertions.assertSchema(elements.get(4).getSchema()).isEqualTo(expectedSchema);
+
+ elements.forEach(
+ partitionResult -> {
+
assertThat(partitionResult.getPartitionDescriptors()).size().isEqualTo(1);
+ });
+ }
+ }
+
+ @Test
+ void progress() throws Exception {
+ try (var stmt = conn.createStatement()) {
+ stmt.setSqlQuery("poll");
+ stmt.prepare();
+ var iter = stmt.pollPartitioned();
+ assertThat(stmt.getProgress()).isEqualTo(0.0);
+ assertThat(stmt.getMaxProgress()).isEqualTo(1.0);
+
+ for (int i = 0; i < 5; i++) {
+ assertThat(iter.hasNext()).isTrue();
+ assertThat(stmt.getProgress()).isCloseTo(0.2 * (i + 1),
Assertions.offset(0.05));
+ assertThat(stmt.getMaxProgress()).isEqualTo(1.0);
+ assertThat(iter.next()).isNotNull();
+ }
+
+ assertThat(iter.hasNext()).isFalse();
+ }
+ }
+
+ @Test
+ void pollError() throws Exception {
+ try (var stmt = conn.createStatement()) {
+ stmt.setSqlQuery("error_poll_later");
+ stmt.prepare();
+ var iter = stmt.pollPartitioned();
+ assertThat(iter.hasNext()).isTrue();
+ assertThat(iter.next()).isNotNull();
+ assertThatThrownBy(iter::hasNext).hasMessageContaining("expected error");
+ }
+ }
+
+ @Test
+ void pollForever() throws Exception {
+ try (var stmt = conn.createStatement()) {
+ stmt.setSqlQuery("forever");
+ stmt.prepare();
+ var iter = stmt.pollPartitioned();
+ assertThat(iter.hasNext()).isTrue();
+ assertThat(iter.next()).isNotNull();
+ // XXX: rather janky. must cancel on the background; cancelling in
between statements has no
+ // effect. we should perhaps reconsider this behavior
+ var t =
+ new Thread(
+ () -> {
+ try {
+ while (stmt.getProgress() < 0.05) {
+ Thread.sleep(100);
+ }
+ System.out.println("progress: " + stmt.getProgress());
+ stmt.cancel();
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
Review Comment:
This background thread calls stmt.getProgress() while the main thread is
concurrently polling the statement. Per AdbcStatement's thread-safety contract,
only cancel() is required to be thread-safe, so this can lead to flaky tests
depending on the driver implementation. Prefer waiting via sleep (or other
external signal) and only calling cancel() from the background thread.
##########
java/driver/jni/src/main/cpp/jni_wrapper.cc:
##########
@@ -1059,81 +1149,15 @@
Java_org_apache_arrow_adbc_driver_jni_impl_NativeAdbc_connectionReadPartition(
auto* conn = reinterpret_cast<struct
AdbcConnection*>(static_cast<uintptr_t>(handle));
struct ArrowArrayStream out = {};
size_t serialized_length = 0;
- const uint8_t* serialized_partition = nullptr;
std::vector<uint8_t> allocated_partition;
try {
- jclass bb_class = RequireClass(env, "java/nio/ByteBuffer");
- jmethodID bb_remaining = RequireMethod(env, bb_class, "remaining", "()I");
-
- if (!env->IsInstanceOf(partition, bb_class)) {
- ThrowJavaException(env, "java/lang/IllegalArgumentException",
- "Partition must be a ByteBuffer");
- return nullptr;
- }
- jint remaining = env->CallIntMethod(partition, bb_remaining);
- if (remaining < 0) {
- ThrowJavaException(env, "java/lang/IllegalArgumentException",
- "ByteBuffer remaining() must be non-negative");
- return nullptr;
- }
- serialized_length = static_cast<size_t>(remaining);
-
- // fast path (if direct buffer)
- void* buf = env->GetDirectBufferAddress(partition);
- if (buf) {
- serialized_partition = static_cast<const uint8_t*>(buf);
- }
-
- // middle path (backing array)
- if (!serialized_partition) {
- jmethodID bb_has_array = RequireMethod(env, bb_class, "hasArray", "()Z");
- jmethodID bb_array = RequireMethod(env, bb_class, "array", "()[B");
- jmethodID bb_array_offset = RequireMethod(env, bb_class, "arrayOffset",
"()I");
- jboolean has_array = env->CallBooleanMethod(partition, bb_has_array);
- if (env->ExceptionCheck()) return nullptr;
- if (has_array) {
- jint array_offset = env->CallIntMethod(partition, bb_array_offset);
- if (env->ExceptionCheck()) return nullptr;
-
- auto array =
- reinterpret_cast<jbyteArray>(env->CallObjectMethod(partition,
bb_array));
- if (env->ExceptionCheck()) return nullptr;
-
- assert(serialized_length <=
static_cast<size_t>(env->GetArrayLength(array)));
- allocated_partition.resize(serialized_length);
- env->GetByteArrayRegion(array, array_offset,
- static_cast<jsize>(serialized_length),
-
reinterpret_cast<jbyte*>(allocated_partition.data()));
- serialized_partition = allocated_partition.data();
- }
- }
-
- // slow path (copy)
- if (!serialized_partition) {
- jmethodID bb_get = RequireMethod(env, bb_class, "get",
"([B)Ljava/nio/ByteBuffer;");
- jbyteArray temp =
env->NewByteArray(static_cast<jsize>(serialized_length));
- if (!temp) {
- ThrowJavaException(env, "java/lang/OutOfMemoryError",
- "Failed to allocate byte array for partition");
- return nullptr;
- }
-
- env->CallVoidMethod(partition, bb_get, temp);
- if (env->ExceptionCheck()) return nullptr;
-
- allocated_partition.resize(serialized_length);
- env->GetByteArrayRegion(temp, 0, static_cast<jsize>(serialized_length),
-
reinterpret_cast<jbyte*>(allocated_partition.data()));
- serialized_partition = allocated_partition.data();
- }
-
+ const uint8_t* serialized_partition =
+ GetJniByteBuffer(env, partition, allocated_partition,
serialized_length);
assert(serialized_partition != nullptr);
-
CHECK_ADBC_ERROR(AdbcConnectionReadPartition(conn, serialized_partition,
serialized_length, &out,
&error),
Review Comment:
GetJniByteBuffer can return nullptr after throwing a Java exception (e.g.,
wrong argument type). Asserting non-null and continuing risks passing a null
pointer into AdbcConnectionReadPartition (asserts may be compiled out), leading
to a crash. Return early when a Java exception is pending / pointer is null.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]