This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 85684fe4af GH-43169: [Swift] Add StructArray to ArrowReader (#43335)
85684fe4af is described below
commit 85684fe4af1e35233f3ac921ed45b95202cda562
Author: abandy <[email protected]>
AuthorDate: Thu Jul 25 15:49:52 2024 -0400
GH-43169: [Swift] Add StructArray to ArrowReader (#43335)
### Rationale for this change
Structs have been added for Swift but currently the ArrowReader does not
support them. This PR adds the ArrowReader support
### What changes are included in this PR?
Adding StructArray to ArrowReader
### Are these changes tested?
The next PR for the ArrowWriter will include a test for reading and writing
Structs.
* GitHub Issue: #43169
Authored-by: Alva Bandy <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
---
swift/Arrow/Sources/Arrow/ArrowCImporter.swift | 3 +-
swift/Arrow/Sources/Arrow/ArrowReader.swift | 199 +++++++++++++++-------
swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift | 59 ++++++-
swift/Arrow/Tests/ArrowTests/ArrayTests.swift | 2 +-
4 files changed, 194 insertions(+), 69 deletions(-)
diff --git a/swift/Arrow/Sources/Arrow/ArrowCImporter.swift
b/swift/Arrow/Sources/Arrow/ArrowCImporter.swift
index f55077ef3d..e65d78d730 100644
--- a/swift/Arrow/Sources/Arrow/ArrowCImporter.swift
+++ b/swift/Arrow/Sources/Arrow/ArrowCImporter.swift
@@ -153,7 +153,8 @@ public class ArrowCImporter {
}
}
- switch makeArrayHolder(arrowField, buffers: arrowBuffers, nullCount:
nullCount) {
+ switch makeArrayHolder(arrowField, buffers: arrowBuffers,
+ nullCount: nullCount, children: nil, rbLength:
0) {
case .success(let holder):
return .success(ImportArrayHolder(holder, cArrayPtr: cArrayPtr))
case .failure(let err):
diff --git a/swift/Arrow/Sources/Arrow/ArrowReader.swift
b/swift/Arrow/Sources/Arrow/ArrowReader.swift
index 237f22dc97..ae187e22ee 100644
--- a/swift/Arrow/Sources/Arrow/ArrowReader.swift
+++ b/swift/Arrow/Sources/Arrow/ArrowReader.swift
@@ -21,14 +21,46 @@ import Foundation
let FILEMARKER = "ARROW1"
let CONTINUATIONMARKER = -1
-public class ArrowReader {
- private struct DataLoadInfo {
+public class ArrowReader { // swiftlint:disable:this type_body_length
+ private class RecordBatchData {
+ let schema: org_apache_arrow_flatbuf_Schema
let recordBatch: org_apache_arrow_flatbuf_RecordBatch
- let field: org_apache_arrow_flatbuf_Field
- let nodeIndex: Int32
- let bufferIndex: Int32
+ private var fieldIndex: Int32 = 0
+ private var nodeIndex: Int32 = 0
+ private var bufferIndex: Int32 = 0
+ init(_ recordBatch: org_apache_arrow_flatbuf_RecordBatch,
+ schema: org_apache_arrow_flatbuf_Schema) {
+ self.recordBatch = recordBatch
+ self.schema = schema
+ }
+
+ func nextNode() -> org_apache_arrow_flatbuf_FieldNode? {
+ if nodeIndex >= self.recordBatch.nodesCount {return nil}
+ defer {nodeIndex += 1}
+ return self.recordBatch.nodes(at: nodeIndex)
+ }
+
+ func nextBuffer() -> org_apache_arrow_flatbuf_Buffer? {
+ if bufferIndex >= self.recordBatch.buffersCount {return nil}
+ defer {bufferIndex += 1}
+ return self.recordBatch.buffers(at: bufferIndex)
+ }
+
+ func nextField() -> org_apache_arrow_flatbuf_Field? {
+ if fieldIndex >= self.schema.fieldsCount {return nil}
+ defer {fieldIndex += 1}
+ return self.schema.fields(at: fieldIndex)
+ }
+
+ func isDone() -> Bool {
+ return nodeIndex >= self.recordBatch.nodesCount
+ }
+ }
+
+ private struct DataLoadInfo {
let fileData: Data
let messageOffset: Int64
+ var batchData: RecordBatchData
}
public class ArrowReaderResult {
@@ -54,49 +86,104 @@ public class ArrowReader {
return .success(builder.finish())
}
- private func loadPrimitiveData(_ loadInfo: DataLoadInfo) ->
Result<ArrowArrayHolder, ArrowError> {
- do {
- let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
- let nullLength = UInt(ceil(Double(node.length) / 8))
- try validateBufferIndex(loadInfo.recordBatch, index:
loadInfo.bufferIndex)
- let nullBuffer = loadInfo.recordBatch.buffers(at:
loadInfo.bufferIndex)!
- let arrowNullBuffer = makeBuffer(nullBuffer, fileData:
loadInfo.fileData,
- length: nullLength,
messageOffset: loadInfo.messageOffset)
- try validateBufferIndex(loadInfo.recordBatch, index:
loadInfo.bufferIndex + 1)
- let valueBuffer = loadInfo.recordBatch.buffers(at:
loadInfo.bufferIndex + 1)!
- let arrowValueBuffer = makeBuffer(valueBuffer, fileData:
loadInfo.fileData,
- length: UInt(node.length),
messageOffset: loadInfo.messageOffset)
- return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer,
arrowValueBuffer],
- nullCount: UInt(node.nullCount))
- } catch let error as ArrowError {
- return .failure(error)
- } catch {
- return .failure(.unknownError("\(error)"))
+ private func loadStructData(_ loadInfo: DataLoadInfo,
+ field: org_apache_arrow_flatbuf_Field)
+ -> Result<ArrowArrayHolder, ArrowError> {
+ guard let node = loadInfo.batchData.nextNode() else {
+ return .failure(.invalid("Node not found"))
+ }
+
+ guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
+ return .failure(.invalid("Null buffer not found"))
+ }
+
+ let nullLength = UInt(ceil(Double(node.length) / 8))
+ let arrowNullBuffer = makeBuffer(nullBuffer, fileData:
loadInfo.fileData,
+ length: nullLength, messageOffset:
loadInfo.messageOffset)
+ var children = [ArrowData]()
+ for index in 0..<field.childrenCount {
+ let childField = field.children(at: index)!
+ switch loadField(loadInfo, field: childField) {
+ case .success(let holder):
+ children.append(holder.array.arrowData)
+ case .failure(let error):
+ return .failure(error)
+ }
}
+
+ return makeArrayHolder(field, buffers: [arrowNullBuffer],
+ nullCount: UInt(node.nullCount), children:
children,
+ rbLength:
UInt(loadInfo.batchData.recordBatch.length))
}
- private func loadVariableData(_ loadInfo: DataLoadInfo) ->
Result<ArrowArrayHolder, ArrowError> {
- let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
- do {
- let nullLength = UInt(ceil(Double(node.length) / 8))
- try validateBufferIndex(loadInfo.recordBatch, index:
loadInfo.bufferIndex)
- let nullBuffer = loadInfo.recordBatch.buffers(at:
loadInfo.bufferIndex)!
- let arrowNullBuffer = makeBuffer(nullBuffer, fileData:
loadInfo.fileData,
- length: nullLength,
messageOffset: loadInfo.messageOffset)
- try validateBufferIndex(loadInfo.recordBatch, index:
loadInfo.bufferIndex + 1)
- let offsetBuffer = loadInfo.recordBatch.buffers(at:
loadInfo.bufferIndex + 1)!
- let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData:
loadInfo.fileData,
- length: UInt(node.length),
messageOffset: loadInfo.messageOffset)
- try validateBufferIndex(loadInfo.recordBatch, index:
loadInfo.bufferIndex + 2)
- let valueBuffer = loadInfo.recordBatch.buffers(at:
loadInfo.bufferIndex + 2)!
- let arrowValueBuffer = makeBuffer(valueBuffer, fileData:
loadInfo.fileData,
- length: UInt(node.length),
messageOffset: loadInfo.messageOffset)
- return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer,
arrowOffsetBuffer, arrowValueBuffer],
- nullCount: UInt(node.nullCount))
- } catch let error as ArrowError {
- return .failure(error)
- } catch {
- return .failure(.unknownError("\(error)"))
+ private func loadPrimitiveData(
+ _ loadInfo: DataLoadInfo,
+ field: org_apache_arrow_flatbuf_Field)
+ -> Result<ArrowArrayHolder, ArrowError> {
+ guard let node = loadInfo.batchData.nextNode() else {
+ return .failure(.invalid("Node not found"))
+ }
+
+ guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
+ return .failure(.invalid("Null buffer not found"))
+ }
+
+ guard let valueBuffer = loadInfo.batchData.nextBuffer() else {
+ return .failure(.invalid("Value buffer not found"))
+ }
+
+ let nullLength = UInt(ceil(Double(node.length) / 8))
+ let arrowNullBuffer = makeBuffer(nullBuffer, fileData:
loadInfo.fileData,
+ length: nullLength, messageOffset:
loadInfo.messageOffset)
+ let arrowValueBuffer = makeBuffer(valueBuffer, fileData:
loadInfo.fileData,
+ length: UInt(node.length),
messageOffset: loadInfo.messageOffset)
+ return makeArrayHolder(field, buffers: [arrowNullBuffer,
arrowValueBuffer],
+ nullCount: UInt(node.nullCount), children: nil,
+ rbLength:
UInt(loadInfo.batchData.recordBatch.length))
+ }
+
+ private func loadVariableData(
+ _ loadInfo: DataLoadInfo,
+ field: org_apache_arrow_flatbuf_Field)
+ -> Result<ArrowArrayHolder, ArrowError> {
+ guard let node = loadInfo.batchData.nextNode() else {
+ return .failure(.invalid("Node not found"))
+ }
+
+ guard let nullBuffer = loadInfo.batchData.nextBuffer() else {
+ return .failure(.invalid("Null buffer not found"))
+ }
+
+ guard let offsetBuffer = loadInfo.batchData.nextBuffer() else {
+ return .failure(.invalid("Offset buffer not found"))
+ }
+
+ guard let valueBuffer = loadInfo.batchData.nextBuffer() else {
+ return .failure(.invalid("Value buffer not found"))
+ }
+
+ let nullLength = UInt(ceil(Double(node.length) / 8))
+ let arrowNullBuffer = makeBuffer(nullBuffer, fileData:
loadInfo.fileData,
+ length: nullLength, messageOffset:
loadInfo.messageOffset)
+ let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData:
loadInfo.fileData,
+ length: UInt(node.length),
messageOffset: loadInfo.messageOffset)
+ let arrowValueBuffer = makeBuffer(valueBuffer, fileData:
loadInfo.fileData,
+ length: UInt(node.length),
messageOffset: loadInfo.messageOffset)
+ return makeArrayHolder(field, buffers: [arrowNullBuffer,
arrowOffsetBuffer, arrowValueBuffer],
+ nullCount: UInt(node.nullCount), children: nil,
+ rbLength:
UInt(loadInfo.batchData.recordBatch.length))
+ }
+
+ private func loadField(
+ _ loadInfo: DataLoadInfo,
+ field: org_apache_arrow_flatbuf_Field)
+ -> Result<ArrowArrayHolder, ArrowError> {
+ if isNestedType(field.typeType) {
+ return loadStructData(loadInfo, field: field)
+ } else if isFixedPrimitive(field.typeType) {
+ return loadPrimitiveData(loadInfo, field: field)
+ } else {
+ return loadVariableData(loadInfo, field: field)
}
}
@@ -107,23 +194,17 @@ public class ArrowReader {
data: Data,
messageEndOffset: Int64
) -> Result<RecordBatch, ArrowError> {
- let nodesCount = recordBatch.nodesCount
- var bufferIndex: Int32 = 0
var columns: [ArrowArrayHolder] = []
- for nodeIndex in 0 ..< nodesCount {
- let field = schema.fields(at: nodeIndex)!
- let loadInfo = DataLoadInfo(recordBatch: recordBatch, field: field,
- nodeIndex: nodeIndex, bufferIndex:
bufferIndex,
- fileData: data, messageOffset:
messageEndOffset)
- var result: Result<ArrowArrayHolder, ArrowError>
- if isFixedPrimitive(field.typeType) {
- result = loadPrimitiveData(loadInfo)
- bufferIndex += 2
- } else {
- result = loadVariableData(loadInfo)
- bufferIndex += 3
+ let batchData = RecordBatchData(recordBatch, schema: schema)
+ let loadInfo = DataLoadInfo(fileData: data,
+ messageOffset: messageEndOffset,
+ batchData: batchData)
+ while !batchData.isDone() {
+ guard let field = batchData.nextField() else {
+ return .failure(.invalid("Field not found"))
}
+ let result = loadField(loadInfo, field: field)
switch result {
case .success(let holder):
columns.append(holder)
diff --git a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
index 22c0672b27..48c6fd8550 100644
--- a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
+++ b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
@@ -117,19 +117,42 @@ private func makeFixedHolder<T>(
}
}
+ func makeStructHolder(
+ _ field: ArrowField,
+ buffers: [ArrowBuffer],
+ nullCount: UInt,
+ children: [ArrowData],
+ rbLength: UInt
+) -> Result<ArrowArrayHolder, ArrowError> {
+ do {
+ let arrowData = try ArrowData(field.type,
+ buffers: buffers, children: children,
+ nullCount: nullCount, length: rbLength)
+ return .success(ArrowArrayHolderImpl(try StructArray(arrowData)))
+ } catch let error as ArrowError {
+ return .failure(error)
+ } catch {
+ return .failure(.unknownError("\(error)"))
+ }
+}
+
func makeArrayHolder(
_ field: org_apache_arrow_flatbuf_Field,
buffers: [ArrowBuffer],
- nullCount: UInt
+ nullCount: UInt,
+ children: [ArrowData]?,
+ rbLength: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
let arrowField = fromProto(field: field)
- return makeArrayHolder(arrowField, buffers: buffers, nullCount: nullCount)
+ return makeArrayHolder(arrowField, buffers: buffers, nullCount: nullCount,
children: children, rbLength: rbLength)
}
func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
_ field: ArrowField,
buffers: [ArrowBuffer],
- nullCount: UInt
+ nullCount: UInt,
+ children: [ArrowData]?,
+ rbLength: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
let typeId = field.type.id
switch typeId {
@@ -159,12 +182,12 @@ func makeArrayHolder( // swiftlint:disable:this
cyclomatic_complexity
return makeStringHolder(buffers, nullCount: nullCount)
case .binary:
return makeBinaryHolder(buffers, nullCount: nullCount)
- case .date32:
+ case .date32, .date64:
return makeDateHolder(field, buffers: buffers, nullCount: nullCount)
- case .time32:
- return makeTimeHolder(field, buffers: buffers, nullCount: nullCount)
- case .time64:
+ case .time32, .time64:
return makeTimeHolder(field, buffers: buffers, nullCount: nullCount)
+ case .strct:
+ return makeStructHolder(field, buffers: buffers, nullCount: nullCount,
children: children!, rbLength: rbLength)
default:
return .failure(.unknownType("Type \(typeId) currently not supported"))
}
@@ -187,7 +210,16 @@ func isFixedPrimitive(_ type:
org_apache_arrow_flatbuf_Type_) -> Bool {
}
}
-func findArrowType( // swiftlint:disable:this cyclomatic_complexity
+func isNestedType(_ type: org_apache_arrow_flatbuf_Type_) -> Bool {
+ switch type {
+ case .struct_:
+ return true
+ default:
+ return false
+ }
+}
+
+func findArrowType( // swiftlint:disable:this cyclomatic_complexity
function_body_length
_ field: org_apache_arrow_flatbuf_Field) -> ArrowType {
let type = field.typeType
switch type {
@@ -229,6 +261,17 @@ func findArrowType( // swiftlint:disable:this
cyclomatic_complexity
}
return ArrowTypeTime64(timeType.unit == .microsecond ? .microseconds :
.nanoseconds)
+ case .struct_:
+ _ = field.type(type: org_apache_arrow_flatbuf_Struct_.self)!
+ var fields = [ArrowField]()
+ for index in 0..<field.childrenCount {
+ let childField = field.children(at: index)!
+ let childType = findArrowType(childField)
+ fields.append(
+ ArrowField(childField.name ?? "", type: childType, isNullable:
childField.nullable))
+ }
+
+ return ArrowNestedType(ArrowType.ArrowStruct, fields: fields)
default:
return ArrowType(ArrowType.ArrowUnknown)
}
diff --git a/swift/Arrow/Tests/ArrowTests/ArrayTests.swift
b/swift/Arrow/Tests/ArrowTests/ArrayTests.swift
index bfd7492064..d793aa11dc 100644
--- a/swift/Arrow/Tests/ArrowTests/ArrayTests.swift
+++ b/swift/Arrow/Tests/ArrowTests/ArrayTests.swift
@@ -279,7 +279,7 @@ final class ArrayTests: XCTestCase { //
swiftlint:disable:this type_body_length
ArrowBuffer(length: 0, capacity: 0,
rawPointer:
UnsafeMutableRawPointer.allocate(byteCount: 0, alignment: .zero))]
let field = ArrowField("", type: checkType, isNullable: true)
- switch makeArrayHolder(field, buffers: buffers, nullCount: 0) {
+ switch makeArrayHolder(field, buffers: buffers, nullCount: 0,
children: nil, rbLength: 0) {
case .success(let holder):
XCTAssertEqual(holder.type.id, checkType.id)
case .failure(let err):