This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e30c2a  Use bulk append (#150)
4e30c2a is described below

commit 4e30c2ab39f9468e3ab742a8a69ad909e9e85539
Author: Sutou Kouhei <[email protected]>
AuthorDate: Sat Oct 28 17:10:42 2023 +0900

    Use bulk append (#150)
    
    Closes GH-149
---
 src/afs.cc | 445 ++++++++++++++++++++++++++++++++++++++++++++++---------------
 1 file changed, 340 insertions(+), 105 deletions(-)

diff --git a/src/afs.cc b/src/afs.cc
index 44f8d43..aa2df7a 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -970,79 +970,298 @@ class ArrowPGValueConverter : public arrow::ArrayVisitor 
{
        Datum& datum_;
 };
 
-class PGArrowValueConverter : public arrow::ArrayVisitor {
+class ArrowArrayBuilderBase {
    public:
-       explicit PGArrowValueConverter(Form_pg_attribute attribute) : 
attribute_(attribute) {}
-
-       arrow::Result<std::shared_ptr<arrow::DataType>> convert_type() const
-       {
-               switch (attribute_->atttypid)
-               {
-                       case INT2OID:
-                               return arrow::int16();
-                       case INT4OID:
-                               return arrow::int32();
-                       case INT8OID:
-                               return arrow::int64();
-                       case FLOAT4OID:
-                               return arrow::float32();
-                       case FLOAT8OID:
-                               return arrow::float64();
-                       case VARCHAROID:
-                       case TEXTOID:
-                               return arrow::utf8();
-                       case BYTEAOID:
-                               return arrow::binary();
-                       case TIMESTAMPOID:
-                               return arrow::timestamp(arrow::TimeUnit::MICRO);
-                       default:
-                               return 
arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
-                                                                    
attribute_->atttypid);
-               }
-       }
-
-       arrow::Status convert_value(arrow::ArrayBuilder* builder, Datum datum) 
const
-       {
-               switch (attribute_->atttypid)
-               {
-                       case INT2OID:
-                               return 
static_cast<arrow::Int16Builder*>(builder)->Append(
-                                       DatumGetInt16(datum));
-                       case INT4OID:
-                               return 
static_cast<arrow::Int32Builder*>(builder)->Append(
-                                       DatumGetInt32(datum));
-                       case INT8OID:
-                               return 
static_cast<arrow::Int64Builder*>(builder)->Append(
-                                       DatumGetInt64(datum));
-                       case FLOAT4OID:
-                               return 
static_cast<arrow::FloatBuilder*>(builder)->Append(
-                                       DatumGetFloat4(datum));
-                       case FLOAT8OID:
-                               return 
static_cast<arrow::DoubleBuilder*>(builder)->Append(
-                                       DatumGetFloat8(datum));
-                       case VARCHAROID:
-                       case TEXTOID:
-                               return 
static_cast<arrow::StringBuilder*>(builder)->Append(
-                                       VARDATA_ANY(datum), 
VARSIZE_ANY_EXHDR(datum));
-                       case BYTEAOID:
-                               return 
static_cast<arrow::BinaryBuilder*>(builder)->Append(
-                                       VARDATA_ANY(datum), 
VARSIZE_ANY_EXHDR(datum));
-                       case TIMESTAMPOID:
-                               // Arrow uses UNIX epoch (1970-01-01) but 
PostgreSQL
-                               // uses 2000-01-01.
-                               return 
static_cast<arrow::TimestampBuilder*>(builder)->Append(
-                                       DatumGetTimestamp(datum) +
-                                       (POSTGRES_EPOCH_JDATE - 
UNIX_EPOCH_JDATE) * USECS_PER_DAY);
-                       default:
-                               return 
arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
-                                                                    
attribute_->atttypid);
+       static arrow::Result<std::unique_ptr<ArrowArrayBuilderBase>> make(
+               Form_pg_attribute attribute, int iAttribute, 
arrow::ArrayBuilder* builder);
+
+       static arrow::Result<std::shared_ptr<arrow::DataType>> arrow_type(
+               Form_pg_attribute attribute);
+
+       ArrowArrayBuilderBase(Form_pg_attribute attribute, int iAttribute)
+               : attribute_(attribute), iAttribute_(iAttribute)
+       {
+       }
+
+       virtual ~ArrowArrayBuilderBase() = default;
+
+       arrow::Status build(uint64_t iTuple, uint64_t iTupleEnd)
+       {
+               if (attribute_->attnotnull)
+               {
+                       return build_not_null(iTuple, iTupleEnd);
+               }
+               else
+               {
+                       return build_may_null(iTuple, iTupleEnd);
                }
        }
 
-   private:
+   protected:
        Form_pg_attribute attribute_;
+       int iAttribute_;
+
+       virtual arrow::Status build_not_null(uint64_t iTuple, uint64_t 
iTupleEnd) = 0;
+       virtual arrow::Status build_may_null(uint64_t iTuple, uint64_t 
iTupleEnd) = 0;
+};
+
+template <typename ArrowType, typename Enable = void>
+class ArrowArrayBuilder;
+
+template <typename ArrowType>
+class ArrowArrayBuilder<ArrowType, arrow::enable_if_has_c_type<ArrowType>>
+       : public ArrowArrayBuilderBase {
+   public:
+       ArrowArrayBuilder(Form_pg_attribute attribute,
+                         int iAttribute,
+                         arrow::ArrayBuilder* builder)
+               : ArrowArrayBuilderBase(attribute, iAttribute),
+                 builder_(static_cast<BuilderType*>(builder)),
+                 values_(),
+                 validBytes_()
+       {
+       }
+
+   private:
+       using CType = typename arrow::TypeTraits<ArrowType>::CType;
+       using BuilderType = typename arrow::TypeTraits<ArrowType>::BuilderType;
+
+       BuilderType* builder_;
+       std::vector<CType> values_;
+       std::vector<uint8_t> validBytes_;
+
+       arrow::Status build_not_null(uint64_t iTuple, uint64_t iTupleEnd)
+       {
+               values_.resize(iTupleEnd - iTuple);
+               for (; iTuple < iTupleEnd; iTuple++)
+               {
+                       bool isNull;
+                       auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple],
+                                                  SPI_tuptable->tupdesc,
+                                                  iAttribute_ + 1,
+                                                  &isNull);
+                       values_[iTuple] = convert_value<ArrowType>(datum);
+               }
+               return builder_->AppendValues(values_.data(), values_.size());
+       }
+
+       arrow::Status build_may_null(uint64_t iTuple, uint64_t iTupleEnd)
+       {
+               bool haveNull = false;
+               values_.resize(iTupleEnd - iTuple);
+               validBytes_.resize(iTupleEnd - iTuple);
+               for (; iTuple < iTupleEnd; iTuple++)
+               {
+                       bool isNull;
+                       auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple],
+                                                  SPI_tuptable->tupdesc,
+                                                  iAttribute_ + 1,
+                                                  &isNull);
+                       if (isNull)
+                       {
+                               haveNull = true;
+                               validBytes_[iTuple] = 0;
+                       }
+                       else
+                       {
+                               validBytes_[iTuple] = 1;
+                               values_[iTuple] = 
convert_value<ArrowType>(datum);
+                       }
+               }
+               if (haveNull)
+               {
+                       return builder_->AppendValues(
+                               values_.data(), values_.size(), 
validBytes_.data());
+               }
+               else
+               {
+                       return builder_->AppendValues(values_.data(), 
values_.size());
+               }
+       }
+
+       template <typename TargetArrowType>
+       std::enable_if_t<std::is_same_v<TargetArrowType, arrow::Int16Type>, 
int16_t>
+       convert_value(Datum datum)
+       {
+               return DatumGetInt16(datum);
+       }
+
+       template <typename TargetArrowType>
+       std::enable_if_t<std::is_same_v<TargetArrowType, arrow::Int32Type>, 
int32_t>
+       convert_value(Datum datum)
+       {
+               return DatumGetInt32(datum);
+       }
+
+       template <typename TargetArrowType>
+       std::enable_if_t<std::is_same_v<TargetArrowType, arrow::Int64Type>, 
int64_t>
+       convert_value(Datum datum)
+       {
+               return DatumGetInt64(datum);
+       }
+
+       template <typename TargetArrowType>
+       std::enable_if_t<std::is_same_v<TargetArrowType, arrow::FloatType>, 
float>
+       convert_value(Datum datum)
+       {
+               return DatumGetFloat4(datum);
+       }
+
+       template <typename TargetArrowType>
+       std::enable_if_t<std::is_same_v<TargetArrowType, arrow::DoubleType>, 
double>
+       convert_value(Datum datum)
+       {
+               return DatumGetFloat8(datum);
+       }
+
+       template <typename TargetArrowType>
+       std::enable_if_t<std::is_same_v<TargetArrowType, arrow::TimestampType>, 
int64_t>
+       convert_value(Datum datum)
+       {
+               // Arrow uses UNIX epoch (1970-01-01) but PostgreSQL
+               // uses 2000-01-01.
+               return DatumGetTimestamp(datum) +
+                      ((POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * 
USECS_PER_DAY);
+       }
 };
 
