lidavidm commented on code in PR #771:
URL: https://github.com/apache/arrow-go/pull/771#discussion_r3120926995


##########
arrow/array/arreflect/reflect.go:
##########
@@ -0,0 +1,586 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arreflect
+
+import (
+       "errors"
+       "fmt"
+       "reflect"
+       "sort"
+       "strconv"
+       "strings"
+       "sync"
+
+       "github.com/apache/arrow-go/v18/arrow"
+       "github.com/apache/arrow-go/v18/arrow/array"
+       "github.com/apache/arrow-go/v18/arrow/memory"
+)
+
+var (
+       ErrUnsupportedType = errors.New("arreflect: unsupported type")
+       ErrTypeMismatch    = errors.New("arreflect: type mismatch")
+)
+
+type tagOpts struct {
+       Name             string
+       Skip             bool
+       Dict             bool
+       ListView         bool
+       REE              bool
+       DecimalPrecision int32
+       DecimalScale     int32
+       HasDecimalOpts   bool
+       Temporal         string // "timestamp" (default), "date32", "date64", 
"time32", "time64"
+       DecimalParseErr  string // diagnostic set when decimal(p,s) tag fails 
to parse; surfaced by validateOptions
+}
+
+type fieldMeta struct {
+       Name     string
+       Index    []int
+       Type     reflect.Type
+       Nullable bool
+       Opts     tagOpts
+}
+
+func parseTag(tag string) tagOpts {
+       if tag == "-" {
+               return tagOpts{Skip: true}
+       }
+
+       var name, rest string
+       if idx := strings.Index(tag, ","); idx >= 0 {
+               name = tag[:idx]
+               rest = tag[idx+1:]
+       } else {
+               name = tag
+               rest = ""
+       }
+
+       opts := tagOpts{Name: name}
+
+       if rest == "" {
+               return opts
+       }
+
+       parseOptions(&opts, rest)
+       return opts
+}
+
+func splitTagTokens(rest string) []string {
+       var tokens []string
+       depth := 0
+       start := 0
+       for i := 0; i < len(rest); i++ {
+               switch rest[i] {
+               case '(':
+                       depth++
+               case ')':
+                       depth--
+               case ',':
+                       if depth == 0 {
+                               tokens = append(tokens, 
strings.TrimSpace(rest[start:i]))
+                               start = i + 1
+                       }
+               }
+       }
+       if start < len(rest) {
+               tokens = append(tokens, strings.TrimSpace(rest[start:]))
+       }
+       return tokens
+}
+
+func parseOptions(opts *tagOpts, rest string) {
+       for _, token := range splitTagTokens(rest) {
+               if strings.HasPrefix(token, "decimal(") && 
strings.HasSuffix(token, ")") {
+                       parseDecimalOpt(opts, token)
+                       continue
+               }
+               switch token {
+               case "dict":
+                       opts.Dict = true
+               case "listview":
+                       opts.ListView = true
+               case "ree":
+                       opts.REE = true
+               case "date32", "date64", "time32", "time64", "timestamp":
+                       opts.Temporal = token
+               }

Review Comment:
   Should we error on unknown tokens? DecimalParseErr could be generalized



##########
arrow/array/arreflect/reflect.go:
##########
@@ -0,0 +1,586 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arreflect
+
+import (
+       "errors"
+       "fmt"
+       "reflect"
+       "sort"
+       "strconv"
+       "strings"
+       "sync"
+
+       "github.com/apache/arrow-go/v18/arrow"
+       "github.com/apache/arrow-go/v18/arrow/array"
+       "github.com/apache/arrow-go/v18/arrow/memory"
+)
+
+var (
+       ErrUnsupportedType = errors.New("arreflect: unsupported type")
+       ErrTypeMismatch    = errors.New("arreflect: type mismatch")
+)
+
+type tagOpts struct {
+       Name             string
+       Skip             bool
+       Dict             bool
+       ListView         bool
+       REE              bool
+       DecimalPrecision int32
+       DecimalScale     int32
+       HasDecimalOpts   bool
+       Temporal         string // "timestamp" (default), "date32", "date64", 
"time32", "time64"
+       DecimalParseErr  string // diagnostic set when decimal(p,s) tag fails 
to parse; surfaced by validateOptions
+}
+
+type fieldMeta struct {
+       Name     string
+       Index    []int
+       Type     reflect.Type
+       Nullable bool
+       Opts     tagOpts
+}
+
+func parseTag(tag string) tagOpts {
+       if tag == "-" {
+               return tagOpts{Skip: true}
+       }
+
+       var name, rest string
+       if idx := strings.Index(tag, ","); idx >= 0 {
+               name = tag[:idx]
+               rest = tag[idx+1:]
+       } else {
+               name = tag
+               rest = ""
+       }
+
+       opts := tagOpts{Name: name}
+
+       if rest == "" {
+               return opts
+       }
+
+       parseOptions(&opts, rest)
+       return opts
+}
+
+func splitTagTokens(rest string) []string {
+       var tokens []string
+       depth := 0
+       start := 0
+       for i := 0; i < len(rest); i++ {
+               switch rest[i] {
+               case '(':
+                       depth++
+               case ')':
+                       depth--
+               case ',':
+                       if depth == 0 {
+                               tokens = append(tokens, 
strings.TrimSpace(rest[start:i]))
+                               start = i + 1
+                       }
+               }
+       }
+       if start < len(rest) {
+               tokens = append(tokens, strings.TrimSpace(rest[start:]))
+       }
+       return tokens
+}
+
+func parseOptions(opts *tagOpts, rest string) {
+       for _, token := range splitTagTokens(rest) {
+               if strings.HasPrefix(token, "decimal(") && 
strings.HasSuffix(token, ")") {
+                       parseDecimalOpt(opts, token)
+                       continue
+               }
+               switch token {
+               case "dict":
+                       opts.Dict = true
+               case "listview":
+                       opts.ListView = true
+               case "ree":
+                       opts.REE = true
+               case "date32", "date64", "time32", "time64", "timestamp":
+                       opts.Temporal = token
+               }
+       }
+}
+
+func parseDecimalOpt(opts *tagOpts, token string) {
+       inner := strings.TrimPrefix(token, "decimal(")
+       inner = strings.TrimSuffix(inner, ")")
+       parts := strings.SplitN(inner, ",", 2)
+       if len(parts) != 2 {
+               opts.DecimalParseErr = fmt.Sprintf("invalid decimal tag %q: 
expected decimal(precision,scale)", token)
+               return
+       }
+       p, errP := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 32)
+       if errP != nil {
+               opts.DecimalParseErr = fmt.Sprintf("invalid decimal tag %q: 
precision %q is not an integer", token, strings.TrimSpace(parts[0]))
+               return
+       }
+       s, errS := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 32)
+       if errS != nil {
+               opts.DecimalParseErr = fmt.Sprintf("invalid decimal tag %q: 
scale %q is not an integer", token, strings.TrimSpace(parts[1]))
+               return
+       }
+       opts.HasDecimalOpts = true
+       opts.DecimalPrecision = int32(p)
+       opts.DecimalScale = int32(s)
+}
+
+type bfsEntry struct {
+       t     reflect.Type
+       index []int
+       depth int
+}
+
+type candidate struct {
+       meta   fieldMeta
+       depth  int
+       tagged bool
+       order  int
+}
+
+type resolvedField struct {
+       meta  fieldMeta
+       order int
+}
+
+func collectFieldCandidates(t reflect.Type) map[string][]candidate {
+       nameMap := make(map[string][]candidate)
+       orderCounter := 0
+
+       queue := []bfsEntry{{t: t, index: nil, depth: 0}}
+       visited := make(map[reflect.Type]bool)
+
+       for len(queue) > 0 {
+               entry := queue[0]
+               queue = queue[1:]
+
+               st := entry.t
+               for st.Kind() == reflect.Ptr {
+                       st = st.Elem()
+               }
+               if st.Kind() != reflect.Struct {
+                       continue
+               }
+
+               if visited[st] {
+                       continue
+               }
+               if entry.depth > 0 {
+                       visited[st] = true
+               }
+
+               for i := 0; i < st.NumField(); i++ {
+                       sf := st.Field(i)
+
+                       fullIndex := make([]int, len(entry.index)+1)
+                       copy(fullIndex, entry.index)
+                       fullIndex[len(entry.index)] = i
+
+                       if !sf.IsExported() && !sf.Anonymous {
+                               continue
+                       }
+
+                       tagVal, hasTag := sf.Tag.Lookup("arrow")
+                       var opts tagOpts
+                       if hasTag {
+                               opts = parseTag(tagVal)
+                       }
+
+                       if opts.Skip {
+                               continue
+                       }
+
+                       arrowName := opts.Name
+                       if arrowName == "" {
+                               arrowName = sf.Name
+                       }
+
+                       if sf.Anonymous && !hasTag {
+                               ft := sf.Type
+                               for ft.Kind() == reflect.Ptr {
+                                       ft = ft.Elem()
+                               }
+                               if ft.Kind() == reflect.Struct {
+                                       queue = append(queue, bfsEntry{
+                                               t:     ft,
+                                               index: fullIndex,
+                                               depth: entry.depth + 1,
+                                       })
+                                       continue
+                               }
+                       }
+
+                       nullable := sf.Type.Kind() == reflect.Ptr
+                       tagged := hasTag && opts.Name != ""
+
+                       meta := fieldMeta{
+                               Name:     arrowName,
+                               Index:    fullIndex,
+                               Type:     sf.Type,
+                               Nullable: nullable,
+                               Opts:     opts,
+                       }
+
+                       existingCands := nameMap[arrowName]
+                       order := orderCounter
+                       if len(existingCands) > 0 {
+                               order = existingCands[0].order
+                       } else {
+                               orderCounter++
+                       }
+
+                       nameMap[arrowName] = append(existingCands, candidate{
+                               meta:   meta,
+                               depth:  entry.depth,
+                               tagged: tagged,
+                               order:  order,
+                       })
+               }
+       }
+
+       return nameMap
+}
+
+func resolveFieldCandidates(nameMap map[string][]candidate) []fieldMeta {
+       resolved := make([]resolvedField, 0, len(nameMap))
+       for _, candidates := range nameMap {
+               minDepth := candidates[0].depth
+               for _, c := range candidates[1:] {
+                       if c.depth < minDepth {
+                               minDepth = c.depth
+                       }
+               }
+
+               var atMin []candidate
+               for _, c := range candidates {
+                       if c.depth == minDepth {
+                               atMin = append(atMin, c)
+                       }
+               }
+
+               var winner *candidate
+               if len(atMin) == 1 {
+                       winner = &atMin[0]
+               } else {
+                       var tagged []candidate
+                       for _, c := range atMin {
+                               if c.tagged {
+                                       tagged = append(tagged, c)
+                               }
+                       }
+                       if len(tagged) == 1 {
+                               winner = &tagged[0]
+                       }
+               }
+
+               if winner != nil {
+                       resolved = append(resolved, resolvedField{meta: 
winner.meta, order: winner.order})
+               }
+       }
+
+       sort.Slice(resolved, func(i, j int) bool {
+               return resolved[i].order < resolved[j].order
+       })
+
+       result := make([]fieldMeta, len(resolved))
+       for i, r := range resolved {
+               result[i] = r.meta
+       }
+       return result
+}
+
+func getStructFields(t reflect.Type) []fieldMeta {
+       for t.Kind() == reflect.Ptr {
+               t = t.Elem()
+       }
+
+       if t.Kind() != reflect.Struct {
+               return nil
+       }
+
+       return resolveFieldCandidates(collectFieldCandidates(t))
+}
+
+var structFieldCache sync.Map
+
+func cachedStructFields(t reflect.Type) []fieldMeta {
+       for t.Kind() == reflect.Ptr {
+               t = t.Elem()
+       }
+
+       if v, ok := structFieldCache.Load(t); ok {
+               return v.([]fieldMeta)
+       }
+
+       fields := getStructFields(t)
+       v, _ := structFieldCache.LoadOrStore(t, fields)
+       return v.([]fieldMeta)
+}
+
+func fieldByIndexSafe(v reflect.Value, index []int) (reflect.Value, bool) {
+       for _, idx := range index {
+               if v.Kind() == reflect.Ptr {
+                       if v.IsNil() {
+                               return reflect.Value{}, false
+                       }
+                       v = v.Elem()
+               }
+               v = v.Field(idx)
+       }
+       return v, true
+}
+
+func At[T any](arr arrow.Array, i int) (T, error) {
+       var result T
+       v := reflect.ValueOf(&result).Elem()
+       if err := setValue(v, arr, i); err != nil {
+               var zero T
+               return zero, err
+       }
+       return result, nil
+}
+
+func ToSlice[T any](arr arrow.Array) ([]T, error) {
+       n := arr.Len()
+       result := make([]T, n)
+       for i := 0; i < n; i++ {
+               v := reflect.ValueOf(&result[i]).Elem()
+               if err := setValue(v, arr, i); err != nil {
+                       return nil, fmt.Errorf("index %d: %w", i, err)
+               }
+       }
+       return result, nil
+}
+
+// Option configures encoding behavior for [FromSlice] and [RecordFromSlice].
+type Option func(*tagOpts)
+
+// WithDict requests dictionary encoding for the top-level array.
+func WithDict() Option { return func(o *tagOpts) { o.Dict = true } }
+
+// WithListView requests ListView encoding instead of List for slice types.
+func WithListView() Option { return func(o *tagOpts) { o.ListView = true } }

Review Comment:
   Do we need similar options for StringView/could this just be an option to 
use View types? (Also for large types?)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to