zeroshade commented on code in PR #771: URL: https://github.com/apache/arrow-go/pull/771#discussion_r3112887126
########## arrow/array/arreflect/reflect.go: ########## @@ -0,0 +1,542 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "errors" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "sync" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +var ( + ErrUnsupportedType = errors.New("arreflect: unsupported type") + ErrTypeMismatch = errors.New("arreflect: type mismatch") +) + +type tagOpts struct { + Name string + Skip bool + Dict bool + ListView bool + REE bool + DecimalPrecision int32 + DecimalScale int32 + HasDecimalOpts bool + Temporal string // "timestamp" (default), "date32", "date64", "time32", "time64" +} + +type fieldMeta struct { + Name string + Index []int + Type reflect.Type + Nullable bool + Opts tagOpts +} + +func parseTag(tag string) tagOpts { + if tag == "-" { + return tagOpts{Skip: true} + } + + var name, rest string + if idx := strings.Index(tag, ","); idx >= 0 { + name = tag[:idx] + rest = tag[idx+1:] + } else { + name = tag + rest = "" + } + + opts := tagOpts{Name: name} + + if rest == "" { + return opts + } + + parseOptions(&opts, rest) + return opts +} + +func splitTagTokens(rest string) []string { + var tokens []string + depth := 0 + start := 0 + for i := 0; i < len(rest); i++ { + switch rest[i] { + case '(': + depth++ + case ')': + depth-- + case ',': + if depth == 0 { + tokens = append(tokens, strings.TrimSpace(rest[start:i])) + start = i + 1 + } + } + } + if start < len(rest) { + tokens = append(tokens, strings.TrimSpace(rest[start:])) + } + return tokens +} + +func parseOptions(opts *tagOpts, rest string) { + for _, token := range splitTagTokens(rest) { + if strings.HasPrefix(token, "decimal(") && strings.HasSuffix(token, ")") { + parseDecimalOpt(opts, token) + continue + } + switch token { + case "dict": + opts.Dict = true + case "listview": + opts.ListView = true + case "ree": + opts.REE = true + case "date32", "date64", "time32", "time64", "timestamp": + opts.Temporal = token + } + } +} + +func parseDecimalOpt(opts *tagOpts, token string) { + inner := strings.TrimPrefix(token, "decimal(") + inner = strings.TrimSuffix(inner, ")") + parts := strings.SplitN(inner, ",", 2) + if len(parts) == 2 { + p, errP := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 32) + s, errS := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 32) + if errP == nil && errS == nil { + opts.HasDecimalOpts = true + opts.DecimalPrecision = int32(p) + opts.DecimalScale = int32(s) + } + } +} + +type bfsEntry struct { + t reflect.Type + index []int + depth int +} + +type candidate struct { + meta fieldMeta + depth int + tagged bool + order int +} + +type resolvedField struct { + meta fieldMeta + order int +} + +func collectFieldCandidates(t reflect.Type) map[string][]candidate { + nameMap := make(map[string][]candidate) + orderCounter := 0 + + queue := []bfsEntry{{t: t, index: nil, depth: 0}} + visited := make(map[reflect.Type]bool) + + for len(queue) > 0 { + entry := queue[0] + queue = queue[1:] + + st := entry.t + for st.Kind() == reflect.Ptr { + st = st.Elem() + } + if st.Kind() != reflect.Struct { + continue + } + + if visited[st] { + continue + } + if entry.depth > 0 { + visited[st] = true + } + + for i := 0; i < st.NumField(); i++ { + sf := st.Field(i) + + fullIndex := make([]int, len(entry.index)+1) + copy(fullIndex, entry.index) + fullIndex[len(entry.index)] = i + + if !sf.IsExported() && !sf.Anonymous { + continue + } + + tagVal, hasTag := sf.Tag.Lookup("arrow") + var opts tagOpts + if hasTag { + opts = parseTag(tagVal) + } + + if opts.Skip { + continue + } + + arrowName := opts.Name + if arrowName == "" { + arrowName = sf.Name + } + + if sf.Anonymous && !hasTag { + ft := sf.Type + for ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if ft.Kind() == reflect.Struct { + queue = append(queue, bfsEntry{ + t: ft, + index: fullIndex, + depth: entry.depth + 1, + }) + continue + } + } + + nullable := sf.Type.Kind() == reflect.Ptr + tagged := hasTag && opts.Name != "" + + meta := fieldMeta{ + Name: arrowName, + Index: fullIndex, + Type: sf.Type, + Nullable: nullable, + Opts: opts, + } + + existingCands := nameMap[arrowName] + order := orderCounter + if len(existingCands) > 0 { + order = existingCands[0].order + } else { + orderCounter++ + } + + nameMap[arrowName] = append(existingCands, candidate{ + meta: meta, + depth: entry.depth, + tagged: tagged, + order: order, + }) + } + } + + return nameMap +} + +func resolveFieldCandidates(nameMap map[string][]candidate) []fieldMeta { + resolved := make([]resolvedField, 0, len(nameMap)) + for _, candidates := range nameMap { + minDepth := candidates[0].depth + for _, c := range candidates[1:] { + if c.depth < minDepth { + minDepth = c.depth + } + } + + var atMin []candidate + for _, c := range candidates { + if c.depth == minDepth { + atMin = append(atMin, c) + } + } + + var winner *candidate + if len(atMin) == 1 { + winner = &atMin[0] + } else { + var tagged []candidate + for _, c := range atMin { + if c.tagged { + tagged = append(tagged, c) + } + } + if len(tagged) == 1 { + winner = &tagged[0] + } + } + + if winner != nil { + resolved = append(resolved, resolvedField{meta: winner.meta, order: winner.order}) + } + } + + sort.Slice(resolved, func(i, j int) bool { + return resolved[i].order < resolved[j].order + }) + + result := make([]fieldMeta, len(resolved)) + for i, r := range resolved { + result[i] = r.meta + } + return result +} + +func getStructFields(t reflect.Type) []fieldMeta { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return nil + } + + return resolveFieldCandidates(collectFieldCandidates(t)) +} + +var structFieldCache sync.Map + +func cachedStructFields(t reflect.Type) []fieldMeta { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if v, ok := structFieldCache.Load(t); ok { + return v.([]fieldMeta) + } + + fields := getStructFields(t) + v, _ := structFieldCache.LoadOrStore(t, fields) + return v.([]fieldMeta) +} + +func At[T any](arr arrow.Array, i int) (T, error) { + var result T + v := reflect.ValueOf(&result).Elem() + if err := setValue(v, arr, i); err != nil { + var zero T + return zero, err + } + return result, nil +} + +func ToSlice[T any](arr arrow.Array) ([]T, error) { + n := arr.Len() + result := make([]T, n) + for i := 0; i < n; i++ { + v := reflect.ValueOf(&result[i]).Elem() + if err := setValue(v, arr, i); err != nil { + return nil, fmt.Errorf("index %d: %w", i, err) + } + } + return result, nil +} + +// Option configures encoding behavior for [FromSlice] and [RecordFromSlice]. +type Option func(*tagOpts) + +// WithDict requests dictionary encoding for the top-level array. +func WithDict() Option { return func(o *tagOpts) { o.Dict = true } } + +// WithListView requests ListView encoding instead of List for slice types. +func WithListView() Option { return func(o *tagOpts) { o.ListView = true } } + +// WithREE requests run-end encoding for the top-level array. +func WithREE() Option { return func(o *tagOpts) { o.REE = true } } + +// WithDecimal sets the precision and scale for decimal types. +func WithDecimal(precision, scale int32) Option { + return func(o *tagOpts) { + o.DecimalPrecision = precision + o.DecimalScale = scale + o.HasDecimalOpts = true + } +} + +// WithTemporal overrides the Arrow temporal encoding for time.Time slices. +// Valid values: "date32", "date64", "time32", "time64", "timestamp" (default). +// Equivalent to tagging a struct field with arrow:",date32" etc. +// Invalid values cause FromSlice to return an error. +func WithTemporal(temporal string) Option { + return func(o *tagOpts) { o.Temporal = temporal } +} + +func validateTemporalOpt(temporal string) error { + switch temporal { + case "", "timestamp", "date32", "date64", "time32", "time64": + return nil + default: + return fmt.Errorf("arreflect: invalid WithTemporal value %q; valid values are date32, date64, time32, time64, timestamp: %w", temporal, ErrUnsupportedType) + } +} + +func buildEmptyTyped(goType reflect.Type, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + dt, err := inferArrowType(goType) + if err != nil { + return nil, err + } + derefType := goType + for derefType.Kind() == reflect.Ptr { + derefType = derefType.Elem() + } + dt = applyDecimalOpts(dt, derefType, opts) + dt = applyTemporalOpts(dt, derefType, opts) + if opts.ListView { + if derefType.Kind() != reflect.Slice || derefType == typeOfByteSlice { + return nil, fmt.Errorf("arreflect: WithListView requires a slice-of-slices element type, got %s: %w", goType, ErrUnsupportedType) + } + innerElem := derefType.Elem() + for innerElem.Kind() == reflect.Ptr { + innerElem = innerElem.Elem() + } + innerDT, err := inferArrowType(innerElem) + if err != nil { + return nil, err + } + dt = arrow.ListViewOf(innerDT) + } + if opts.Dict { + if err := validateDictValueType(dt); err != nil { + return nil, err + } + dt = &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: dt} + } else if opts.REE { + dt = arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, dt) + } + b := array.NewBuilder(mem, dt) + defer b.Release() + return b.NewArray(), nil +} + +func FromSlice[T any](vals []T, mem memory.Allocator, opts ...Option) (arrow.Array, error) { + if mem == nil { + mem = memory.DefaultAllocator + } + var tOpts tagOpts + for _, o := range opts { + o(&tOpts) + } + if err := validateTemporalOpt(tOpts.Temporal); err != nil { + return nil, err + } + // "timestamp" is excluded: it is a no-op for non-time.Time types via applyTemporalOpts. + if tOpts.Temporal != "" && tOpts.Temporal != "timestamp" { + goType := reflect.TypeFor[T]() + deref := goType + for deref.Kind() == reflect.Ptr { + deref = deref.Elem() + } + if deref != typeOfTime { + return nil, fmt.Errorf("arreflect: WithTemporal requires a time.Time element type, got %s: %w", deref, ErrUnsupportedType) + } + } + if len(vals) == 0 { + return buildEmptyTyped(reflect.TypeFor[T](), tOpts, mem) + } + sv := reflect.ValueOf(vals) + return buildArray(sv, tOpts, mem) +} Review Comment: Good catch. Added `validateOptions` that rejects conflicting combinations (Dict+REE, ListView+REE, Dict+ListView). Now `FromSlice` returns an error if more than one encoding option is specified. ########## arrow/array/arreflect/reflect_infer.go: ########## @@ -0,0 +1,428 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "fmt" + "reflect" + "time" + "unicode" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal256" +) + +var ( + typeOfTime = reflect.TypeOf(time.Time{}) + typeOfDuration = reflect.TypeOf(time.Duration(0)) + typeOfDec32 = reflect.TypeOf(decimal.Decimal32(0)) + typeOfDec64 = reflect.TypeOf(decimal.Decimal64(0)) + typeOfDec128 = reflect.TypeOf(decimal128.Num{}) + typeOfDec256 = reflect.TypeOf(decimal256.Num{}) + typeOfByteSlice = reflect.TypeOf([]byte{}) + typeOfInt = reflect.TypeOf(int(0)) + typeOfUint = reflect.TypeOf(uint(0)) + typeOfInt8 = reflect.TypeOf(int8(0)) + typeOfInt16 = reflect.TypeOf(int16(0)) + typeOfInt32 = reflect.TypeOf(int32(0)) + typeOfInt64 = reflect.TypeOf(int64(0)) + typeOfUint8 = reflect.TypeOf(uint8(0)) + typeOfUint16 = reflect.TypeOf(uint16(0)) + typeOfUint32 = reflect.TypeOf(uint32(0)) + typeOfUint64 = reflect.TypeOf(uint64(0)) + typeOfFloat32 = reflect.TypeOf(float32(0)) + typeOfFloat64 = reflect.TypeOf(float64(0)) + typeOfBool = reflect.TypeOf(false) + typeOfString = reflect.TypeOf("") +) + +const ( + dec32DefaultPrecision int32 = 9 + dec64DefaultPrecision int32 = 18 + dec128DefaultPrecision int32 = 38 + dec256DefaultPrecision int32 = 76 +) + +type listElemTyper interface{ Elem() arrow.DataType } + +func inferPrimitiveArrowType(t reflect.Type) (arrow.DataType, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t { + case typeOfInt8: + return arrow.PrimitiveTypes.Int8, nil + case typeOfInt16: + return arrow.PrimitiveTypes.Int16, nil + case typeOfInt32: + return arrow.PrimitiveTypes.Int32, nil + case typeOfInt64: + return arrow.PrimitiveTypes.Int64, nil + case typeOfInt: + return arrow.PrimitiveTypes.Int64, nil + case typeOfUint8: + return arrow.PrimitiveTypes.Uint8, nil + case typeOfUint16: + return arrow.PrimitiveTypes.Uint16, nil + case typeOfUint32: + return arrow.PrimitiveTypes.Uint32, nil + case typeOfUint64: + return arrow.PrimitiveTypes.Uint64, nil + case typeOfUint: + return arrow.PrimitiveTypes.Uint64, nil + case typeOfFloat32: + return arrow.PrimitiveTypes.Float32, nil + case typeOfFloat64: + return arrow.PrimitiveTypes.Float64, nil + case typeOfBool: + return arrow.FixedWidthTypes.Boolean, nil + case typeOfString: + return arrow.BinaryTypes.String, nil + case typeOfByteSlice: + return arrow.BinaryTypes.Binary, nil + case typeOfTime: + return &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}, nil + case typeOfDuration: + return &arrow.DurationType{Unit: arrow.Nanosecond}, nil + case typeOfDec128: + return &arrow.Decimal128Type{Precision: dec128DefaultPrecision, Scale: 0}, nil + case typeOfDec32: + return &arrow.Decimal32Type{Precision: dec32DefaultPrecision, Scale: 0}, nil + case typeOfDec64: + return &arrow.Decimal64Type{Precision: dec64DefaultPrecision, Scale: 0}, nil + case typeOfDec256: + return &arrow.Decimal256Type{Precision: dec256DefaultPrecision, Scale: 0}, nil + default: + return nil, fmt.Errorf("unsupported Go type for Arrow inference %v: %w", t, ErrUnsupportedType) + } +} + +func inferArrowType(t reflect.Type) (arrow.DataType, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t == typeOfByteSlice { + return arrow.BinaryTypes.Binary, nil + } + + switch t.Kind() { + case reflect.Slice: + elemDT, err := inferArrowType(t.Elem()) + if err != nil { + return nil, err + } + return arrow.ListOf(elemDT), nil + + case reflect.Array: + elemDT, err := inferArrowType(t.Elem()) + if err != nil { + return nil, err + } + return arrow.FixedSizeListOf(int32(t.Len()), elemDT), nil + + case reflect.Map: + keyDT, err := inferArrowType(t.Key()) + if err != nil { + return nil, err + } + valDT, err := inferArrowType(t.Elem()) + if err != nil { + return nil, err + } + return arrow.MapOf(keyDT, valDT), nil + + case reflect.Struct: + return inferStructType(t) + + default: + return inferPrimitiveArrowType(t) + } +} + +func applyDecimalOpts(dt arrow.DataType, origType reflect.Type, opts tagOpts) arrow.DataType { + if !opts.HasDecimalOpts { + return dt + } + prec, scale := opts.DecimalPrecision, opts.DecimalScale + switch origType { + case typeOfDec128: + return &arrow.Decimal128Type{Precision: prec, Scale: scale} + case typeOfDec256: + return &arrow.Decimal256Type{Precision: prec, Scale: scale} + case typeOfDec32: + return &arrow.Decimal32Type{Precision: prec, Scale: scale} + case typeOfDec64: + return &arrow.Decimal64Type{Precision: prec, Scale: scale} + } + return dt +} + +func applyTemporalOpts(dt arrow.DataType, origType reflect.Type, opts tagOpts) arrow.DataType { + if origType != typeOfTime || opts.Temporal == "" || opts.Temporal == "timestamp" { + return dt + } + switch opts.Temporal { + case "date32": + return arrow.FixedWidthTypes.Date32 + case "date64": + return arrow.FixedWidthTypes.Date64 + case "time32": + return &arrow.Time32Type{Unit: arrow.Millisecond} + case "time64": + return &arrow.Time64Type{Unit: arrow.Nanosecond} + } + return dt +} + +func applyEncodingOpts(dt arrow.DataType, fm fieldMeta) (arrow.DataType, error) { + switch { + case fm.Opts.Dict: + if err := validateDictValueType(dt); err != nil { + return nil, fmt.Errorf("arreflect: dict tag on field %q: %w", fm.Name, err) + } + return &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: dt}, nil + case fm.Opts.ListView: + lt, ok := dt.(*arrow.ListType) + if !ok { + return nil, fmt.Errorf("arreflect: listview tag on field %q requires a slice type, got %v", fm.Name, dt) + } + return arrow.ListViewOf(lt.Elem()), nil + case fm.Opts.REE: + return nil, fmt.Errorf("arreflect: ree tag on struct field %q is not supported; use ree at top-level via FromSlice", fm.Name) + } + return dt, nil +} + +func inferStructType(t reflect.Type) (*arrow.StructType, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("arreflect: expected struct, got %v", t) + } + + fields := cachedStructFields(t) + arrowFields := make([]arrow.Field, 0, len(fields)) + + for _, fm := range fields { + origType := fm.Type + for origType.Kind() == reflect.Ptr { + origType = origType.Elem() + } + + dt, err := inferArrowType(fm.Type) + if err != nil { + return nil, fmt.Errorf("struct field %q: %w", fm.Name, err) + } + + dt = applyDecimalOpts(dt, origType, fm.Opts) + dt = applyTemporalOpts(dt, origType, fm.Opts) + dt, err = applyEncodingOpts(dt, fm) + if err != nil { + return nil, err + } + + arrowFields = append(arrowFields, arrow.Field{ + Name: fm.Name, + Type: dt, + Nullable: fm.Nullable, + }) + } + + return arrow.StructOf(arrowFields...), nil +} + +// InferSchema infers an *arrow.Schema from a Go struct type T. +// T must be a struct type; returns an error otherwise. +// For column-level Arrow type inspection, use [InferType]. +// Field names come from arrow struct tags or Go field names. +// Pointer fields are marked Nullable=true. +func InferSchema[T any]() (*arrow.Schema, error) { + t := reflect.TypeFor[T]() + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("arreflect: InferSchema requires a struct type T, got %v", t) + } + st, err := inferStructType(t) + if err != nil { + return nil, err + } + fields := make([]arrow.Field, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + fields[i] = st.Field(i) + } + return arrow.NewSchema(fields, nil), nil +} + +// InferType infers the Arrow DataType for a Go type T. +// For struct types, [InferSchema] is preferred when the result will be used with +// arrow.Record or array.NewRecord; InferType returns an arrow.DataType that would +// require an additional cast to *arrow.StructType. +func InferType[T any]() (arrow.DataType, error) { + t := reflect.TypeFor[T]() + return inferArrowType(t) +} + +// InferGoType returns the Go reflect.Type corresponding to the given Arrow DataType. +// For STRUCT types it constructs an anonymous struct type at runtime using +// [reflect.StructOf]; field names are exported (capitalised) with the original +// Arrow field name preserved in an arrow struct tag. Nullable Arrow fields +// (field.Nullable == true) become pointer types (*T). +// For DICTIONARY and RUN_END_ENCODED types it returns the Go type of the +// value/encoded type respectively (dictionaries are resolved transparently). +func InferGoType(dt arrow.DataType) (reflect.Type, error) { + switch dt.ID() { + case arrow.INT8: + return typeOfInt8, nil + case arrow.INT16: + return typeOfInt16, nil + case arrow.INT32: + return typeOfInt32, nil + case arrow.INT64: + return typeOfInt64, nil + case arrow.UINT8: + return typeOfUint8, nil + case arrow.UINT16: + return typeOfUint16, nil + case arrow.UINT32: + return typeOfUint32, nil + case arrow.UINT64: + return typeOfUint64, nil + case arrow.FLOAT32: + return typeOfFloat32, nil + case arrow.FLOAT64: + return typeOfFloat64, nil + case arrow.BOOL: + return typeOfBool, nil + case arrow.STRING, arrow.LARGE_STRING: + return typeOfString, nil + case arrow.BINARY, arrow.LARGE_BINARY: + return typeOfByteSlice, nil + case arrow.TIMESTAMP, arrow.DATE32, arrow.DATE64, arrow.TIME32, arrow.TIME64: + return typeOfTime, nil + case arrow.DURATION: + return typeOfDuration, nil + case arrow.DECIMAL128: + return typeOfDec128, nil + case arrow.DECIMAL256: + return typeOfDec256, nil + case arrow.DECIMAL32: + return typeOfDec32, nil + case arrow.DECIMAL64: + return typeOfDec64, nil + + case arrow.LIST, arrow.LARGE_LIST, arrow.LIST_VIEW, arrow.LARGE_LIST_VIEW: + ll, ok := dt.(listElemTyper) + if !ok { + return nil, fmt.Errorf("unsupported Arrow type for Go inference: %v: %w", dt, ErrUnsupportedType) + } + elemDT := ll.Elem() + elemType, err := InferGoType(elemDT) + if err != nil { + return nil, err + } + return reflect.SliceOf(elemType), nil + + case arrow.FIXED_SIZE_LIST: + fsl := dt.(*arrow.FixedSizeListType) + elemType, err := InferGoType(fsl.Elem()) + if err != nil { + return nil, err + } + return reflect.ArrayOf(int(fsl.Len()), elemType), nil + + case arrow.MAP: + mt := dt.(*arrow.MapType) + keyType, err := InferGoType(mt.KeyType()) + if err != nil { + return nil, err + } + if !keyType.Comparable() { + return nil, fmt.Errorf("arreflect: InferGoType: MAP key type %v is not comparable in Go: %w", mt.KeyType(), ErrUnsupportedType) + } + valType, err := InferGoType(mt.ItemField().Type) + if err != nil { + return nil, err + } + return reflect.MapOf(keyType, valType), nil + + case arrow.STRUCT: + return inferGoStructType(dt.(*arrow.StructType)) + + case arrow.DICTIONARY: + return InferGoType(dt.(*arrow.DictionaryType).ValueType) + + case arrow.RUN_END_ENCODED: + return InferGoType(dt.(*arrow.RunEndEncodedType).Encoded()) + + default: + return nil, fmt.Errorf("unsupported Arrow type for Go inference: %v: %w", dt, ErrUnsupportedType) + } +} + +func exportedFieldName(name string, index int) (string, error) { + if len(name) == 0 { + return fmt.Sprintf("Field%d", index), nil + } + runes := []rune(name) + runes[0] = unicode.ToUpper(runes[0]) + for j, r := range runes { + if j == 0 { + if !unicode.IsLetter(r) { + return "", fmt.Errorf("arreflect: InferGoType: field name %q produces invalid Go identifier: %w", name, ErrUnsupportedType) + } + } else if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { + return "", fmt.Errorf("arreflect: InferGoType: field name %q produces invalid Go identifier: %w", name, ErrUnsupportedType) + } + } + return string(runes), nil Review Comment: Fixed. `exportedFieldName` now prefixes names starting with non-letter characters (like `_` or digits) with `X` to produce a valid exported Go identifier (e.g. `_id` → `X_id`, `1st` → `X1st`). The original Arrow field name is preserved in the struct tag for correct round-trip mapping. ########## arrow/array/arreflect/reflect_arrow_to_go.go: ########## @@ -0,0 +1,430 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "fmt" + "reflect" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" +) + +func assertArray[T any](arr arrow.Array) (*T, error) { + a, ok := any(arr).(*T) + if !ok { + var zero T + return nil, fmt.Errorf("expected *%T, got %T: %w", zero, arr, ErrTypeMismatch) + } + return a, nil +} + +func isIntKind(k reflect.Kind) bool { + return k == reflect.Int || k == reflect.Int8 || k == reflect.Int16 || + k == reflect.Int32 || k == reflect.Int64 +} + +func isUintKind(k reflect.Kind) bool { + return k == reflect.Uint || k == reflect.Uint8 || k == reflect.Uint16 || + k == reflect.Uint32 || k == reflect.Uint64 || k == reflect.Uintptr +} + +func isFloatKind(k reflect.Kind) bool { return k == reflect.Float32 || k == reflect.Float64 } + +func setValue(v reflect.Value, arr arrow.Array, i int) error { + if arr.IsNull(i) { + v.Set(reflect.Zero(v.Type())) + return nil + } + if v.Kind() == reflect.Ptr { + v.Set(reflect.New(v.Type().Elem())) Review Comment: Fixed. `setValue` now loops `for v.Kind() == reflect.Ptr` allocating each level, correctly handling `**T` and deeper pointer chains. ########## arrow/array/arreflect/reflect_arrow_to_go.go: ########## @@ -0,0 +1,430 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "fmt" + "reflect" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" +) + +func assertArray[T any](arr arrow.Array) (*T, error) { + a, ok := any(arr).(*T) + if !ok { + var zero T + return nil, fmt.Errorf("expected *%T, got %T: %w", zero, arr, ErrTypeMismatch) + } + return a, nil +} + +func isIntKind(k reflect.Kind) bool { + return k == reflect.Int || k == reflect.Int8 || k == reflect.Int16 || + k == reflect.Int32 || k == reflect.Int64 +} + +func isUintKind(k reflect.Kind) bool { + return k == reflect.Uint || k == reflect.Uint8 || k == reflect.Uint16 || + k == reflect.Uint32 || k == reflect.Uint64 || k == reflect.Uintptr +} + +func isFloatKind(k reflect.Kind) bool { return k == reflect.Float32 || k == reflect.Float64 } + +func setValue(v reflect.Value, arr arrow.Array, i int) error { + if arr.IsNull(i) { + v.Set(reflect.Zero(v.Type())) + return nil + } + if v.Kind() == reflect.Ptr { + v.Set(reflect.New(v.Type().Elem())) + v = v.Elem() + } + + switch arr.DataType().ID() { + case arrow.BOOL: + a, err := assertArray[array.Boolean](arr) + if err != nil { + return err + } + if v.Kind() != reflect.Bool { + return fmt.Errorf("cannot set bool into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetBool(a.Value(i)) + + case arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64, + arrow.UINT8, arrow.UINT16, arrow.UINT32, arrow.UINT64, + arrow.FLOAT32, arrow.FLOAT64: + return setPrimitiveValue(v, arr, i) + + case arrow.STRING, arrow.LARGE_STRING: + type stringer interface{ Value(int) string } + a, ok := arr.(stringer) + if !ok { + return fmt.Errorf("expected string array, got %T: %w", arr, ErrTypeMismatch) + } + if v.Kind() != reflect.String { + return fmt.Errorf("cannot set string into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetString(a.Value(i)) + + case arrow.BINARY, arrow.LARGE_BINARY: + type byter interface{ Value(int) []byte } + a, ok := arr.(byter) + if !ok { + return fmt.Errorf("expected binary array, got %T: %w", arr, ErrTypeMismatch) + } + if v.Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("cannot set []byte into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetBytes(a.Value(i)) + Review Comment: Fixed. Strings are now copied via `strings.Clone` and `[]byte` values are copied into a fresh slice. This ensures the returned Go values own their data and remain valid after the Arrow array is released, regardless of allocator. ########## arrow/array/arreflect/reflect_go_to_arrow.go: ########## @@ -0,0 +1,781 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "fmt" + "reflect" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/decimal" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/arrow-go/v18/arrow/decimal256" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +func buildArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + if vals.Kind() != reflect.Slice { + return nil, fmt.Errorf("arreflect: expected slice, got %v", vals.Kind()) + } + + elemType := vals.Type().Elem() + for elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + + if opts.Dict { + return buildDictionaryArray(vals, mem) + } + if opts.REE { + return buildRunEndEncodedArray(vals, opts, mem) + } + if opts.ListView { + if elemType.Kind() != reflect.Slice || elemType == typeOfByteSlice { + return nil, fmt.Errorf("arreflect: WithListView requires a slice-of-slices element type, got %s: %w", elemType, ErrUnsupportedType) + } + return buildListViewArray(vals, mem) + } + + switch elemType { + case typeOfDec32, typeOfDec64, typeOfDec128, typeOfDec256: + return buildDecimalArray(vals, opts, mem) + } + + switch elemType.Kind() { + case reflect.Slice: + if elemType == typeOfByteSlice { + return buildPrimitiveArray(vals, mem) + } + return buildListArray(vals, mem) + + case reflect.Array: + return buildFixedSizeListArray(vals, mem) + + case reflect.Map: + return buildMapArray(vals, mem) + + case reflect.Struct: + switch elemType { + case typeOfTime: + return buildTemporalArray(vals, opts, mem) + default: + return buildStructArray(vals, mem) + } + + default: + return buildPrimitiveArray(vals, mem) + } +} + +func buildPrimitiveArray(vals reflect.Value, mem memory.Allocator) (arrow.Array, error) { + elemType, isPtr := derefSliceElem(vals) + + dt, err := inferArrowType(elemType) + if err != nil { + return nil, err + } + + b := array.NewBuilder(mem, dt) + defer b.Release() + b.Reserve(vals.Len()) + + if err := iterSlice(vals, isPtr, b.AppendNull, func(v reflect.Value) error { + return appendValue(b, v) + }); err != nil { + return nil, err + } + return b.NewArray(), nil +} + +func timeOfDayNanos(t time.Time) int64 { + t = t.UTC() + midnight := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) + return t.Sub(midnight).Nanoseconds() +} + +func asTime(v reflect.Value) (time.Time, error) { + t, ok := reflect.TypeAssert[time.Time](v) + if !ok { + return time.Time{}, fmt.Errorf("expected time.Time, got %s: %w", v.Type(), ErrTypeMismatch) + } + return t, nil +} + +func asDuration(v reflect.Value) (time.Duration, error) { + d, ok := reflect.TypeAssert[time.Duration](v) + if !ok { + return 0, fmt.Errorf("expected time.Duration, got %s: %w", v.Type(), ErrTypeMismatch) + } + return d, nil +} + +func derefSliceElem(vals reflect.Value) (elemType reflect.Type, isPtr bool) { + elemType = vals.Type().Elem() + isPtr = elemType.Kind() == reflect.Ptr + for elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + return +} + +func iterSlice(vals reflect.Value, isPtr bool, appendNull func(), appendVal func(reflect.Value) error) error { + for i := 0; i < vals.Len(); i++ { + v := vals.Index(i) + if isPtr { + wasNull := false + for v.Kind() == reflect.Ptr { + if v.IsNil() { + appendNull() + wasNull = true + break + } + v = v.Elem() + } + if wasNull { + continue + } + } + if err := appendVal(v); err != nil { + return err + } + } + return nil +} + +func inferListElemDT(vals reflect.Value) (elemDT arrow.DataType, err error) { + outerSliceType, _ := derefSliceElem(vals) + innerElemType := outerSliceType.Elem() + for innerElemType.Kind() == reflect.Ptr { + innerElemType = innerElemType.Elem() + } + elemDT, err = inferArrowType(innerElemType) + return +} + +func temporalBuilder(opts tagOpts, mem memory.Allocator) array.Builder { + switch opts.Temporal { + case "date32": + return array.NewDate32Builder(mem) + case "date64": + return array.NewDate64Builder(mem) + case "time32": + return array.NewTime32Builder(mem, &arrow.Time32Type{Unit: arrow.Millisecond}) + case "time64": + return array.NewTime64Builder(mem, &arrow.Time64Type{Unit: arrow.Nanosecond}) + default: + return array.NewTimestampBuilder(mem, &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"}) + } +} + +func buildTemporalArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + elemType, isPtr := derefSliceElem(vals) + if elemType != typeOfTime { + return nil, fmt.Errorf("unsupported temporal type %v: %w", elemType, ErrUnsupportedType) + } + b := temporalBuilder(opts, mem) + defer b.Release() + b.Reserve(vals.Len()) + if err := iterSlice(vals, isPtr, b.AppendNull, func(v reflect.Value) error { + return appendTemporalValue(b, v) + }); err != nil { + return nil, err + } + return b.NewArray(), nil +} + +func decimalPrecisionScale(opts tagOpts, defaultPrec int32) (precision, scale int32) { + if opts.HasDecimalOpts { + return opts.DecimalPrecision, opts.DecimalScale + } + return defaultPrec, 0 +} + +func buildDecimalArray(vals reflect.Value, opts tagOpts, mem memory.Allocator) (arrow.Array, error) { + elemType, isPtr := derefSliceElem(vals) + + var b array.Builder + switch elemType { + case typeOfDec128: + p, s := decimalPrecisionScale(opts, dec128DefaultPrecision) + b = array.NewDecimal128Builder(mem, &arrow.Decimal128Type{Precision: p, Scale: s}) + case typeOfDec256: + p, s := decimalPrecisionScale(opts, dec256DefaultPrecision) + b = array.NewDecimal256Builder(mem, &arrow.Decimal256Type{Precision: p, Scale: s}) + case typeOfDec32: + p, s := decimalPrecisionScale(opts, dec32DefaultPrecision) + b = array.NewDecimal32Builder(mem, &arrow.Decimal32Type{Precision: p, Scale: s}) + case typeOfDec64: + p, s := decimalPrecisionScale(opts, dec64DefaultPrecision) + b = array.NewDecimal64Builder(mem, &arrow.Decimal64Type{Precision: p, Scale: s}) + default: + return nil, fmt.Errorf("unsupported decimal type %v: %w", elemType, ErrUnsupportedType) + } + defer b.Release() + b.Reserve(vals.Len()) + if err := iterSlice(vals, isPtr, b.AppendNull, func(v reflect.Value) error { + return appendDecimalValue(b, v) + }); err != nil { + return nil, err + } + return b.NewArray(), nil +} + +func appendStructFields(sb *array.StructBuilder, v reflect.Value, fields []fieldMeta) error { + sb.Append(true) + for fi, fm := range fields { + if err := appendValue(sb.FieldBuilder(fi), v.FieldByIndex(fm.Index)); err != nil { + return fmt.Errorf("struct field %q: %w", fm.Name, err) + } + } + return nil Review Comment: Fixed. Both `appendStructFields` and `setStructValue` now use a `fieldByIndexSafe` helper that walks the index path with nil checks at each pointer level. If an embedded pointer is nil, the Go→Arrow path appends null and the Arrow→Go path leaves the field at its zero value. ########## arrow/array/arreflect/reflect_arrow_to_go.go: ########## @@ -0,0 +1,430 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arreflect + +import ( + "fmt" + "reflect" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" +) + +func assertArray[T any](arr arrow.Array) (*T, error) { + a, ok := any(arr).(*T) + if !ok { + var zero T + return nil, fmt.Errorf("expected *%T, got %T: %w", zero, arr, ErrTypeMismatch) + } + return a, nil +} + +func isIntKind(k reflect.Kind) bool { + return k == reflect.Int || k == reflect.Int8 || k == reflect.Int16 || + k == reflect.Int32 || k == reflect.Int64 +} + +func isUintKind(k reflect.Kind) bool { + return k == reflect.Uint || k == reflect.Uint8 || k == reflect.Uint16 || + k == reflect.Uint32 || k == reflect.Uint64 || k == reflect.Uintptr +} + +func isFloatKind(k reflect.Kind) bool { return k == reflect.Float32 || k == reflect.Float64 } + +func setValue(v reflect.Value, arr arrow.Array, i int) error { + if arr.IsNull(i) { + v.Set(reflect.Zero(v.Type())) + return nil + } + if v.Kind() == reflect.Ptr { + v.Set(reflect.New(v.Type().Elem())) + v = v.Elem() + } + + switch arr.DataType().ID() { + case arrow.BOOL: + a, err := assertArray[array.Boolean](arr) + if err != nil { + return err + } + if v.Kind() != reflect.Bool { + return fmt.Errorf("cannot set bool into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetBool(a.Value(i)) + + case arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64, + arrow.UINT8, arrow.UINT16, arrow.UINT32, arrow.UINT64, + arrow.FLOAT32, arrow.FLOAT64: + return setPrimitiveValue(v, arr, i) + + case arrow.STRING, arrow.LARGE_STRING: + type stringer interface{ Value(int) string } + a, ok := arr.(stringer) + if !ok { + return fmt.Errorf("expected string array, got %T: %w", arr, ErrTypeMismatch) + } + if v.Kind() != reflect.String { + return fmt.Errorf("cannot set string into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetString(a.Value(i)) + + case arrow.BINARY, arrow.LARGE_BINARY: + type byter interface{ Value(int) []byte } + a, ok := arr.(byter) + if !ok { + return fmt.Errorf("expected binary array, got %T: %w", arr, ErrTypeMismatch) + } + if v.Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("cannot set []byte into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetBytes(a.Value(i)) + + case arrow.TIMESTAMP, arrow.DATE32, arrow.DATE64, + arrow.TIME32, arrow.TIME64, arrow.DURATION: + return setTemporalValue(v, arr, i) + + case arrow.DECIMAL128, arrow.DECIMAL256, arrow.DECIMAL32, arrow.DECIMAL64: + return setDecimalValue(v, arr, i) + + case arrow.STRUCT: + a, err := assertArray[array.Struct](arr) + if err != nil { + return err + } + return setStructValue(v, a, i) + + case arrow.LIST, arrow.LARGE_LIST, arrow.LIST_VIEW, arrow.LARGE_LIST_VIEW: + a, ok := arr.(array.ListLike) + if !ok { + return fmt.Errorf("expected ListLike, got %T: %w", arr, ErrTypeMismatch) + } + return setListValue(v, a, i) + + case arrow.MAP: + a, err := assertArray[array.Map](arr) + if err != nil { + return err + } + return setMapValue(v, a, i) + + case arrow.FIXED_SIZE_LIST: + a, err := assertArray[array.FixedSizeList](arr) + if err != nil { + return err + } + return setFixedSizeListValue(v, a, i) + + case arrow.DICTIONARY: + a, err := assertArray[array.Dictionary](arr) + if err != nil { + return err + } + return setDictionaryValue(v, a, i) + + case arrow.RUN_END_ENCODED: + a, err := assertArray[array.RunEndEncoded](arr) + if err != nil { + return err + } + return setRunEndEncodedValue(v, a, i) + + default: + return fmt.Errorf("unsupported Arrow type %v for reflection: %w", arr.DataType(), ErrUnsupportedType) + } + return nil +} + +func setPrimitiveValue(v reflect.Value, arr arrow.Array, i int) error { + switch arr.DataType().ID() { + case arrow.INT8: + if !isIntKind(v.Kind()) { + return fmt.Errorf("cannot set int8 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetInt(int64(arr.(*array.Int8).Value(i))) + case arrow.INT16: + if !isIntKind(v.Kind()) { + return fmt.Errorf("cannot set int16 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetInt(int64(arr.(*array.Int16).Value(i))) + case arrow.INT32: + if !isIntKind(v.Kind()) { + return fmt.Errorf("cannot set int32 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetInt(int64(arr.(*array.Int32).Value(i))) + case arrow.INT64: + if !isIntKind(v.Kind()) { + return fmt.Errorf("cannot set int64 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetInt(arr.(*array.Int64).Value(i)) + case arrow.UINT8: + if !isUintKind(v.Kind()) { + return fmt.Errorf("cannot set uint8 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetUint(uint64(arr.(*array.Uint8).Value(i))) + case arrow.UINT16: + if !isUintKind(v.Kind()) { + return fmt.Errorf("cannot set uint16 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetUint(uint64(arr.(*array.Uint16).Value(i))) + case arrow.UINT32: + if !isUintKind(v.Kind()) { + return fmt.Errorf("cannot set uint32 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetUint(uint64(arr.(*array.Uint32).Value(i))) + case arrow.UINT64: + if !isUintKind(v.Kind()) { + return fmt.Errorf("cannot set uint64 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetUint(arr.(*array.Uint64).Value(i)) + case arrow.FLOAT32: + if !isFloatKind(v.Kind()) { + return fmt.Errorf("cannot set float32 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetFloat(float64(arr.(*array.Float32).Value(i))) + case arrow.FLOAT64: + if !isFloatKind(v.Kind()) { + return fmt.Errorf("cannot set float64 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.SetFloat(arr.(*array.Float64).Value(i)) + default: + return fmt.Errorf("unsupported primitive type %v: %w", arr.DataType(), ErrUnsupportedType) + } + return nil +} + +func setTime(v reflect.Value, t time.Time) error { + if v.Type() != typeOfTime { + return fmt.Errorf("cannot set time.Time into %s: %w", v.Type(), ErrTypeMismatch) + } + v.Set(reflect.ValueOf(t)) + return nil +} + +func setTemporalValue(v reflect.Value, arr arrow.Array, i int) error { + switch arr.DataType().ID() { + case arrow.TIMESTAMP: + a, err := assertArray[array.Timestamp](arr) + if err != nil { + return err + } + unit := arr.DataType().(*arrow.TimestampType).Unit + return setTime(v, a.Value(i).ToTime(unit)) + + case arrow.DATE32: + a, err := assertArray[array.Date32](arr) + if err != nil { + return err + } + return setTime(v, a.Value(i).ToTime()) + + case arrow.DATE64: + a, err := assertArray[array.Date64](arr) + if err != nil { + return err + } + return setTime(v, a.Value(i).ToTime()) + + case arrow.TIME32: + a, err := assertArray[array.Time32](arr) + if err != nil { + return err + } + unit := arr.DataType().(*arrow.Time32Type).Unit + return setTime(v, a.Value(i).ToTime(unit)) + + case arrow.TIME64: + a, err := assertArray[array.Time64](arr) + if err != nil { + return err + } + unit := arr.DataType().(*arrow.Time64Type).Unit + return setTime(v, a.Value(i).ToTime(unit)) + + case arrow.DURATION: + a, err := assertArray[array.Duration](arr) + if err != nil { + return err + } + if v.Type() != typeOfDuration { + return fmt.Errorf("cannot set time.Duration into %s: %w", v.Type(), ErrTypeMismatch) + } + unit := arr.DataType().(*arrow.DurationType).Unit + dur := time.Duration(a.Value(i)) * unit.Multiplier() + v.Set(reflect.ValueOf(dur)) + + default: + return fmt.Errorf("unsupported temporal type %v: %w", arr.DataType(), ErrUnsupportedType) + } + return nil +} + +func setDecimalValue(v reflect.Value, arr arrow.Array, i int) error { + switch arr.DataType().ID() { + case arrow.DECIMAL128: + a, err := assertArray[array.Decimal128](arr) + if err != nil { + return err + } + if v.Type() != typeOfDec128 { + return fmt.Errorf("cannot set decimal128.Num into %s: %w", v.Type(), ErrTypeMismatch) + } + num := a.Value(i) + v.Set(reflect.ValueOf(num)) + + case arrow.DECIMAL256: + a, err := assertArray[array.Decimal256](arr) + if err != nil { + return err + } + if v.Type() != typeOfDec256 { + return fmt.Errorf("cannot set decimal256.Num into %s: %w", v.Type(), ErrTypeMismatch) + } + num := a.Value(i) + v.Set(reflect.ValueOf(num)) + + case arrow.DECIMAL32: + a, err := assertArray[array.Decimal32](arr) + if err != nil { + return err + } + if v.Type() != typeOfDec32 { + return fmt.Errorf("cannot set decimal.Decimal32 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.Set(reflect.ValueOf(a.Value(i))) + + case arrow.DECIMAL64: + a, err := assertArray[array.Decimal64](arr) + if err != nil { + return err + } + if v.Type() != typeOfDec64 { + return fmt.Errorf("cannot set decimal.Decimal64 into %s: %w", v.Type(), ErrTypeMismatch) + } + v.Set(reflect.ValueOf(a.Value(i))) + + default: + return fmt.Errorf("unsupported decimal type %v: %w", arr.DataType(), ErrUnsupportedType) + } + return nil +} + +func setStructValue(v reflect.Value, sa *array.Struct, i int) error { + if v.Kind() != reflect.Struct { + return fmt.Errorf("cannot set struct into %s: %w", v.Type(), ErrTypeMismatch) + } + + fields := cachedStructFields(v.Type()) + st := sa.DataType().(*arrow.StructType) + + for _, fm := range fields { + arrowIdx, found := st.FieldIdx(fm.Name) + if !found { + continue + } + if err := setValue(v.FieldByIndex(fm.Index), sa.Field(arrowIdx), i); err != nil { + return fmt.Errorf("arreflect: field %q: %w", fm.Name, err) + } + } Review Comment: Fixed alongside the `appendStructFields` issue. Introduced `fieldByIndexSafe` which walks the index path manually with nil checks at each pointer dereference, preventing the panic from nil embedded pointers. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