+template <typename ArrowType>
+class ArrowArrayBuilder<ArrowType, arrow::enable_if_base_binary<ArrowType>>
+       : public ArrowArrayBuilderBase {
+   public:
+       ArrowArrayBuilder(Form_pg_attribute attribute,
+                         int iAttribute,
+                         arrow::ArrayBuilder* builder)
+               : ArrowArrayBuilderBase(attribute, iAttribute),
+                 builder_(static_cast<BuilderType*>(builder)),
+                 values_(),
+                 validBytes_()
+       {
+       }
+
+   private:
+       using BuilderType = typename arrow::TypeTraits<ArrowType>::BuilderType;
+
+       BuilderType* builder_;
+       std::vector<std::string> values_;
+       std::vector<uint8_t> validBytes_;
+
+       arrow::Status build_not_null(uint64_t iTuple, uint64_t iTupleEnd)
+       {
+               values_.resize(iTupleEnd - iTuple);
+               for (; iTuple < iTupleEnd; iTuple++)
+               {
+                       bool isNull;
+                       auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple],
+                                                  SPI_tuptable->tupdesc,
+                                                  iAttribute_ + 1,
+                                                  &isNull);
+                       values_[iTuple] = std::string(VARDATA_ANY(datum), 
VARSIZE_ANY_EXHDR(datum));
+               }
+               return builder_->AppendValues(values_);
+       }
+
+       arrow::Status build_may_null(uint64_t iTuple, uint64_t iTupleEnd)
+       {
+               bool haveNull = false;
+               values_.resize(iTupleEnd - iTuple);
+               validBytes_.resize(iTupleEnd - iTuple);
+               for (; iTuple < iTupleEnd; iTuple++)
+               {
+                       bool isNull;
+                       auto datum = SPI_getbinval(SPI_tuptable->vals[iTuple],
+                                                  SPI_tuptable->tupdesc,
+                                                  iAttribute_ + 1,
+                                                  &isNull);
+                       if (isNull)
+                       {
+                               haveNull = true;
+                               validBytes_[iTuple] = 0;
+                       }
+                       else
+                       {
+                               validBytes_[iTuple] = 1;
+                               values_[iTuple] =
+                                       std::string(VARDATA_ANY(datum), 
VARSIZE_ANY_EXHDR(datum));
+                       }
+               }
+               if (haveNull)
+               {
+                       return builder_->AppendValues(values_, 
validBytes_.data());
+               }
+               else
+               {
+                       return builder_->AppendValues(values_);
+               }
+       }
+};
+
+arrow::Result<std::unique_ptr<ArrowArrayBuilderBase>>
+ArrowArrayBuilderBase::make(Form_pg_attribute attribute,
+                            int iAttribute,
+                            arrow::ArrayBuilder* builder)
+{
+       switch (attribute->atttypid)
+       {
+               case INT2OID:
+                       return 
std::make_unique<ArrowArrayBuilder<arrow::Int16Type>>(
+                               attribute, iAttribute, builder);
+               case INT4OID:
+                       return 
std::make_unique<ArrowArrayBuilder<arrow::Int32Type>>(
+                               attribute, iAttribute, builder);
+               case INT8OID:
+                       return 
std::make_unique<ArrowArrayBuilder<arrow::Int64Type>>(
+                               attribute, iAttribute, builder);
+               case FLOAT4OID:
+                       return 
std::make_unique<ArrowArrayBuilder<arrow::FloatType>>(
+                               attribute, iAttribute, builder);
+               case FLOAT8OID:
+                       return 
std::make_unique<ArrowArrayBuilder<arrow::DoubleType>>(
+                               attribute, iAttribute, builder);
+               case VARCHAROID:
+               case TEXTOID:
+                       return 
std::make_unique<ArrowArrayBuilder<arrow::StringType>>(
+                               attribute, iAttribute, builder);
+               case BYTEAOID:
+                       return 
std::make_unique<ArrowArrayBuilder<arrow::BinaryType>>(
+                               attribute, iAttribute, builder);
+               case TIMESTAMPOID:
+                       return 
std::make_unique<ArrowArrayBuilder<arrow::TimestampType>>(
+                               attribute, iAttribute, builder);
+               default:
+                       return arrow::Status::NotImplemented("Unsupported 
PostgreSQL type: ",
+                                                            
attribute->atttypid);
+       }
+}
+
+arrow::Result<std::shared_ptr<arrow::DataType>>
+ArrowArrayBuilderBase::arrow_type(Form_pg_attribute attribute)
+{
+       switch (attribute->atttypid)
+       {
+               case INT2OID:
+                       return arrow::int16();
+               case INT4OID:
+                       return arrow::int32();
+               case INT8OID:
+                       return arrow::int64();
+               case FLOAT4OID:
+                       return arrow::float32();
+               case FLOAT8OID:
+                       return arrow::float64();
+               case VARCHAROID:
+               case TEXTOID:
+                       return arrow::utf8();
+               case BYTEAOID:
+                       return arrow::binary();
+               case TIMESTAMPOID:
+                       return arrow::timestamp(arrow::TimeUnit::MICRO);
+               default:
+                       return arrow::Status::NotImplemented("Unsupported 
PostgreSQL type: ",
+                                                            
attribute->atttypid);
+       }
+}
+
 class PreparedStatement {
    public:
        explicit PreparedStatement(std::string query)
@@ -1570,14 +1789,12 @@ class Executor : public WorkerProcessor {
        arrow::Status write(const char* tag)
        {
                SharedRingBufferOutputStream output(this, session_);
-               std::vector<PGArrowValueConverter> converters;
                std::vector<std::shared_ptr<arrow::Field>> fields;
                for (int i = 0; i < SPI_tuptable->tupdesc->natts; ++i)
                {
                        auto attribute = TupleDescAttr(SPI_tuptable->tupdesc, 
i);
-                       converters.emplace_back(attribute);
-                       const auto& converter = converters[converters.size() - 
1];
-                       ARROW_ASSIGN_OR_RAISE(auto type, 
converter.convert_type());
+                       ARROW_ASSIGN_OR_RAISE(auto type,
+                                             
ArrowArrayBuilderBase::arrow_type(attribute));
                        fields.push_back(arrow::field(
                                NameStr(attribute->attname), std::move(type), 
!attribute->attnotnull));
                }
@@ -1585,6 +1802,16 @@ class Executor : public WorkerProcessor {
                ARROW_ASSIGN_OR_RAISE(
                        auto builder,
                        arrow::RecordBatchBuilder::Make(schema, 
arrow::default_memory_pool()));
+               std::vector<std::unique_ptr<ArrowArrayBuilderBase>> builders;
+               for (int i = 0; i < SPI_tuptable->tupdesc->natts; ++i)
+               {
+                       auto attribute = TupleDescAttr(SPI_tuptable->tupdesc, 
i);
+                       ARROW_ASSIGN_OR_RAISE(
+                               auto array_builder,
+                               ArrowArrayBuilderBase::make(attribute, i, 
builder->GetField(i)));
+                       builders.push_back(std::move(array_builder));
+               }
+
                auto options = arrow::ipc::IpcWriteOptions::Defaults();
                options.emit_dictionary_deltas = true;
 
@@ -1601,64 +1828,72 @@ class Executor : public WorkerProcessor {
                // Write another stream format data with record batches.
                ARROW_ASSIGN_OR_RAISE(writer,
                                      arrow::ipc::MakeStreamWriter(&output, 
schema, options));
-               bool needLastFlush = false;
-               for (uint64_t iTuple = 0; iTuple < SPI_processed; ++iTuple)
+               uint64_t iTuple = 0;
+               for (; iTuple < SPI_processed; iTuple += MaxNRowsPerRecordBatch)
                {
-                       P("%s: %s: %s: write: data: record batch: %d/%d",
+                       uint64_t iTupleEnd = iTuple + MaxNRowsPerRecordBatch;
+                       if (iTupleEnd >= SPI_processed)
+                       {
+                               iTupleEnd = SPI_processed;
+                       }
+                       P("%s: %s: %s: write: data: record batch: %" PRIu64 
"/%" PRIu64,
                          Tag,
                          tag_,
                          tag,
                          iTuple,
-                         SPI_processed);
+                         iTupleEnd);
                        for (int iAttribute = 0; iAttribute < 
SPI_tuptable->tupdesc->natts;
                             ++iAttribute)
                        {
-                               P("%s: %s: %s: write: data: record batch: 
%d/%d: %d/%d",
+                               P("%s: %s: %s: write: data: record batch: %" 
PRIu64 "/%" PRIu64 ": %d/%d",
                                  Tag,
                                  tag_,
                                  tag,
                                  iTuple,
-                                 SPI_processed,
+                                 iTupleEnd,
                                  iAttribute,
                                  SPI_tuptable->tupdesc->natts);
-                               bool isNull;
-                               auto datum = 
SPI_getbinval(SPI_tuptable->vals[iTuple],
-                                                          
SPI_tuptable->tupdesc,
-                                                          iAttribute + 1,
-                                                          &isNull);
-                               auto arrayBuilder = 
builder->GetField(iAttribute);
-                               if (isNull)
-                               {
-                                       
ARROW_RETURN_NOT_OK(arrayBuilder->AppendNull());
-                               }
-                               else
-                               {
-                                       ARROW_RETURN_NOT_OK(
-                                               
converters[iAttribute].convert_value(arrayBuilder, datum));
-                               }
+                               
ARROW_RETURN_NOT_OK(builders[iAttribute]->build(iTuple, iTupleEnd));
                        }
+                       ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush());
+                       P("%s: %s: %s: write: data: WriteRecordBatch: %" PRIu64 
"/%" PRIu64,
+                         Tag,
+                         tag_,
+                         tag,
+                         iTuple,
+                         iTupleEnd);
+                       
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
+               }
 
-                       if (((iTuple + 1) % MaxNRowsPerRecordBatch) == 0)
+               if (iTuple < SPI_processed)
+               {
+                       P("%s: %s: %s: write: data: record batch: last: %" 
PRIu64 "/%" PRIu64,
+                         Tag,
+                         tag_,
+                         tag,
+                         iTuple,
+                         SPI_processed);
+                       for (int iAttribute = 0; iAttribute < 
SPI_tuptable->tupdesc->natts;
+                            ++iAttribute)
                        {
-                               ARROW_ASSIGN_OR_RAISE(recordBatch, 
builder->Flush());
-                               P("%s: %s: %s: write: data: WriteRecordBatch: 
%d/%d",
+                               P("%s: %s: %s: write: data: record batch: last: 
%" PRIu64 "/%" PRIu64
+                                 ": %d/%d",
                                  Tag,
                                  tag_,
                                  tag,
                                  iTuple,
-                                 SPI_processed);
-                               
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
-                               needLastFlush = false;
-                       }
-                       else
-                       {
-                               needLastFlush = true;
+                                 SPI_processed,
+                                 iAttribute,
+                                 SPI_tuptable->tupdesc->natts);
+                               
ARROW_RETURN_NOT_OK(builders[iAttribute]->build(iTuple, SPI_processed));
                        }
-               }
-               if (needLastFlush)
-               {
                        ARROW_ASSIGN_OR_RAISE(recordBatch, builder->Flush());
-                       P("%s: %s: %s: write: data: WriteRecordBatch", Tag, 
tag_, tag);
+                       P("%s: %s: %s: write: data: WriteRecordBatch: last: %" 
PRIu64 "/%" PRIu64,
+                         Tag,
+                         tag_,
+                         tag,
+                         iTuple,
+                         SPI_processed);
                        
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*recordBatch));
                }
                P("%s: %s: %s, write: data: Close", Tag, tag_, tag);

Reply via email to