This is an automated email from the ASF dual-hosted git repository.

quinnj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-julia.git


The following commit(s) were added to refs/heads/main by this push:
     new a3f6da7  refactor Arrow.write to support incremental writes (#277)
a3f6da7 is described below

commit a3f6da7c1d59f8321315f4955a3b45c48d38aab4
Author: Ben Baumgold <[email protected]>
AuthorDate: Sat Apr 9 13:33:17 2022 -0400

    refactor Arrow.write to support incremental writes (#277)
    
    * refactor Arrow.write to support incremental writes
    
    * bump julia compat due to dependency on interpolation in 
Base.Threads.@spawn
    
    * PR feedback
    
    * add Arrow.Writer-specific tests and in-code/manual documentation
    
    Co-authored-by: Ben Baumgold <[email protected]>
---
 .github/workflows/ci.yml     |   4 +-
 Project.toml                 |   2 +-
 docs/src/manual.md           |   4 +
 src/arraytypes/arraytypes.jl |   3 +
 src/write.jl                 | 302 ++++++++++++++++++++++++++++---------------
 test/runtests.jl             |  18 +++
 6 files changed, 229 insertions(+), 104 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index b8a0a18..4a98314 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -73,7 +73,7 @@ jobs:
             dir: './src/ArrowTypes'
         version:
           - '1.0'
-          - '1.3'
+          - '1.4'
           - '1' # automatically expands to the latest stable 1.x release of 
Julia
           - 'nightly'
         os:
@@ -87,7 +87,7 @@ jobs:
             version: '1.0'
           - pkg:
               name: ArrowTypes.jl
-            version: '1.3'
+            version: '1.4'
     steps:
       - uses: actions/checkout@v2
       - uses: julia-actions/setup-julia@v1
diff --git a/Project.toml b/Project.toml
index f5784f4..5004d76 100644
--- a/Project.toml
+++ b/Project.toml
@@ -45,7 +45,7 @@ PooledArrays = "0.5, 1.0"
 SentinelArrays = "1"
 Tables = "1.1"
 TimeZones = "1"
-julia = "1.3"
+julia = "1.4"
 
 [extras]
 CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
diff --git a/docs/src/manual.md b/docs/src/manual.md
index 9db4403..3cfd6e6 100644
--- a/docs/src/manual.md
+++ b/docs/src/manual.md
@@ -229,6 +229,10 @@ csv_parts = Tables.partitioner(CSV.File, csv_files)
 Arrow.write(io, csv_parts)
 ```
 
+### `Arrow.Writer`
+
+With `Arrow.Writer`, you instantiate an `Arrow.Writer` object, write sources 
using it, and then close it.  This allows for incrmental writes to the same 
sink.  It is similar to `Arrow.append` without having to close and re-open the 
sink in between writes and without the limitation of only supporting the IPC 
stream format.
+
 ### Multithreaded writing
 
 By default, `Arrow.write` will use multiple threads to write multiple
diff --git a/src/arraytypes/arraytypes.jl b/src/arraytypes/arraytypes.jl
index 3b3273e..88d38f8 100644
--- a/src/arraytypes/arraytypes.jl
+++ b/src/arraytypes/arraytypes.jl
@@ -70,6 +70,9 @@ end
 _normalizemeta(::Nothing) = nothing
 _normalizemeta(meta) = toidict(String(k) => String(v) for (k, v) in meta)
 
+_normalizecolmeta(::Nothing) = nothing
+_normalizecolmeta(colmeta) = toidict(Symbol(k) => toidict(String(v1) => 
String(v2) for (v1, v2) in v) for (k, v) in colmeta)
+
 function _arrowtypemeta(::Nothing, n, m)
     return toidict(("ARROW:extension:name" => n, "ARROW:extension:metadata" => 
m))
 end
diff --git a/src/write.jl b/src/write.jl
index 1fd28aa..49f519f 100644
--- a/src/write.jl
+++ b/src/write.jl
@@ -53,129 +53,244 @@ function write end
 
 write(io_or_file; kw...) = x -> write(io_or_file, x; kw...)
 
-function write(file_path, tbl; metadata=getmetadata(tbl), colmetadata=nothing, 
largelists::Bool=false, compress::Union{Nothing, Symbol, LZ4FrameCompressor, 
ZstdCompressor}=nothing, denseunions::Bool=true, dictencode::Bool=false, 
dictencodenested::Bool=false, alignment::Int=8, 
maxdepth::Int=DEFAULT_MAX_DEPTH, ntasks=Inf, file::Bool=true)
-    open(file_path, "w") do io
-        write(io, tbl, file, largelists, compress, denseunions, dictencode, 
dictencodenested, alignment, maxdepth, ntasks, metadata, colmetadata)
+function write(file_path, tbl; kwargs...)
+    open(Writer, file_path; file=true, kwargs...) do writer
+        write(writer, tbl)
     end
-    return file_path
+    file_path
 end
 
-function write(io::IO, tbl; metadata=getmetadata(tbl), colmetadata=nothing, 
largelists::Bool=false, compress::Union{Nothing, Symbol, LZ4FrameCompressor, 
ZstdCompressor}=nothing, denseunions::Bool=true, dictencode::Bool=false, 
dictencodenested::Bool=false, alignment::Int=8, 
maxdepth::Int=DEFAULT_MAX_DEPTH, ntasks=Inf, file::Bool=false)
-    return write(io, tbl, file, largelists, compress, denseunions, dictencode, 
dictencodenested, alignment, maxdepth, ntasks, metadata, colmetadata)
+struct Message
+    msgflatbuf
+    columns
+    bodylen
+    isrecordbatch::Bool
+    blockmsg::Bool
+    headerType
 end
 
-function write(io, source, writetofile, largelists, compress, denseunions, 
dictencode, dictencodenested, alignment, maxdepth, ntasks, meta, colmeta)
+struct Block
+    offset::Int64
+    metaDataLength::Int32
+    bodyLength::Int64
+end
+
+"""
+    Arrow.Writer{T<:IO}
+
+An object that can be used to incrementally write Arrow partitions
+
+# Examples
+```julia
+julia> writer = open(Arrow.Writer, tempname())
+
+julia> partition1 = (col1 = [1, 2], col2 = ["A", "B"])
+(col1 = [1, 2], col2 = ["A", "B"])
+
+julia> Arrow.write(writer, partition1)
+
+julia> partition2 = (col1 = [3, 4], col2 = ["C", "D"])
+(col1 = [3, 4], col2 = ["C", "D"])
+
+julia> Arrow.write(writer, partition2)
+
+julia> close(writer)
+```
+
+It's also possible to automatically close the Writer using a do-block:
+
+```julia
+julia> open(Arrow.Writer, tempname()) do writer
+           partition2 = (col1 = [1, 2], col2 = ["A", "B"])
+           Arrow.write(writer, partition1)
+           partition2 = (col1 = [3, 4], col2 = ["C", "D"])
+           Arrow.write(writer, partition1)
+       end
+```
+"""
+mutable struct Writer{T<:IO}
+    io::T
+    closeio::Bool
+    
compress::Union{Nothing,LZ4FrameCompressor,Vector{LZ4FrameCompressor},ZstdCompressor,Vector{ZstdCompressor}}
+    writetofile::Bool
+    largelists::Bool
+    denseunions::Bool
+    dictencode::Bool
+    dictencodenested::Bool
+    threaded::Bool
+    alignment::Int32
+    maxdepth::Int64
+    meta::Union{Nothing,Base.ImmutableDict{String,String}}
+    
colmeta::Union{Nothing,Base.ImmutableDict{Symbol,Base.ImmutableDict{String,String}}}
+    msgs::OrderedChannel{Message}
+    schema::Ref{Tables.Schema}
+    firstcols::Ref{Any}
+    dictencodings::Dict{Int64, Any}
+    blocks::NTuple{2, Vector{Block}}
+    task::Task
+    anyerror::Threads.Atomic{Bool}
+    errorref::Ref{Any}
+    partition_count::Int32
+    isclosed::Bool
+end
+
+function Base.open(::Type{Writer}, io::T, 
compress::Union{Nothing,LZ4FrameCompressor,<:AbstractVector{LZ4FrameCompressor},ZstdCompressor,<:AbstractVector{ZstdCompressor}},
 writetofile::Bool, largelists::Bool, denseunions::Bool, dictencode::Bool, 
dictencodenested::Bool, alignment::Integer, maxdepth::Integer, ntasks::Integer, 
meta::Union{Nothing,Any}, colmeta::Union{Nothing,Any}, closeio::Bool) where 
{T<:IO}
     if ntasks < 1
         throw(ArgumentError("ntasks keyword argument must be > 0; pass 
`ntasks=1` to disable multithreaded writing"))
     end
-    if compress === :lz4
-        compress = LZ4_FRAME_COMPRESSOR
-    elseif compress === :zstd
-        compress = ZSTD_COMPRESSOR
-    elseif compress isa Symbol
-        throw(ArgumentError("unsupported compress keyword argument value: 
$compress. Valid values include `:lz4` or `:zstd`"))
-    end
-    # TODO: we're probably not threadsafe if user passes own single compressor 
instance + ntasks > 1
-    # if ntasks > 1 && compres !== nothing && !(compress isa Vector)
-    #     compress = Threads.resize_nthreads!([compress])
-    # end
-    if writetofile
-        @debug 1 "starting write of arrow formatted file"
-        Base.write(io, "ARROW1\0\0")
-    end
     msgs = OrderedChannel{Message}(ntasks)
-    # build messages
-    sch = Ref{Tables.Schema}()
+    schema = Ref{Tables.Schema}()
     firstcols = Ref{Any}()
     dictencodings = Dict{Int64, Any}() # Lockable{DictEncoding}
     blocks = (Block[], Block[])
     # start message writing from channel
     threaded = ntasks > 1
-    tsk = threaded ? (Threads.@spawn for msg in msgs
-        Base.write(io, msg, blocks, sch, alignment)
+    task = threaded ? (Threads.@spawn for msg in msgs
+        Base.write(io, msg, blocks, schema, alignment)
     end) : (@async for msg in msgs
-        Base.write(io, msg, blocks, sch, alignment)
+        Base.write(io, msg, blocks, schema, alignment)
     end)
     anyerror = Threads.Atomic{Bool}(false)
     errorref = Ref{Any}()
-    @sync for (i, tbl) in enumerate(Tables.partitions(source))
-        if anyerror[]
-            @error "error writing arrow data on partition = $(errorref[][3])" 
exception=(errorref[][1], errorref[][2])
-            error("fatal error writing arrow data")
-        end
-        @debug 1 "processing table partition i = $i"
+    meta = _normalizemeta(meta)
+    colmeta = _normalizecolmeta(colmeta)
+    return Writer{T}(io, closeio, compress, writetofile, largelists, 
denseunions, dictencode, dictencodenested, threaded, alignment, maxdepth, meta, 
colmeta, msgs, schema, firstcols, dictencodings, blocks, task, anyerror, 
errorref, 1, false)
+end
+
+function Base.open(::Type{Writer}, io::IO, compress::Symbol, args...)
+    compressor = if compress === :lz4
+        LZ4_FRAME_COMPRESSOR
+    elseif compress === :zstd
+        ZSTD_COMPRESSOR
+    else
+        throw(ArgumentError("unsupported compress keyword argument value: 
$compress. Valid values include `:lz4` or `:zstd`"))
+    end
+    open(Writer, io, compressor, args...)
+end
+
+function Base.open(::Type{Writer}, io::IO; 
compress::Union{Nothing,Symbol,LZ4FrameCompressor,<:AbstractVector{LZ4FrameCompressor},ZstdCompressor,<:AbstractVector{ZstdCompressor}}=nothing,
 file::Bool=true, largelists::Bool=false, denseunions::Bool=true, 
dictencode::Bool=false, dictencodenested::Bool=false, alignment::Integer=8, 
maxdepth::Integer=DEFAULT_MAX_DEPTH, ntasks::Integer=typemax(Int32), 
metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, 
closeio::Bool = false)
+    open(Writer, io, compress, file, largelists, denseunions, dictencode, 
dictencodenested, alignment, maxdepth, ntasks, metadata, colmetadata, closeio)
+end
+
+Base.open(::Type{Writer}, file_path; kwargs...) = open(Writer, open(file_path, 
"w"); kwargs..., closeio=true)
+
+function check_errors(writer::Writer)
+    if writer.anyerror[]
+        errorref = writer.errorref[]
+        @error "error writing arrow data on partition = $(errorref[3])" 
exception=(errorref[1], errorref[2])
+        error("fatal error writing arrow data")
+    end
+end
+
+function write(writer::Writer, source)
+    @sync for tbl in Tables.partitions(source)
+        check_errors(writer)
+        @debug 1 "processing table partition $(writer.partition_count)"
         tblcols = Tables.columns(tbl)
-        if i == 1
-            cols = toarrowtable(tblcols, dictencodings, largelists, compress, 
denseunions, dictencode, dictencodenested, maxdepth, meta, colmeta)
-            sch[] = Tables.schema(cols)
-            firstcols[] = cols
-            put!(msgs, makeschemamsg(sch[], cols), i)
-            if !isempty(dictencodings)
-                des = sort!(collect(dictencodings); by=x->x.first, rev=true)
+        if !isassigned(writer.firstcols)
+            if writer.writetofile
+                @debug 1 "starting write of arrow formatted file"
+                Base.write(writer.io, "ARROW1\0\0")
+            end
+            meta = isnothing(writer.meta) ? getmetadata(source) : writer.meta
+            cols = toarrowtable(tblcols, writer.dictencodings, 
writer.largelists, writer.compress, writer.denseunions, writer.dictencode, 
writer.dictencodenested, writer.maxdepth, meta, writer.colmeta)
+            writer.schema[] = Tables.schema(cols)
+            writer.firstcols[] = cols
+            put!(writer.msgs, makeschemamsg(writer.schema[], cols), 
writer.partition_count)
+            if !isempty(writer.dictencodings)
+                des = sort!(collect(writer.dictencodings); by=x->x.first, 
rev=true)
                 for (id, delock) in des
                     # assign dict encoding ids
                     de = delock.x
                     dictsch = Tables.Schema((:col,), (eltype(de.data),))
-                    put!(msgs, makedictionarybatchmsg(dictsch, (col=de.data,), 
id, false, alignment), i)
+                    dictbatchmsg = makedictionarybatchmsg(dictsch, 
(col=de.data,), id, false, writer.alignment)
+                    put!(writer.msgs, dictbatchmsg, writer.partition_count)
                 end
             end
-            put!(msgs, makerecordbatchmsg(sch[], cols, alignment), i, true)
+            recbatchmsg = makerecordbatchmsg(writer.schema[], cols, 
writer.alignment)
+            put!(writer.msgs, recbatchmsg, writer.partition_count, true)
         else
-            if threaded
-                Threads.@spawn process_partition(tblcols, dictencodings, 
largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, 
msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
+            if writer.threaded
+                Threads.@spawn process_partition(tblcols, 
writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, 
writer.dictencode, writer.dictencodenested, writer.maxdepth, writer.msgs, 
writer.alignment, $(writer.partition_count), writer.schema, writer.errorref, 
writer.anyerror, writer.meta, writer.colmeta)
             else
-                @async process_partition(tblcols, dictencodings, largelists, 
compress, denseunions, dictencode, dictencodenested, maxdepth, msgs, alignment, 
i, sch, errorref, anyerror, meta, colmeta)
+                @async process_partition(tblcols, writer.dictencodings, 
writer.largelists, writer.compress, writer.denseunions, writer.dictencode, 
writer.dictencodenested, writer.maxdepth, writer.msgs, writer.alignment, 
$(writer.partition_count), writer.schema, writer.errorref, writer.anyerror, 
writer.meta, writer.colmeta)
             end
         end
+        writer.partition_count += 1
     end
-    if anyerror[]
-        @error "error writing arrow data on partition = $(errorref[][3])" 
exception=(errorref[][1], errorref[][2])
-        error("fatal error writing arrow data")
-    end
+    check_errors(writer)
+    return
+end
+
+function Base.close(writer::Writer)
+    writer.isclosed && return
     # close our message-writing channel, no further put!-ing is allowed
-    close(msgs)
+    close(writer.msgs)
     # now wait for our message-writing task to finish writing
-    wait(tsk)
+    !istaskfailed(writer.task) && wait(writer.task)
+    if (!isassigned(writer.schema) || !isassigned(writer.firstcols))
+        writer.closeio && close(writer.io)
+        writer.isclosed = true
+        return
+    end
     # write empty message
-    if !writetofile
-        Base.write(io, Message(UInt8[], nothing, 0, true, false, Meta.Schema), 
blocks, sch, alignment)
+    if !writer.writetofile
+        msg = Message(UInt8[], nothing, 0, true, false, Meta.Schema)
+        Base.write(writer.io, msg, writer.blocks, writer.schema, 
writer.alignment)
+        writer.closeio && close(writer.io)
+        writer.isclosed = true
+        return
     end
-    if writetofile
-        b = FlatBuffers.Builder(1024)
-        schfoot = makeschema(b, sch[], firstcols[])
-        if !isempty(blocks[1])
-            N = length(blocks[1])
-            Meta.footerStartRecordBatchesVector(b, N)
-            for blk in Iterators.reverse(blocks[1])
-                Meta.createBlock(b, blk.offset, blk.metaDataLength, 
blk.bodyLength)
-            end
-            recordbatches = FlatBuffers.endvector!(b, N)
-        else
-            recordbatches = FlatBuffers.UOffsetT(0)
+    b = FlatBuffers.Builder(1024)
+    schfoot = makeschema(b, writer.schema[], writer.firstcols[])
+    recordbatches = if !isempty(writer.blocks[1])
+        N = length(writer.blocks[1])
+        Meta.footerStartRecordBatchesVector(b, N)
+        for blk in Iterators.reverse(writer.blocks[1])
+            Meta.createBlock(b, blk.offset, blk.metaDataLength, blk.bodyLength)
         end
-        if !isempty(blocks[2])
-            N = length(blocks[2])
-            Meta.footerStartDictionariesVector(b, N)
-            for blk in Iterators.reverse(blocks[2])
-                Meta.createBlock(b, blk.offset, blk.metaDataLength, 
blk.bodyLength)
-            end
-            dicts = FlatBuffers.endvector!(b, N)
-        else
-            dicts = FlatBuffers.UOffsetT(0)
+        FlatBuffers.endvector!(b, N)
+    else
+        FlatBuffers.UOffsetT(0)
+    end
+    dicts = if !isempty(writer.blocks[2])
+        N = length(writer.blocks[2])
+        Meta.footerStartDictionariesVector(b, N)
+        for blk in Iterators.reverse(writer.blocks[2])
+            Meta.createBlock(b, blk.offset, blk.metaDataLength, blk.bodyLength)
         end
-        Meta.footerStart(b)
-        Meta.footerAddVersion(b, Meta.MetadataVersions.V4)
-        Meta.footerAddSchema(b, schfoot)
-        Meta.footerAddDictionaries(b, dicts)
-        Meta.footerAddRecordBatches(b, recordbatches)
-        foot = Meta.footerEnd(b)
-        FlatBuffers.finish!(b, foot)
-        footer = FlatBuffers.finishedbytes(b)
-        Base.write(io, footer)
-        Base.write(io, Int32(length(footer)))
-        Base.write(io, "ARROW1")
+        FlatBuffers.endvector!(b, N)
+    else
+        FlatBuffers.UOffsetT(0)
+    end
+    Meta.footerStart(b)
+    Meta.footerAddVersion(b, Meta.MetadataVersions.V4)
+    Meta.footerAddSchema(b, schfoot)
+    Meta.footerAddDictionaries(b, dicts)
+    Meta.footerAddRecordBatches(b, recordbatches)
+    foot = Meta.footerEnd(b)
+    FlatBuffers.finish!(b, foot)
+    footer = FlatBuffers.finishedbytes(b)
+    Base.write(writer.io, footer)
+    Base.write(writer.io, Int32(length(footer)))
+    Base.write(writer.io, "ARROW1")
+    writer.closeio && close(writer.io)
+    writer.isclosed = true
+    nothing
+end
+
+function write(io::IO, tbl; kwargs...)
+    open(Writer, io; file=false, kwargs...) do writer
+        write(writer, tbl)
+    end
+    io
+end
+
+function write(io, source, writetofile, largelists, compress, denseunions, 
dictencode, dictencodenested, alignment, maxdepth, ntasks, meta, colmeta)
+    open(Writer, io, compress, writetofile, largelists, denseunions, 
dictencode, dictencodenested, alignment, maxdepth, ntasks, meta, colmeta) do 
writer
+        write(writer, source)
     end
-    return io
+    io
 end
 
 function process_partition(cols, dictencodings, largelists, compress, 
denseunions, dictencode, dictencodenested, maxdepth, msgs, alignment, i, sch, 
errorref, anyerror, meta, colmeta)
@@ -229,21 +344,6 @@ Tables.schema(x::ToArrowTable) = x.sch
 Tables.columnnames(x::ToArrowTable) = x.sch.names
 Tables.getcolumn(x::ToArrowTable, i::Int) = x.cols[i]
 
-struct Message
-    msgflatbuf
-    columns
-    bodylen
-    isrecordbatch::Bool
-    blockmsg::Bool
-    headerType
-end
-
-struct Block
-    offset::Int64
-    metaDataLength::Int32
-    bodyLength::Int64
-end
-
 function Base.write(io::IO, msg::Message, blocks, sch, alignment)
     metalen = padding(length(msg.msgflatbuf), alignment)
     @debug 1 "writing message: metalen = $metalen, bodylen = $(msg.bodylen), 
isrecordbatch = $(msg.isrecordbatch), headerType = $(msg.headerType)"
diff --git a/test/runtests.jl b/test/runtests.jl
index a59156c..4bd6c16 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -509,6 +509,24 @@ t2 = Arrow.Table(Arrow.tobuffer(t1))
 t3 = Arrow.Table(Arrow.tobuffer(t2))
 @test t3.x == t1.x
 
+@testset "Writer" begin
+    io = IOBuffer()
+    writer = open(Arrow.Writer, io)
+    a = 1:26
+    b = 'A':'Z'
+    partitionsize = 10
+    iter_a = Iterators.partition(a, partitionsize)
+    iter_b = Iterators.partition(b, partitionsize)
+    for (part_a, part_b) in zip(iter_a, iter_b)
+        Arrow.write(writer, (a = part_a, b = part_b))
+    end
+    close(writer)
+    seekstart(io)
+    table = Arrow.Table(io)
+    @test table.a == collect(a)
+    @test table.b == collect(b)
+end
+
 end # @testset "misc"
 
 end

Reply via email to