This is an automated email from the ASF dual-hosted git repository. colinlee pushed a commit to branch support_arrow_struct in repository https://gitbox.apache.org/repos/asf/tsfile.git
commit 014580bd8bb9be34d24219becac1e2b196570b2a Author: ColinLee <[email protected]> AuthorDate: Mon Mar 9 09:29:45 2026 +0800 support arrow write. --- cpp/src/cwrapper/arrow_c.cc | 215 +++++++++++++++ cpp/src/cwrapper/tsfile_cwrapper.cc | 23 +- cpp/src/cwrapper/tsfile_cwrapper.h | 5 + cpp/src/writer/tsfile_writer.cc | 8 + cpp/src/writer/tsfile_writer.h | 2 + python/tests/bench_write_arrow_vs_dataframe.py | 230 ++++++++++++++++ python/tests/test_write_arrow.py | 368 +++++++++++++++++++++++++ python/tsfile/tsfile_cpp.pxd | 6 + python/tsfile/tsfile_table_writer.py | 11 + python/tsfile/tsfile_writer.pyx | 37 +++ 10 files changed, 903 insertions(+), 2 deletions(-) diff --git a/cpp/src/cwrapper/arrow_c.cc b/cpp/src/cwrapper/arrow_c.cc index e7f6c2de5..40fdcb639 100644 --- a/cpp/src/cwrapper/arrow_c.cc +++ b/cpp/src/cwrapper/arrow_c.cc @@ -23,6 +23,7 @@ #include <vector> #include "common/allocator/alloc_base.h" +#include "common/tablet.h" #include "common/tsblock/tsblock.h" #include "common/tsblock/tuple_desc.h" #include "common/tsblock/vector/vector.h" @@ -758,4 +759,218 @@ int TsBlockToArrowStruct(common::TsBlock& tsblock, ArrowArray* out_array, return common::E_OK; } +// Convert days since Unix epoch back to YYYYMMDD integer format +static int32_t DaysSinceEpochToYYYYMMDD(int32_t days) { + std::tm epoch = {}; + epoch.tm_year = 70; + epoch.tm_mon = 0; + epoch.tm_mday = 1; + epoch.tm_hour = 12; + epoch.tm_isdst = -1; + time_t epoch_t = mktime(&epoch); + time_t target_t = epoch_t + static_cast<time_t>(days) * 24 * 60 * 60; + std::tm* d = localtime(&target_t); + return (d->tm_year + 1900) * 10000 + (d->tm_mon + 1) * 100 + d->tm_mday; +} + +// Check if Arrow row is valid (non-null) based on validity bitmap +static bool ArrowIsValid(const ArrowArray* arr, int64_t row) { + if (arr->null_count == 0 || arr->buffers[0] == nullptr) return true; + int64_t bit_idx = arr->offset + row; + const uint8_t* bitmap = static_cast<const uint8_t*>(arr->buffers[0]); + return (bitmap[bit_idx / 8] >> (bit_idx % 8)) & 1; +} + +// Map Arrow format string to TSDataType +static common::TSDataType ArrowFormatToDataType(const char* format) { + if (strcmp(format, "b") == 0) return common::BOOLEAN; + if (strcmp(format, "i") == 0) return common::INT32; + if (strcmp(format, "l") == 0) return common::INT64; + if (strcmp(format, "tsn:") == 0) return common::TIMESTAMP; + if (strcmp(format, "f") == 0) return common::FLOAT; + if (strcmp(format, "g") == 0) return common::DOUBLE; + if (strcmp(format, "u") == 0) return common::TEXT; + if (strcmp(format, "tdD") == 0) return common::DATE; + return common::INVALID_DATATYPE; +} + +// Convert Arrow C Data Interface struct array to storage::Tablet. +// The timestamp column (format "tsn:") is used as tablet timestamps; +// all other columns become tablet data columns. +// reg_schema: optional registered TableSchema; when provided its column types +// are used in the Tablet (so they match the writer's registered schema +// exactly). Arrow format strings are still used to decode the actual buffers. +int ArrowStructToTablet(const char* table_name, const ArrowArray* in_array, + const ArrowSchema* in_schema, + const storage::TableSchema* reg_schema, + storage::Tablet** out_tablet) { + if (!in_array || !in_schema || !out_tablet) return common::E_INVALID_ARG; + if (strcmp(in_schema->format, "+s") != 0) return common::E_INVALID_ARG; + + int64_t n_rows = in_array->length; + int64_t n_cols = in_schema->n_children; + if (n_rows <= 0 || n_cols == 0) return common::E_INVALID_ARG; + + int time_col_idx = -1; + std::vector<std::string> col_names; + // col_types: types for Tablet schema (from reg_schema when available) + std::vector<common::TSDataType> col_types; + // read_modes: how to decode Arrow buffers (from Arrow format string) + std::vector<common::TSDataType> read_modes; + std::vector<int> data_col_indices; + + // Cache reg_schema data types once to avoid repeated calls + std::vector<common::TSDataType> reg_data_types; + if (reg_schema) { + reg_data_types = reg_schema->get_data_types(); + } + + for (int64_t i = 0; i < n_cols; i++) { + const ArrowSchema* child = in_schema->children[i]; + common::TSDataType read_mode = ArrowFormatToDataType(child->format); + if (read_mode == common::INVALID_DATATYPE) + return common::E_TYPE_NOT_SUPPORTED; + if (read_mode == common::TIMESTAMP) { + time_col_idx = static_cast<int>(i); + } else { + std::string col_name = child->name ? child->name : ""; + common::TSDataType col_type = read_mode; + if (reg_schema) { + int reg_idx = const_cast<storage::TableSchema*>(reg_schema) + ->find_column_index(col_name); + if (reg_idx >= 0 && + reg_idx < static_cast<int>(reg_data_types.size())) { + col_type = reg_data_types[reg_idx]; + } + } + col_names.emplace_back(std::move(col_name)); + col_types.push_back(col_type); + read_modes.push_back(read_mode); + data_col_indices.push_back(static_cast<int>(i)); + } + } + + if (col_names.empty()) return common::E_INVALID_ARG; + + std::string tname = table_name ? table_name : ""; + auto* tablet = new storage::Tablet(tname, &col_names, &col_types, + static_cast<int>(n_rows)); + if (tablet->err_code_ != common::E_OK) { + int err = tablet->err_code_; + delete tablet; + return err; + } + + // Fill timestamps from the time column + if (time_col_idx >= 0) { + const ArrowArray* ts_arr = in_array->children[time_col_idx]; + const int64_t* ts_buf = static_cast<const int64_t*>(ts_arr->buffers[1]); + int64_t off = ts_arr->offset; + for (int64_t r = 0; r < n_rows; r++) { + if (ArrowIsValid(ts_arr, r)) + tablet->add_timestamp(static_cast<uint32_t>(r), + ts_buf[off + r]); + } + } + + // Fill data columns from Arrow children (use read_modes to decode buffers) + for (size_t ci = 0; ci < data_col_indices.size(); ci++) { + const ArrowArray* col_arr = in_array->children[data_col_indices[ci]]; + common::TSDataType dtype = read_modes[ci]; + uint32_t tcol = static_cast<uint32_t>(ci); + int64_t off = col_arr->offset; + + switch (dtype) { + case common::BOOLEAN: { + // Arrow boolean: bit-packed in buffers[1] + const uint8_t* vals = + static_cast<const uint8_t*>(col_arr->buffers[1]); + for (int64_t r = 0; r < n_rows; r++) { + if (!ArrowIsValid(col_arr, r)) continue; + int64_t bit = off + r; + bool v = (vals[bit / 8] >> (bit % 8)) & 1; + tablet->add_value<bool>(static_cast<uint32_t>(r), tcol, v); + } + break; + } + case common::INT32: { + const int32_t* vals = + static_cast<const int32_t*>(col_arr->buffers[1]); + for (int64_t r = 0; r < n_rows; r++) { + if (ArrowIsValid(col_arr, r)) + tablet->add_value<int32_t>(static_cast<uint32_t>(r), + tcol, vals[off + r]); + } + break; + } + case common::INT64: { + const int64_t* vals = + static_cast<const int64_t*>(col_arr->buffers[1]); + for (int64_t r = 0; r < n_rows; r++) { + if (ArrowIsValid(col_arr, r)) + tablet->add_value<int64_t>(static_cast<uint32_t>(r), + tcol, vals[off + r]); + } + break; + } + case common::FLOAT: { + const float* vals = + static_cast<const float*>(col_arr->buffers[1]); + for (int64_t r = 0; r < n_rows; r++) { + if (ArrowIsValid(col_arr, r)) + tablet->add_value<float>(static_cast<uint32_t>(r), tcol, + vals[off + r]); + } + break; + } + case common::DOUBLE: { + const double* vals = + static_cast<const double*>(col_arr->buffers[1]); + for (int64_t r = 0; r < n_rows; r++) { + if (ArrowIsValid(col_arr, r)) + tablet->add_value<double>(static_cast<uint32_t>(r), + tcol, vals[off + r]); + } + break; + } + case common::DATE: { + // Arrow stores date as int32 days-since-epoch; convert to + // YYYYMMDD + const int32_t* vals = + static_cast<const int32_t*>(col_arr->buffers[1]); + for (int64_t r = 0; r < n_rows; r++) { + if (!ArrowIsValid(col_arr, r)) continue; + int32_t yyyymmdd = DaysSinceEpochToYYYYMMDD(vals[off + r]); + tablet->add_value<int32_t>(static_cast<uint32_t>(r), tcol, + yyyymmdd); + } + break; + } + case common::TEXT: + case common::STRING: { + // Arrow UTF-8 string: buffers[1]=int32 offsets, buffers[2]=char + // data + const int32_t* offsets = + static_cast<const int32_t*>(col_arr->buffers[1]); + const char* data = + static_cast<const char*>(col_arr->buffers[2]); + for (int64_t r = 0; r < n_rows; r++) { + if (!ArrowIsValid(col_arr, r)) continue; + int32_t start = offsets[off + r]; + int32_t len = offsets[off + r + 1] - start; + tablet->add_value(static_cast<uint32_t>(r), tcol, + common::String(data + start, len)); + } + break; + } + default: + delete tablet; + return common::E_TYPE_NOT_SUPPORTED; + } + } + + *out_tablet = tablet; + return common::E_OK; +} + } // namespace arrow diff --git a/cpp/src/cwrapper/tsfile_cwrapper.cc b/cpp/src/cwrapper/tsfile_cwrapper.cc index 298f27f0a..779156945 100644 --- a/cpp/src/cwrapper/tsfile_cwrapper.cc +++ b/cpp/src/cwrapper/tsfile_cwrapper.cc @@ -33,11 +33,15 @@ #include "reader/tsfile_reader.h" #include "writer/tsfile_writer.h" -// Forward declaration for arrow namespace function (defined in arrow_c.cc) +// Forward declarations for arrow namespace functions (defined in arrow_c.cc) namespace arrow { int TsBlockToArrowStruct(common::TsBlock& tsblock, ArrowArray* out_array, ArrowSchema* out_schema); -} +int ArrowStructToTablet(const char* table_name, const ArrowArray* in_array, + const ArrowSchema* in_schema, + const storage::TableSchema* reg_schema, + storage::Tablet** out_tablet); +} // namespace arrow #ifdef __cplusplus extern "C" { @@ -795,6 +799,21 @@ ERRNO _tsfile_writer_write_table(TsFileWriter writer, Tablet tablet) { return w->write_table(*tbl); } +ERRNO _tsfile_writer_write_arrow_table(TsFileWriter writer, + const char* table_name, + ArrowArray* array, ArrowSchema* schema) { + auto* w = static_cast<storage::TsFileWriter*>(writer); + std::shared_ptr<storage::TableSchema> reg_schema = + w->get_table_schema(table_name ? std::string(table_name) : ""); + storage::Tablet* tablet = nullptr; + int ret = arrow::ArrowStructToTablet(table_name, array, schema, + reg_schema.get(), &tablet); + if (ret != common::E_OK) return ret; + ret = w->write_table(*tablet); + delete tablet; + return ret; +} + ERRNO _tsfile_writer_write_ts_record(TsFileWriter writer, TsRecord data) { auto* w = static_cast<storage::TsFileWriter*>(writer); const storage::TsRecord* record = static_cast<storage::TsRecord*>(data); diff --git a/cpp/src/cwrapper/tsfile_cwrapper.h b/cpp/src/cwrapper/tsfile_cwrapper.h index b04e32c26..8d7b79d52 100644 --- a/cpp/src/cwrapper/tsfile_cwrapper.h +++ b/cpp/src/cwrapper/tsfile_cwrapper.h @@ -718,6 +718,11 @@ ERRNO _tsfile_writer_write_tablet(TsFileWriter writer, Tablet tablet); // Write a tablet into a table. ERRNO _tsfile_writer_write_table(TsFileWriter writer, Tablet tablet); +// Write Arrow C Data Interface batch into a table (Arrow -> Tablet -> write). +ERRNO _tsfile_writer_write_arrow_table(TsFileWriter writer, + const char* table_name, + ArrowArray* array, ArrowSchema* schema); + // Write a row record into a device. ERRNO _tsfile_writer_write_ts_record(TsFileWriter writer, TsRecord record); diff --git a/cpp/src/writer/tsfile_writer.cc b/cpp/src/writer/tsfile_writer.cc index 2c2e46b97..1693a6647 100644 --- a/cpp/src/writer/tsfile_writer.cc +++ b/cpp/src/writer/tsfile_writer.cc @@ -336,6 +336,14 @@ int TsFileWriter::do_check_and_prepare_tablet(Tablet& tablet) { return common::E_OK; } +std::shared_ptr<TableSchema> TsFileWriter::get_table_schema( + const std::string& table_name) const { + auto& schema_map = io_writer_->get_schema()->table_schema_map_; + auto it = schema_map.find(table_name); + if (it == schema_map.end()) return nullptr; + return it->second; +} + template <typename MeasurementNamesGetter> int TsFileWriter::do_check_schema( std::shared_ptr<IDeviceID> device_id, diff --git a/cpp/src/writer/tsfile_writer.h b/cpp/src/writer/tsfile_writer.h index e80a1232b..106a41dce 100644 --- a/cpp/src/writer/tsfile_writer.h +++ b/cpp/src/writer/tsfile_writer.h @@ -90,6 +90,8 @@ class TsFileWriter { TableSchemasMapIter; DeviceSchemasMap* get_schema_group_map() { return &schemas_; } + std::shared_ptr<TableSchema> get_table_schema( + const std::string& table_name) const; int64_t calculate_mem_size_for_all_group(); int check_memory_size_and_may_flush_chunks(); /* diff --git a/python/tests/bench_write_arrow_vs_dataframe.py b/python/tests/bench_write_arrow_vs_dataframe.py new file mode 100644 index 000000000..c2f9bedcd --- /dev/null +++ b/python/tests/bench_write_arrow_vs_dataframe.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +Benchmark: write_arrow_batch vs write_dataframe. + +Compares write throughput (rows/s) for: + - Arrow path : write_arrow_batch(pa.RecordBatch) + - DataFrame path: write_dataframe(pd.DataFrame) + +Run: + python -m pytest tests/bench_write_arrow_vs_dataframe.py -v -s + python tests/bench_write_arrow_vs_dataframe.py [row_count [batch_size]] +""" + +import os +import sys +import time + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest + +from tsfile import ( + ColumnCategory, + ColumnSchema, + TableSchema, + TSDataType, + TsFileTableWriter, +) + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +DEFAULT_ROW_COUNT = 100_000 +DEFAULT_BATCH_SIZE = 8_192 +DEFAULT_ROUNDS = 3 + +TABLE_NAME = "bench_table" +BENCH_FILE = "bench_write_arrow.tsfile" + +SCHEMA = TableSchema(TABLE_NAME, [ + ColumnSchema("device", TSDataType.STRING, ColumnCategory.TAG), + ColumnSchema("v_i64", TSDataType.INT64, ColumnCategory.FIELD), + ColumnSchema("v_f64", TSDataType.DOUBLE, ColumnCategory.FIELD), + ColumnSchema("v_bool", TSDataType.BOOLEAN, ColumnCategory.FIELD), + ColumnSchema("v_str", TSDataType.STRING, ColumnCategory.FIELD), +]) + + +# --------------------------------------------------------------------------- +# Data generation +# --------------------------------------------------------------------------- + +def _make_numpy_data(row_count: int): + ts = np.arange(row_count, dtype="int64") + v_i64 = np.arange(row_count, dtype="int64") + v_f64 = np.arange(row_count, dtype="float64") * 1.5 + v_bool = (np.arange(row_count) % 2 == 0) + v_str = [f"s{i}" for i in range(row_count)] + device = ["device0"] * row_count + return ts, device, v_i64, v_f64, v_bool, v_str + + +def _make_arrow_batches(row_count: int, batch_size: int): + ts, device, v_i64, v_f64, v_bool, v_str = _make_numpy_data(row_count) + batches = [] + for start in range(0, row_count, batch_size): + end = min(start + batch_size, row_count) + batches.append(pa.record_batch({ + "time": pa.array(ts[start:end], type=pa.timestamp("ns")), + "device": pa.array(device[start:end], type=pa.string()), + "v_i64": pa.array(v_i64[start:end], type=pa.int64()), + "v_f64": pa.array(v_f64[start:end], type=pa.float64()), + "v_bool": pa.array(v_bool[start:end], type=pa.bool_()), + "v_str": pa.array(v_str[start:end], type=pa.string()), + })) + return batches + + +def _make_dataframe_chunks(row_count: int, batch_size: int): + ts, device, v_i64, v_f64, v_bool, v_str = _make_numpy_data(row_count) + chunks = [] + for start in range(0, row_count, batch_size): + end = min(start + batch_size, row_count) + chunks.append(pd.DataFrame({ + "time": pd.Series(ts[start:end], dtype="int64"), + "device": device[start:end], + "v_i64": pd.Series(v_i64[start:end], dtype="int64"), + "v_f64": pd.Series(v_f64[start:end], dtype="float64"), + "v_bool": pd.Series(v_bool[start:end], dtype="bool"), + "v_str": v_str[start:end], + })) + return chunks + + +# --------------------------------------------------------------------------- +# Benchmark runners +# --------------------------------------------------------------------------- + +def _write_arrow(file_path: str, batches): + schema = TableSchema(TABLE_NAME, [ + ColumnSchema("device", TSDataType.STRING, ColumnCategory.TAG), + ColumnSchema("v_i64", TSDataType.INT64, ColumnCategory.FIELD), + ColumnSchema("v_f64", TSDataType.DOUBLE, ColumnCategory.FIELD), + ColumnSchema("v_bool", TSDataType.BOOLEAN, ColumnCategory.FIELD), + ColumnSchema("v_str", TSDataType.STRING, ColumnCategory.FIELD), + ]) + with TsFileTableWriter(file_path, schema) as w: + for batch in batches: + w.write_arrow_batch(batch) + + +def _write_dataframe(file_path: str, chunks): + schema = TableSchema(TABLE_NAME, [ + ColumnSchema("device", TSDataType.STRING, ColumnCategory.TAG), + ColumnSchema("v_i64", TSDataType.INT64, ColumnCategory.FIELD), + ColumnSchema("v_f64", TSDataType.DOUBLE, ColumnCategory.FIELD), + ColumnSchema("v_bool", TSDataType.BOOLEAN, ColumnCategory.FIELD), + ColumnSchema("v_str", TSDataType.STRING, ColumnCategory.FIELD), + ]) + with TsFileTableWriter(file_path, schema) as w: + for chunk in chunks: + w.write_dataframe(chunk) + + +def _run_timed(label: str, func, *args, rounds: int = DEFAULT_ROUNDS, row_count: int = 0): + times = [] + for _ in range(rounds): + if os.path.exists(BENCH_FILE): + os.remove(BENCH_FILE) + t0 = time.perf_counter() + func(BENCH_FILE, *args) + times.append(time.perf_counter() - t0) + avg = sum(times) / len(times) + best = min(times) + rps = row_count / avg if avg > 0 else 0 + print(f" {label:42s} avg={avg:.3f}s best={best:.3f}s {rps:>10.0f} rows/s") + return avg + + +# --------------------------------------------------------------------------- +# Main benchmark +# --------------------------------------------------------------------------- + +def run_benchmark( + row_count: int = DEFAULT_ROW_COUNT, + batch_size: int = DEFAULT_BATCH_SIZE, + rounds: int = DEFAULT_ROUNDS, +): + print() + print(f"=== write benchmark: {row_count:,} rows, batch_size={batch_size}, rounds={rounds} ===") + + # Pre-build data once (exclude data-preparation time from timing) + arrow_batches = _make_arrow_batches(row_count, batch_size) + df_chunks = _make_dataframe_chunks(row_count, batch_size) + + df_avg = _run_timed( + "write_dataframe", + _write_dataframe, df_chunks, + rounds=rounds, row_count=row_count, + ) + arrow_avg = _run_timed( + "write_arrow_batch", + _write_arrow, arrow_batches, + rounds=rounds, row_count=row_count, + ) + + print() + if arrow_avg > 0 and df_avg > 0: + ratio = df_avg / arrow_avg + if ratio >= 1.0: + print(f" Arrow is {ratio:.2f}x faster than DataFrame") + else: + print(f" DataFrame is {1/ratio:.2f}x faster than Arrow") + print() + + if os.path.exists(BENCH_FILE): + os.remove(BENCH_FILE) + + return df_avg, arrow_avg + + +# --------------------------------------------------------------------------- +# Pytest entry points +# --------------------------------------------------------------------------- + +def test_bench_write_arrow_small(): + """Quick sanity check with small data (5 k rows).""" + run_benchmark(row_count=5_000, batch_size=1_024, rounds=2) + + +def test_bench_write_arrow_default(): + """Default benchmark (100 k rows).""" + run_benchmark( + row_count=DEFAULT_ROW_COUNT, + batch_size=DEFAULT_BATCH_SIZE, + rounds=DEFAULT_ROUNDS, + ) + + +def test_bench_write_arrow_large(): + """Large benchmark (1 M rows).""" + run_benchmark(row_count=10_000_000, batch_size=32_384, rounds=3) + + +# --------------------------------------------------------------------------- +# Script entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + row_count = int(sys.argv[1]) if len(sys.argv) > 1 else DEFAULT_ROW_COUNT + batch_size = int(sys.argv[2]) if len(sys.argv) > 2 else DEFAULT_BATCH_SIZE + run_benchmark(row_count=row_count, batch_size=batch_size) diff --git a/python/tests/test_write_arrow.py b/python/tests/test_write_arrow.py new file mode 100644 index 000000000..19c8abc2e --- /dev/null +++ b/python/tests/test_write_arrow.py @@ -0,0 +1,368 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +Tests for write_arrow_batch: write PyArrow RecordBatch/Table to tsfile +and verify correctness by reading back. +""" + +import os +from datetime import date + +import numpy as np +import pandas as pd +import pytest + +pa = pytest.importorskip("pyarrow", reason="pyarrow is not installed") + +from tsfile import ColumnCategory, ColumnSchema, TableSchema, TSDataType, TsFileReader +from tsfile import TsFileTableWriter + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_schema(table_name, extra_cols): + """Build a TableSchema with a string TAG 'device' plus the given field cols.""" + return TableSchema( + table_name, + [ColumnSchema("device", TSDataType.STRING, ColumnCategory.TAG)] + extra_cols, + ) + + +def _read_all_arrow(file_path, table_name, columns, start=0, end=10**18, batch_size=4096): + """Read all rows from file via read_arrow_batch and return as a pa.Table.""" + reader = TsFileReader(file_path) + rs = reader.query_table_batch( + table_name=table_name, + column_names=columns, + start_time=start, + end_time=end, + batch_size=batch_size, + ) + batches = [] + while True: + batch = rs.read_arrow_batch() + if batch is None: + break + batches.append(batch) + rs.close() + reader.close() + if not batches: + return pa.table({}) + return pa.concat_tables(batches) + + +# --------------------------------------------------------------------------- +# Basic write + read-back +# --------------------------------------------------------------------------- + +def test_write_arrow_basic(): + """Write 1 000 rows via write_arrow_batch and verify count + values.""" + path = "test_write_arrow_basic.tsfile" + table_name = "t" + n = 1000 + + schema = _make_schema(table_name, [ + ColumnSchema("value1", TSDataType.INT64, ColumnCategory.FIELD), + ColumnSchema("value2", TSDataType.DOUBLE, ColumnCategory.FIELD), + ]) + + batch = pa.record_batch({ + "time": pa.array(np.arange(n, dtype="int64"), type=pa.timestamp("ns")), + "device": pa.array([f"d{i}" for i in range(n)], type=pa.string()), + "value1": pa.array(np.arange(n, dtype="int64"), type=pa.int64()), + "value2": pa.array(np.arange(n, dtype="float64") * 1.5, type=pa.float64()), + }) + + try: + if os.path.exists(path): + os.remove(path) + with TsFileTableWriter(path, schema) as w: + w.write_arrow_batch(batch) + + result = _read_all_arrow(path, table_name, ["device", "value1", "value2"]) + assert len(result) == n + + df = result.to_pandas().sort_values("time").reset_index(drop=True) + assert list(df["value1"]) == list(range(n)) + assert all(abs(df["value2"].iloc[i] - i * 1.5) < 1e-9 for i in range(n)) + finally: + if os.path.exists(path): + os.remove(path) + + +# --------------------------------------------------------------------------- +# pa.Table input +# --------------------------------------------------------------------------- + +def test_write_arrow_from_table(): + """write_arrow_batch should accept pa.Table (multi-chunk) as well.""" + path = "test_write_arrow_from_table.tsfile" + table_name = "t" + n = 500 + + schema = _make_schema(table_name, [ + ColumnSchema("v", TSDataType.INT32, ColumnCategory.FIELD), + ]) + + tbl = pa.table({ + "time": pa.array(np.arange(n, dtype="int64"), type=pa.timestamp("ns")), + "device": pa.array(["dev"] * n, type=pa.string()), + "v": pa.array(np.arange(n, dtype="int32"), type=pa.int32()), + }) + + try: + if os.path.exists(path): + os.remove(path) + with TsFileTableWriter(path, schema) as w: + w.write_arrow_batch(tbl) + + result = _read_all_arrow(path, table_name, ["device", "v"]) + assert len(result) == n + df = result.to_pandas().sort_values("time").reset_index(drop=True) + assert list(df["v"]) == list(range(n)) + finally: + if os.path.exists(path): + os.remove(path) + + +# --------------------------------------------------------------------------- +# Multiple batches +# --------------------------------------------------------------------------- + +def test_write_arrow_multiple_batches(): + """Write several batches sequentially and verify the total row count.""" + path = "test_write_arrow_multi.tsfile" + table_name = "t" + rows_per_batch = 300 + num_batches = 4 + total = rows_per_batch * num_batches + + schema = _make_schema(table_name, [ + ColumnSchema("v", TSDataType.INT64, ColumnCategory.FIELD), + ]) + + try: + if os.path.exists(path): + os.remove(path) + with TsFileTableWriter(path, schema) as w: + for b in range(num_batches): + start_ts = b * rows_per_batch + batch = pa.record_batch({ + "time": pa.array( + np.arange(start_ts, start_ts + rows_per_batch, dtype="int64"), + type=pa.timestamp("ns")), + "device": pa.array(["dev"] * rows_per_batch, type=pa.string()), + "v": pa.array( + np.arange(start_ts, start_ts + rows_per_batch, dtype="int64"), + type=pa.int64()), + }) + w.write_arrow_batch(batch) + + result = _read_all_arrow(path, table_name, ["device", "v"]) + assert len(result) == total + finally: + if os.path.exists(path): + os.remove(path) + + +# --------------------------------------------------------------------------- +# All supported data types +# --------------------------------------------------------------------------- + +def test_write_arrow_all_datatypes(): + """Write every supported data type and verify values read back correctly.""" + path = "test_write_arrow_all_types.tsfile" + table_name = "t" + n = 200 + + schema = TableSchema(table_name, [ + ColumnSchema("tag", TSDataType.STRING, ColumnCategory.TAG), + ColumnSchema("bool_col", TSDataType.BOOLEAN, ColumnCategory.FIELD), + ColumnSchema("int32_col", TSDataType.INT32, ColumnCategory.FIELD), + ColumnSchema("int64_col", TSDataType.INT64, ColumnCategory.FIELD), + ColumnSchema("float_col", TSDataType.FLOAT, ColumnCategory.FIELD), + ColumnSchema("double_col", TSDataType.DOUBLE, ColumnCategory.FIELD), + ColumnSchema("str_col", TSDataType.STRING, ColumnCategory.FIELD), + ColumnSchema("date_col", TSDataType.DATE, ColumnCategory.FIELD), + ]) + + dates_days = [ + (date(2025, 1, (i % 28) + 1) - date(1970, 1, 1)).days for i in range(n) + ] + + batch = pa.record_batch({ + "time": pa.array(np.arange(n, dtype="int64"), type=pa.timestamp("ns")), + "tag": pa.array([f"dev{i}" for i in range(n)], type=pa.string()), + "bool_col": pa.array([i % 2 == 0 for i in range(n)], type=pa.bool_()), + "int32_col": pa.array(np.arange(n, dtype="int32"), type=pa.int32()), + "int64_col": pa.array(np.arange(n, dtype="int64") * 10, type=pa.int64()), + "float_col": pa.array(np.arange(n, dtype="float32") * 0.5, type=pa.float32()), + "double_col": pa.array(np.arange(n, dtype="float64") * 1.1, type=pa.float64()), + "str_col": pa.array([f"s{i}" for i in range(n)], type=pa.string()), + "date_col": pa.array(dates_days, type=pa.date32()), + }) + + try: + if os.path.exists(path): + os.remove(path) + with TsFileTableWriter(path, schema) as w: + w.write_arrow_batch(batch) + + result = _read_all_arrow( + path, table_name, + ["tag", "bool_col", "int32_col", "int64_col", + "float_col", "double_col", "str_col", "date_col"], + ) + assert len(result) == n + df = result.to_pandas().sort_values("time").reset_index(drop=True) + + for col in ["tag", "bool_col", "int32_col", "int64_col", + "float_col", "double_col", "str_col", "date_col"]: + assert col in df.columns, f"Column '{col}' missing from result" + + assert list(df["int32_col"]) == list(range(n)) + assert list(df["int64_col"]) == [i * 10 for i in range(n)] + for i in range(n): + assert df["bool_col"].iloc[i] == (i % 2 == 0) + assert abs(df["double_col"].iloc[i] - i * 1.1) < 1e-9 + assert df["str_col"].iloc[i] == f"s{i}" + finally: + if os.path.exists(path): + os.remove(path) + + +# --------------------------------------------------------------------------- +# Parity with write_dataframe +# --------------------------------------------------------------------------- + +def test_write_arrow_parity_with_dataframe(): + """Data written via write_arrow_batch must match data written via write_dataframe.""" + arrow_path = "test_write_arrow_parity_arrow.tsfile" + df_path = "test_write_arrow_parity_df.tsfile" + table_name = "t" + n = 500 + + schema_arrow = TableSchema(table_name, [ + ColumnSchema("device", TSDataType.STRING, ColumnCategory.TAG), + ColumnSchema("v_i32", TSDataType.INT32, ColumnCategory.FIELD), + ColumnSchema("v_f64", TSDataType.DOUBLE, ColumnCategory.FIELD), + ColumnSchema("v_bool", TSDataType.BOOLEAN, ColumnCategory.FIELD), + ColumnSchema("v_str", TSDataType.STRING, ColumnCategory.FIELD), + ]) + schema_df = TableSchema(table_name, [ + ColumnSchema("device", TSDataType.STRING, ColumnCategory.TAG), + ColumnSchema("v_i32", TSDataType.INT32, ColumnCategory.FIELD), + ColumnSchema("v_f64", TSDataType.DOUBLE, ColumnCategory.FIELD), + ColumnSchema("v_bool", TSDataType.BOOLEAN, ColumnCategory.FIELD), + ColumnSchema("v_str", TSDataType.STRING, ColumnCategory.FIELD), + ]) + + timestamps = np.arange(n, dtype="int64") + v_i32 = np.arange(n, dtype="int32") + v_f64 = np.arange(n, dtype="float64") * 2.5 + v_bool = np.array([i % 3 == 0 for i in range(n)]) + v_str = [f"row{i}" for i in range(n)] + device = ["dev"] * n + + batch = pa.record_batch({ + "time": pa.array(timestamps, type=pa.timestamp("ns")), + "device": pa.array(device, type=pa.string()), + "v_i32": pa.array(v_i32, type=pa.int32()), + "v_f64": pa.array(v_f64, type=pa.float64()), + "v_bool": pa.array(v_bool, type=pa.bool_()), + "v_str": pa.array(v_str, type=pa.string()), + }) + + dataframe = pd.DataFrame({ + "time": pd.Series(timestamps, dtype="int64"), + "device": device, + "v_i32": pd.Series(v_i32, dtype="int32"), + "v_f64": pd.Series(v_f64, dtype="float64"), + "v_bool": pd.Series(v_bool, dtype="bool"), + "v_str": v_str, + }) + + cols = ["device", "v_i32", "v_f64", "v_bool", "v_str"] + + try: + for p in (arrow_path, df_path): + if os.path.exists(p): + os.remove(p) + + with TsFileTableWriter(arrow_path, schema_arrow) as w: + w.write_arrow_batch(batch) + with TsFileTableWriter(df_path, schema_df) as w: + w.write_dataframe(dataframe) + + result_arrow = _read_all_arrow(arrow_path, table_name, cols).to_pandas() + result_df = _read_all_arrow(df_path, table_name, cols).to_pandas() + + result_arrow = result_arrow.sort_values("time").reset_index(drop=True) + result_df = result_df.sort_values("time").reset_index(drop=True) + + assert len(result_arrow) == len(result_df) == n + + assert list(result_arrow["v_i32"]) == list(result_df["v_i32"]) + assert list(result_arrow["v_str"]) == list(result_df["v_str"]) + assert list(result_arrow["v_bool"]) == list(result_df["v_bool"]) + for i in range(n): + assert abs(result_arrow["v_f64"].iloc[i] - result_df["v_f64"].iloc[i]) < 1e-9 + finally: + for p in (arrow_path, df_path): + if os.path.exists(p): + os.remove(p) + + +# --------------------------------------------------------------------------- +# Large batch +# --------------------------------------------------------------------------- + +def test_write_arrow_large_batch(): + """Write a single large batch (100 k rows) and verify row count.""" + path = "test_write_arrow_large.tsfile" + table_name = "t" + n = 100_000 + + schema = _make_schema(table_name, [ + ColumnSchema("v", TSDataType.DOUBLE, ColumnCategory.FIELD), + ]) + + batch = pa.record_batch({ + "time": pa.array(np.arange(n, dtype="int64"), type=pa.timestamp("ns")), + "device": pa.array(["d"] * n, type=pa.string()), + "v": pa.array(np.random.rand(n), type=pa.float64()), + }) + + try: + if os.path.exists(path): + os.remove(path) + with TsFileTableWriter(path, schema) as w: + w.write_arrow_batch(batch) + + result = _read_all_arrow(path, table_name, ["device", "v"], batch_size=8192) + assert len(result) == n + finally: + if os.path.exists(path): + os.remove(path) + + +if __name__ == "__main__": + os.chdir(os.path.dirname(os.path.abspath(__file__))) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/tsfile/tsfile_cpp.pxd b/python/tsfile/tsfile_cpp.pxd index e82092ed6..021e14e23 100644 --- a/python/tsfile/tsfile_cpp.pxd +++ b/python/tsfile/tsfile_cpp.pxd @@ -248,6 +248,12 @@ cdef extern from "./tsfile_cwrapper.h": ArrowArray* out_array, ArrowSchema* out_schema); + # Arrow batch writing function + ErrorCode _tsfile_writer_write_arrow_table(TsFileWriter writer, + const char* table_name, + ArrowArray* array, + ArrowSchema* schema); + cdef extern from "./common/config/config.h" namespace "common": diff --git a/python/tsfile/tsfile_table_writer.py b/python/tsfile/tsfile_table_writer.py index a8f7805d3..1db3d94db 100644 --- a/python/tsfile/tsfile_table_writer.py +++ b/python/tsfile/tsfile_table_writer.py @@ -182,6 +182,17 @@ class TsFileTableWriter: self.writer.write_dataframe(self.tableSchema.get_table_name(), dataframe, self.tableSchema) + def write_arrow_batch(self, data): + """ + Write a PyArrow RecordBatch or Table into tsfile using Arrow C Data + Interface for efficient batch writing without Python-level row loops. + :param data: pyarrow.RecordBatch or pyarrow.Table. Must include a + timestamp-typed column (pa.timestamp) which is used as the row + timestamps. All other columns must match the registered schema. + :return: no return value. + """ + self.writer.write_arrow_batch(self.tableSchema.get_table_name(), data) + def close(self): """ Close TsFileTableWriter and will flush data automatically. diff --git a/python/tsfile/tsfile_writer.pyx b/python/tsfile/tsfile_writer.pyx index 4826ef72d..e1d494a19 100644 --- a/python/tsfile/tsfile_writer.pyx +++ b/python/tsfile/tsfile_writer.pyx @@ -16,11 +16,14 @@ # under the License. # import pandas +import pyarrow as pa from tsfile.row_record import RowRecord from tsfile.schema import TableSchema as TableSchemaPy from tsfile.schema import TimeseriesSchema as TimeseriesSchemaPy, DeviceSchema as DeviceSchemaPy from tsfile.tablet import Tablet as TabletPy +from libc.string cimport memset +from libc.stdint cimport uintptr_t from .tsfile_cpp cimport * from .tsfile_py_cpp cimport * @@ -122,6 +125,40 @@ cdef class TsFileWriterPy: finally: free_c_tablet(ctablet) + def write_arrow_batch(self, table_name: str, data): + """ + Write an Arrow RecordBatch or Table into tsfile using Arrow C Data + Interface for efficient batch writing without Python-level row loops. + table_name: target table name (must be registered) + data: pyarrow.RecordBatch or pyarrow.Table + """ + if isinstance(data, pa.Table): + data = data.combine_chunks().to_batches() + if not data: + return + data = data[0] + + cdef ArrowArray arrow_array + cdef ArrowSchema arrow_schema + cdef ErrorCode errno + memset(&arrow_array, 0, sizeof(ArrowArray)) + memset(&arrow_schema, 0, sizeof(ArrowSchema)) + + cdef uintptr_t array_ptr = <uintptr_t>&arrow_array + cdef uintptr_t schema_ptr = <uintptr_t>&arrow_schema + data._export_to_c(array_ptr, schema_ptr) + + cdef bytes tname = table_name.lower().encode('utf-8') + try: + errno = _tsfile_writer_write_arrow_table( + self.writer, tname, &arrow_array, &arrow_schema) + check_error(errno) + finally: + if arrow_array.release != NULL: + arrow_array.release(&arrow_array) + if arrow_schema.release != NULL: + arrow_schema.release(&arrow_schema) + cpdef close(self): """ Flush data and Close tsfile writer.
