zhztheplayer commented on a change in pull request #7030: URL: https://github.com/apache/arrow/pull/7030#discussion_r429553399
########## File path: cpp/src/jni/dataset/jni_wrapper.cpp ########## @@ -0,0 +1,517 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/dataset/api.h> +#include <arrow/dataset/file_base.h> +#include <arrow/filesystem/hdfs.h> +#include <arrow/filesystem/localfs.h> +#include <arrow/io/api.h> +#include <arrow/ipc/api.h> +#include <arrow/util/iterator.h> + +#include "org_apache_arrow_dataset_file_JniWrapper.h" +#include "org_apache_arrow_dataset_jni_JniWrapper.h" + +static jclass illegal_access_exception_class; +static jclass illegal_argument_exception_class; +static jclass runtime_exception_class; + +static jclass record_batch_handle_class; +static jclass record_batch_handle_field_class; +static jclass record_batch_handle_buffer_class; + +static jmethodID record_batch_handle_constructor; +static jmethodID record_batch_handle_field_constructor; +static jmethodID record_batch_handle_buffer_constructor; + +static jint JNI_VERSION = JNI_VERSION_1_6; + +class JniPendingException : public std::runtime_error { + public: + explicit JniPendingException(const std::string& arg) : runtime_error(arg) {} +}; + +void ThrowPendingException(const std::string& message) { + throw JniPendingException(message); +} + +template <typename T> +T JniGetOrThrow(arrow::Result<T> result) { + if (!result.status().ok()) { + ThrowPendingException(result.status().message()); + } + return std::move(result).ValueOrDie(); +} + +void JniAssertOkOrThrow(arrow::Status status) { + if (!status.ok()) { + ThrowPendingException(status.message()); + } +} + +void JniThrow(std::string message) { ThrowPendingException(message); } + +#define JNI_METHOD_START try { +// macro ended + +#define JNI_METHOD_END(fallback_expr) \ + } \ + catch (JniPendingException & e) { \ + env->ThrowNew(runtime_exception_class, e.what()); \ + return fallback_expr; \ + } +// macro ended + +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { + jclass local_class = env->FindClass(class_name); + jclass global_class = (jclass)env->NewGlobalRef(local_class); + env->DeleteLocalRef(local_class); + return global_class; +} + +arrow::Result<jmethodID> GetMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig) { + jmethodID ret = env->GetMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find method " + std::string(name) + + " within signature" + std::string(sig); + return arrow::Status::Invalid(error_message); + } + return ret; +} + +jint JNI_OnLoad(JavaVM* vm, void* reserved) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + JNI_METHOD_START + illegal_access_exception_class = + CreateGlobalClassReference(env, "Ljava/lang/IllegalAccessException;"); + illegal_argument_exception_class = + CreateGlobalClassReference(env, "Ljava/lang/IllegalArgumentException;"); + runtime_exception_class = + CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); + + record_batch_handle_class = + CreateGlobalClassReference(env, + "Lorg/apache/arrow/" + "dataset/jni/NativeRecordBatchHandle;"); + record_batch_handle_field_class = + CreateGlobalClassReference(env, + "Lorg/apache/arrow/" + "dataset/jni/NativeRecordBatchHandle$Field;"); + record_batch_handle_buffer_class = + CreateGlobalClassReference(env, + "Lorg/apache/arrow/" + "dataset/jni/NativeRecordBatchHandle$Buffer;"); + + record_batch_handle_constructor = + JniGetOrThrow(GetMethodID(env, record_batch_handle_class, "<init>", + "(J[Lorg/apache/arrow/dataset/" + "jni/NativeRecordBatchHandle$Field;" + "[Lorg/apache/arrow/dataset/" + "jni/NativeRecordBatchHandle$Buffer;)V")); + record_batch_handle_field_constructor = + JniGetOrThrow(GetMethodID(env, record_batch_handle_field_class, "<init>", "(JJ)V")); + record_batch_handle_buffer_constructor = JniGetOrThrow( + GetMethodID(env, record_batch_handle_buffer_class, "<init>", "(JJJJ)V")); + + return JNI_VERSION; + JNI_METHOD_END(JNI_ERR) +} + +void JNI_OnUnload(JavaVM* vm, void* reserved) { + JNIEnv* env; + vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION); + env->DeleteGlobalRef(illegal_access_exception_class); + env->DeleteGlobalRef(illegal_argument_exception_class); + env->DeleteGlobalRef(runtime_exception_class); + env->DeleteGlobalRef(record_batch_handle_class); + env->DeleteGlobalRef(record_batch_handle_field_class); + env->DeleteGlobalRef(record_batch_handle_buffer_class); +} + +std::shared_ptr<arrow::Schema> SchemaFromColumnNames( + const std::shared_ptr<arrow::Schema>& input, + const std::vector<std::string>& column_names) { + std::vector<std::shared_ptr<arrow::Field>> columns; + for (const auto& name : column_names) { + columns.push_back(input->GetFieldByName(name)); + } + return std::make_shared<arrow::Schema>(columns); +} + +arrow::Result<std::shared_ptr<arrow::dataset::FileFormat>> GetFileFormat(jint id) { + switch (id) { + case 0: + return std::make_shared<arrow::dataset::ParquetFileFormat>(); + default: + std::string error_message = "illegal file format id: " + std::to_string(id); + return arrow::Status::Invalid(error_message); + } +} + +arrow::Result<std::shared_ptr<arrow::fs::FileSystem>> GetFileSystem( + jint id, std::string path, std::string* out_path) { + switch (id) { + case 0: + *out_path = path; + return std::make_shared<arrow::fs::LocalFileSystem>(); + case 1: { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::fs::FileSystem> ret, + arrow::fs::FileSystemFromUri(path, out_path)) + return ret; + } + default: + std::string error_message = "illegal file system id: " + std::to_string(id); + return arrow::Status::Invalid(error_message); + } +} + +std::string JStringToCString(JNIEnv* env, jstring string) { + if (string == nullptr) { + return std::string(); + } + jboolean copied; + int32_t length = env->GetStringUTFLength(string); + const char* chars = env->GetStringUTFChars(string, &copied); + std::string str = std::string(chars, length); + // fixme calling ReleaseStringUTFChars if memory leak faced + return str; +} + +std::vector<std::string> ToStringVector(JNIEnv* env, jobjectArray& str_array) { + int length = env->GetArrayLength(str_array); + std::vector<std::string> vector; + for (int i = 0; i < length; i++) { + auto string = (jstring)(env->GetObjectArrayElement(str_array, i)); + vector.push_back(JStringToCString(env, string)); + } + return vector; +} + +template <typename T> +jlong CreateNativeRef(std::shared_ptr<T> t) { + std::shared_ptr<T>* retained_ptr = new std::shared_ptr<T>(t); + return reinterpret_cast<jlong>(retained_ptr); +} + +template <typename T> +std::shared_ptr<T> RetrieveNativeInstance(jlong ref) { + std::shared_ptr<T>* retrieved_ptr = reinterpret_cast<std::shared_ptr<T>*>(ref); + return *retrieved_ptr; +} + +template <typename T> +void ReleaseNativeRef(jlong ref) { + std::shared_ptr<T>* retrieved_ptr = reinterpret_cast<std::shared_ptr<T>*>(ref); + delete retrieved_ptr; +} + +arrow::Result<jbyteArray> ToSchemaByteArray(JNIEnv* env, + std::shared_ptr<arrow::Schema> schema) { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr<arrow::Buffer> buffer, + arrow::ipc::SerializeSchema(*schema, nullptr, arrow::default_memory_pool())) + + jbyteArray out = env->NewByteArray(buffer->size()); + auto src = reinterpret_cast<const jbyte*>(buffer->data()); + env->SetByteArrayRegion(out, 0, buffer->size(), src); + return out; +} + +arrow::Result<std::shared_ptr<arrow::Schema>> FromSchemaByteArray( + JNIEnv* env, jbyteArray schemaBytes) { + arrow::ipc::DictionaryMemo in_memo; + int schemaBytes_len = env->GetArrayLength(schemaBytes); + jbyte* schemaBytes_data = env->GetByteArrayElements(schemaBytes, nullptr); + auto serialized_schema = std::make_shared<arrow::Buffer>( + reinterpret_cast<uint8_t*>(schemaBytes_data), schemaBytes_len); + arrow::io::BufferReader buf_reader(serialized_schema); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::Schema> schema, + arrow::ipc::ReadSchema(&buf_reader, &in_memo)) + env->ReleaseByteArrayElements(schemaBytes, schemaBytes_data, JNI_ABORT); + return schema; +} + +/* + * Class: org_apache_arrow_dataset_jni_JniWrapper + * Method: closeDatasetFactory + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDatasetFactory( + JNIEnv* env, jobject, jlong id) { + JNI_METHOD_START + ReleaseNativeRef<arrow::dataset::DatasetFactory>(id); + JNI_METHOD_END() +} + +/* + * Class: org_apache_arrow_dataset_jni_JniWrapper + * Method: inspectSchema + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_inspectSchema( + JNIEnv* env, jobject, jlong dataset_factor_id) { + JNI_METHOD_START + std::shared_ptr<arrow::dataset::DatasetFactory> d = + RetrieveNativeInstance<arrow::dataset::DatasetFactory>(dataset_factor_id); + std::shared_ptr<arrow::Schema> schema = JniGetOrThrow(d->Inspect()); + return JniGetOrThrow(ToSchemaByteArray(env, schema)); + JNI_METHOD_END(nullptr) +} + +/* + * Class: org_apache_arrow_dataset_jni_JniWrapper + * Method: createDataset + * Signature: (J[B)J + */ +JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createDataset( + JNIEnv* env, jobject, jlong dataset_factory_id, jbyteArray schema_bytes) { + JNI_METHOD_START + std::shared_ptr<arrow::dataset::DatasetFactory> d = + RetrieveNativeInstance<arrow::dataset::DatasetFactory>(dataset_factory_id); + std::shared_ptr<arrow::Schema> schema; + schema = JniGetOrThrow(FromSchemaByteArray(env, schema_bytes)); + std::shared_ptr<arrow::dataset::Dataset> dataset = JniGetOrThrow(d->Finish(schema)); + return CreateNativeRef(dataset); + JNI_METHOD_END(-1L) +} + +/* + * Class: org_apache_arrow_dataset_jni_JniWrapper + * Method: closeDataset + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset( + JNIEnv* env, jobject, jlong id) { + JNI_METHOD_START + ReleaseNativeRef<arrow::dataset::Dataset>(id); + JNI_METHOD_END() +} + +/// \class DisposableScannerAdaptor +/// \brief An adaptor that iterates over a Scanner instance then returns RecordBatches +/// directly. +/// +/// This lessens the complexity of the JNI bridge to make sure it to be easier to +/// maintain. On Java-side, NativeScanner can only produces a single NativeScanTask +/// instance during its whole lifecycle. Each task stands for a DisposableScannerAdaptor +/// instance through JNI bridge. +/// +class DisposableScannerAdaptor { Review comment: So let's just remove this and add a method `Iterator<ScanTask> Scanner.scan()` to JNI bridge? I have same concern with you and for me this code is temporary. I think we wanted to keep JNI bridge simple enough at the first implementation so that if C++ API get factored again in future we may save time from changing JNI code to make Java side work. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org