This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new e6b05a8e58 GH-43483: [Java][C++] Support more CsvFragmentScanOptions
in JNI call (#43482)
e6b05a8e58 is described below
commit e6b05a8e58ad5c20f6daa2fa64488b8efec51993
Author: Jin Chengcheng <[email protected]>
AuthorDate: Tue Aug 6 10:12:56 2024 +0800
GH-43483: [Java][C++] Support more CsvFragmentScanOptions in JNI call
(#43482)
### Rationale for this change
Support more CSV fragment scan options
### What changes are included in this PR?
Implement nearly all cpp code supported CsvFragmentScanOptions
### Are these changes tested?
Yes, new added test and exists UT.
### Are there any user-facing changes?
No.
* GitHub Issue: #28866
* GitHub Issue: #43483
Authored-by: Chengcheng Jin <[email protected]>
Signed-off-by: David Li <[email protected]>
---
java/dataset/src/main/cpp/jni_wrapper.cc | 113 ++++++++++++---
.../scanner/csv/CsvFragmentScanOptions.java | 4 +
.../arrow/dataset/TestFragmentScanOptions.java | 156 +++++++++++++++++++++
3 files changed, 254 insertions(+), 19 deletions(-)
diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc
b/java/dataset/src/main/cpp/jni_wrapper.cc
index f324f87d6c..63b8dd73f4 100644
--- a/java/dataset/src/main/cpp/jni_wrapper.cc
+++ b/java/dataset/src/main/cpp/jni_wrapper.cc
@@ -368,29 +368,104 @@ std::shared_ptr<arrow::Buffer>
LoadArrowBufferFromByteBuffer(JNIEnv* env, jobjec
inline bool ParseBool(const std::string& value) { return value == "true" ?
true : false; }
+inline bool ParseChar(const std::string& key, const std::string& value) {
+ if (value.size() != 1) {
+ JniThrow("Option " + key + " should be a char, but is " + value);
+ }
+ return value.at(0);
+}
+
/// \brief Construct FragmentScanOptions from config map
#ifdef ARROW_CSV
-arrow::Result<std::shared_ptr<arrow::dataset::FragmentScanOptions>>
-ToCsvFragmentScanOptions(const std::unordered_map<std::string, std::string>&
configs) {
+
+bool SetCsvConvertOptions(arrow::csv::ConvertOptions& options, const
std::string& key,
+ const std::string& value) {
+ if (key == "column_types") {
+ int64_t schema_address = std::stol(value);
+ ArrowSchema* c_schema = reinterpret_cast<ArrowSchema*>(schema_address);
+ auto schema = JniGetOrThrow(arrow::ImportSchema(c_schema));
+ auto& column_types = options.column_types;
+ for (auto field : schema->fields()) {
+ column_types[field->name()] = field->type();
+ }
+ } else if (key == "strings_can_be_null") {
+ options.strings_can_be_null = ParseBool(value);
+ } else if (key == "check_utf8") {
+ options.check_utf8 = ParseBool(value);
+ } else if (key == "null_values") {
+ options.null_values = {value};
+ } else if (key == "true_values") {
+ options.true_values = {value};
+ } else if (key == "false_values") {
+ options.false_values = {value};
+ } else if (key == "quoted_strings_can_be_null") {
+ options.quoted_strings_can_be_null = ParseBool(value);
+ } else if (key == "auto_dict_encode") {
+ options.auto_dict_encode = ParseBool(value);
+ } else if (key == "auto_dict_max_cardinality") {
+ options.auto_dict_max_cardinality = std::stoi(value);
+ } else if (key == "decimal_point") {
+ options.decimal_point = ParseChar(key, value);
+ } else if (key == "include_missing_columns") {
+ options.include_missing_columns = ParseBool(value);
+ } else {
+ return false;
+ }
+ return true;
+}
+
+bool SetCsvParseOptions(arrow::csv::ParseOptions& options, const std::string&
key,
+ const std::string& value) {
+ if (key == "delimiter") {
+ options.delimiter = ParseChar(key, value);
+ } else if (key == "quoting") {
+ options.quoting = ParseBool(value);
+ } else if (key == "quote_char") {
+ options.quote_char = ParseChar(key, value);
+ } else if (key == "double_quote") {
+ options.double_quote = ParseBool(value);
+ } else if (key == "escaping") {
+ options.escaping = ParseBool(value);
+ } else if (key == "escape_char") {
+ options.escape_char = ParseChar(key, value);
+ } else if (key == "newlines_in_values") {
+ options.newlines_in_values = ParseBool(value);
+ } else if (key == "ignore_empty_lines") {
+ options.ignore_empty_lines = ParseBool(value);
+ } else {
+ return false;
+ }
+ return true;
+}
+
+bool SetCsvReadOptions(arrow::csv::ReadOptions& options, const std::string&
key,
+ const std::string& value) {
+ if (key == "use_threads") {
+ options.use_threads = ParseBool(value);
+ } else if (key == "block_size") {
+ options.block_size = std::stoi(value);
+ } else if (key == "skip_rows") {
+ options.skip_rows = std::stoi(value);
+ } else if (key == "skip_rows_after_names") {
+ options.skip_rows_after_names = std::stoi(value);
+ } else if (key == "autogenerate_column_names") {
+ options.autogenerate_column_names = ParseBool(value);
+ } else {
+ return false;
+ }
+ return true;
+}
+
+std::shared_ptr<arrow::dataset::FragmentScanOptions> ToCsvFragmentScanOptions(
+ const std::unordered_map<std::string, std::string>& configs) {
std::shared_ptr<arrow::dataset::CsvFragmentScanOptions> options =
std::make_shared<arrow::dataset::CsvFragmentScanOptions>();
- for (auto const& [key, value] : configs) {
- if (key == "delimiter") {
- options->parse_options.delimiter = value.data()[0];
- } else if (key == "quoting") {
- options->parse_options.quoting = ParseBool(value);
- } else if (key == "column_types") {
- int64_t schema_address = std::stol(value);
- ArrowSchema* c_schema = reinterpret_cast<ArrowSchema*>(schema_address);
- ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(c_schema));
- auto& column_types = options->convert_options.column_types;
- for (auto field : schema->fields()) {
- column_types[field->name()] = field->type();
- }
- } else if (key == "strings_can_be_null") {
- options->convert_options.strings_can_be_null = ParseBool(value);
- } else {
- return arrow::Status::Invalid("Config " + key + " is not supported.");
+ for (const auto& [key, value] : configs) {
+ bool setValid = SetCsvParseOptions(options->parse_options, key, value) ||
+ SetCsvConvertOptions(options->convert_options, key, value)
||
+ SetCsvReadOptions(options->read_options, key, value);
+ if (!setValid) {
+ JniThrow("Config " + key + " is not supported.");
}
}
return options;
diff --git
a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java
b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java
index 39271b5f06..dddc36d387 100644
---
a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java
+++
b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java
@@ -32,6 +32,10 @@ public class CsvFragmentScanOptions implements
FragmentScanOptions {
* CSV scan options, map to CPP struct CsvFragmentScanOptions. The key in
config map is the field
* name of mapping cpp struct
*
+ * <p>Currently, multi-valued options (which are std::vector values in C++)
only support having a
+ * single value set. For example, for the null_values option, only one
string can be set as the
+ * null value.
+ *
* @param convertOptions similar to CsvFragmentScanOptions#convert_options
in CPP, the ArrowSchema
* represents column_types, convert data option such as null value
recognition.
* @param readOptions similar to CsvFragmentScanOptions#read_options in CPP,
specify how to read
diff --git
a/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java
b/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java
index 9787e8308e..d598190528 100644
---
a/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java
+++
b/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java
@@ -18,10 +18,13 @@ package org.apache.arrow.dataset;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
import java.util.Collections;
+import java.util.Map;
import java.util.Optional;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.CDataDictionaryProvider;
@@ -42,6 +45,7 @@ import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.Text;
import org.hamcrest.collection.IsIterableContainingInOrder;
import org.junit.jupiter.api.Test;
@@ -165,4 +169,156 @@ public class TestFragmentScanOptions {
assertEquals(3, rowCount);
}
}
+
+ @Test
+ public void testCsvReadParseAndReadOptions() throws Exception {
+ final Schema schema =
+ new Schema(
+ Collections.singletonList(Field.nullable("Id;Name;Language", new
ArrowType.Utf8())),
+ null);
+ String path = "file://" + getClass().getResource("/").getPath() +
"/data/student.csv";
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ CsvFragmentScanOptions fragmentScanOptions =
+ new CsvFragmentScanOptions(
+ new CsvConvertOptions(ImmutableMap.of()),
+ ImmutableMap.of("skip_rows_after_names", "1"),
+ ImmutableMap.of("delimiter", ";"));
+ ScanOptions options =
+ new ScanOptions.Builder(/*batchSize*/ 32768)
+ .columns(Optional.empty())
+ .fragmentScanOptions(fragmentScanOptions)
+ .build();
+ try (DatasetFactory datasetFactory =
+ new FileSystemDatasetFactory(
+ allocator,
+ NativeMemoryPool.getDefault(),
+ FileFormat.CSV,
+ path,
+ Optional.of(fragmentScanOptions));
+ Dataset dataset = datasetFactory.finish();
+ Scanner scanner = dataset.newScan(options);
+ ArrowReader reader = scanner.scanBatches()) {
+
+ assertEquals(schema.getFields(),
reader.getVectorSchemaRoot().getSchema().getFields());
+ int rowCount = 0;
+ while (reader.loadNextBatch()) {
+ final ValueIterableVector<Text> idVector =
+ (ValueIterableVector<Text>)
reader.getVectorSchemaRoot().getVector("Id;Name;Language");
+ assertThat(
+ idVector.getValueIterable(),
+ IsIterableContainingInOrder.contains(
+ new Text("2;Peter;Python"), new Text("3;Celin;C++")));
+ rowCount += reader.getVectorSchemaRoot().getRowCount();
+ }
+ assertEquals(2, rowCount);
+ }
+ }
+
+ @Test
+ public void testCsvReadOtherOptions() throws Exception {
+ String path = "file://" + getClass().getResource("/").getPath() +
"/data/student.csv";
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ Map<String, String> convertOption =
+ ImmutableMap.of(
+ "check_utf8",
+ "true",
+ "null_values",
+ "NULL",
+ "true_values",
+ "True",
+ "false_values",
+ "False",
+ "quoted_strings_can_be_null",
+ "true",
+ "auto_dict_encode",
+ "false",
+ "auto_dict_max_cardinality",
+ "3456",
+ "decimal_point",
+ ".",
+ "include_missing_columns",
+ "false");
+ Map<String, String> readOption =
+ ImmutableMap.of(
+ "use_threads",
+ "true",
+ "block_size",
+ "1024",
+ "skip_rows",
+ "12",
+ "skip_rows_after_names",
+ "12",
+ "autogenerate_column_names",
+ "false");
+ Map<String, String> parseOption =
+ ImmutableMap.of(
+ "delimiter",
+ ".",
+ "quoting",
+ "true",
+ "quote_char",
+ "'",
+ "double_quote",
+ "False",
+ "escaping",
+ "true",
+ "escape_char",
+ "v",
+ "newlines_in_values",
+ "false",
+ "ignore_empty_lines",
+ "true");
+ CsvFragmentScanOptions fragmentScanOptions =
+ new CsvFragmentScanOptions(new CsvConvertOptions(convertOption),
readOption, parseOption);
+ ScanOptions options =
+ new ScanOptions.Builder(/*batchSize*/ 32768)
+ .columns(Optional.empty())
+ .fragmentScanOptions(fragmentScanOptions)
+ .build();
+ try (DatasetFactory datasetFactory =
+ new FileSystemDatasetFactory(
+ allocator, NativeMemoryPool.getDefault(), FileFormat.CSV,
path);
+ Dataset dataset = datasetFactory.finish();
+ Scanner scanner = dataset.newScan(options)) {
+ assertNotNull(scanner);
+ }
+ }
+
+ @Test
+ public void testCsvInvalidOption() throws Exception {
+ String path = "file://" + getClass().getResource("/").getPath() +
"/data/student.csv";
+ BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ Map<String, String> convertOption =
ImmutableMap.of("not_exists_key_check_utf8", "true");
+ CsvFragmentScanOptions fragmentScanOptions =
+ new CsvFragmentScanOptions(
+ new CsvConvertOptions(convertOption), ImmutableMap.of(),
ImmutableMap.of());
+ ScanOptions options =
+ new ScanOptions.Builder(/*batchSize*/ 32768)
+ .columns(Optional.empty())
+ .fragmentScanOptions(fragmentScanOptions)
+ .build();
+ try (DatasetFactory datasetFactory =
+ new FileSystemDatasetFactory(
+ allocator, NativeMemoryPool.getDefault(), FileFormat.CSV,
path);
+ Dataset dataset = datasetFactory.finish()) {
+ assertThrows(RuntimeException.class, () -> dataset.newScan(options));
+ }
+
+ CsvFragmentScanOptions fragmentScanOptionsFaultValue =
+ new CsvFragmentScanOptions(
+ new CsvConvertOptions(ImmutableMap.of()),
+ ImmutableMap.of("", ""),
+ ImmutableMap.of("escape_char", "vbvb"));
+ ScanOptions optionsFault =
+ new ScanOptions.Builder(/*batchSize*/ 32768)
+ .columns(Optional.empty())
+ .fragmentScanOptions(fragmentScanOptionsFaultValue)
+ .build();
+ try (DatasetFactory datasetFactory =
+ new FileSystemDatasetFactory(
+ allocator, NativeMemoryPool.getDefault(), FileFormat.CSV,
path);
+ Dataset dataset = datasetFactory.finish()) {
+ assertThrows(RuntimeException.class, () ->
dataset.newScan(optionsFault));
+ }
+ }
}