This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new a4cc441e0f ARROW-17390: [Go] Add union scalar types (#13860)
a4cc441e0f is described below

commit a4cc441e0faf47184847bdd205cea8c2a046a384
Author: Matt Topol <[email protected]>
AuthorDate: Thu Aug 11 16:49:27 2022 -0400

    ARROW-17390: [Go] Add union scalar types (#13860)
    
    Authored-by: Matt Topol <[email protected]>
    Signed-off-by: Matt Topol <[email protected]>
---
 go/arrow/scalar/nested.go      | 192 +++++++++++++++++++++++++++++
 go/arrow/scalar/scalar.go      |  81 +++++++++---
 go/arrow/scalar/scalar_test.go | 273 +++++++++++++++++++++++++++++++++++++++++
 3 files changed, 531 insertions(+), 15 deletions(-)

diff --git a/go/arrow/scalar/nested.go b/go/arrow/scalar/nested.go
index 756e383f5a..2d106e5071 100644
--- a/go/arrow/scalar/nested.go
+++ b/go/arrow/scalar/nested.go
@@ -520,3 +520,195 @@ func (s *Dictionary) GetEncodedValue() (Scalar, error) {
 func (s *Dictionary) value() interface{} {
        return s.Value.Index.value()
 }
+
+type Union interface {
+       Scalar
+       ChildValue() Scalar
+       Release()
+}
+
+type SparseUnion struct {
+       scalar
+
+       TypeCode arrow.UnionTypeCode
+       Value    []Scalar
+       ChildID  int
+}
+
+func (s *SparseUnion) equals(rhs Scalar) bool {
+       right := rhs.(*SparseUnion)
+       return Equals(s.ChildValue(), right.ChildValue())
+}
+
+func (s *SparseUnion) value() interface{} { return s.ChildValue() }
+
+func (s *SparseUnion) String() string {
+       dt := s.Type.(*arrow.SparseUnionType)
+       val := s.ChildValue()
+       return "union{" + dt.Fields()[dt.ChildIDs()[s.TypeCode]].String() + " = 
" + val.String() + "}"
+}
+
+func (s *SparseUnion) Release() {
+       for _, v := range s.Value {
+               if v, ok := v.(Releasable); ok {
+                       v.Release()
+               }
+       }
+}
+
+func (s *SparseUnion) Validate() (err error) {
+       dt := s.Type.(*arrow.SparseUnionType)
+       if len(dt.Fields()) != len(s.Value) {
+               return fmt.Errorf("sparse union scalar value had %d fields but 
type has %d fields", len(dt.Fields()), len(s.Value))
+       }
+
+       if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || 
dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
+               return fmt.Errorf("%s scalar has invalid type code %d", dt, 
s.TypeCode)
+       }
+
+       for i, f := range dt.Fields() {
+               v := s.Value[i]
+               if !arrow.TypeEqual(f.Type, v.DataType()) {
+                       return fmt.Errorf("%s value for field %s had incorrect 
type of %s", dt, f, v.DataType())
+               }
+               if err = v.Validate(); err != nil {
+                       return err
+               }
+       }
+       return
+}
+
+func (s *SparseUnion) ValidateFull() (err error) {
+       dt := s.Type.(*arrow.SparseUnionType)
+       if len(dt.Fields()) != len(s.Value) {
+               return fmt.Errorf("sparse union scalar value had %d fields but 
type has %d fields", len(dt.Fields()), len(s.Value))
+       }
+
+       if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || 
dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
+               return fmt.Errorf("%s scalar has invalid type code %d", dt, 
s.TypeCode)
+       }
+
+       for i, f := range dt.Fields() {
+               v := s.Value[i]
+               if !arrow.TypeEqual(f.Type, v.DataType()) {
+                       return fmt.Errorf("%s value for field %s had incorrect 
type of %s", dt, f, v.DataType())
+               }
+               if err = v.ValidateFull(); err != nil {
+                       return err
+               }
+       }
+       return
+}
+
+func (s *SparseUnion) CastTo(to arrow.DataType) (Scalar, error) {
+       if !s.Valid {
+               return MakeNullScalar(to), nil
+       }
+
+       switch to.ID() {
+       case arrow.STRING:
+               return NewStringScalar(s.String()), nil
+       case arrow.LARGE_STRING:
+               return NewLargeStringScalar(s.String()), nil
+       }
+
+       return nil, fmt.Errorf("cannot cast non-nil union to type other than 
string")
+}
+
+func (s *SparseUnion) ChildValue() Scalar { return s.Value[s.ChildID] }
+
+func NewSparseUnionScalar(val []Scalar, code arrow.UnionTypeCode, dt 
*arrow.SparseUnionType) *SparseUnion {
+       ret := &SparseUnion{
+               scalar:   scalar{dt, true},
+               TypeCode: code,
+               Value:    val,
+               ChildID:  dt.ChildIDs()[code],
+       }
+       ret.Valid = ret.Value[ret.ChildID].IsValid()
+       return ret
+}
+
+func NewSparseUnionScalarFromValue(val Scalar, idx int, dt 
*arrow.SparseUnionType) *SparseUnion {
+       code := dt.TypeCodes()[idx]
+       values := make([]Scalar, len(dt.Fields()))
+       for i, f := range dt.Fields() {
+               if i == idx {
+                       values[i] = val
+               } else {
+                       values[i] = MakeNullScalar(f.Type)
+               }
+       }
+       return NewSparseUnionScalar(values, code, dt)
+}
+
+type DenseUnion struct {
+       scalar
+
+       TypeCode arrow.UnionTypeCode
+       Value    Scalar
+}
+
+func (s *DenseUnion) equals(rhs Scalar) bool {
+       right := rhs.(*DenseUnion)
+       return Equals(s.Value, right.Value)
+}
+
+func (s *DenseUnion) value() interface{} { return s.ChildValue() }
+
+func (s *DenseUnion) String() string {
+       dt := s.Type.(*arrow.DenseUnionType)
+       return "union{" + dt.Fields()[dt.ChildIDs()[s.TypeCode]].String() + " = 
" + s.Value.String() + "}"
+}
+
+func (s *DenseUnion) Release() {
+       if v, ok := s.Value.(Releasable); ok {
+               v.Release()
+       }
+}
+
+func (s *DenseUnion) Validate() (err error) {
+       dt := s.Type.(*arrow.DenseUnionType)
+       if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || 
dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
+               return fmt.Errorf("%s scalar has invalid type code %d", dt, 
s.TypeCode)
+       }
+       fieldType := dt.Fields()[dt.ChildIDs()[s.TypeCode]].Type
+       if !arrow.TypeEqual(fieldType, s.Value.DataType()) {
+               return fmt.Errorf("%s scalar with type code %d should have an 
underlying value of type %s, got %s",
+                       s.Type, s.TypeCode, fieldType, s.Value.DataType())
+       }
+       return s.Value.Validate()
+}
+
+func (s *DenseUnion) ValidateFull() error {
+       dt := s.Type.(*arrow.DenseUnionType)
+       if s.TypeCode < 0 || int(s.TypeCode) >= len(dt.ChildIDs()) || 
dt.ChildIDs()[s.TypeCode] == arrow.InvalidUnionChildID {
+               return fmt.Errorf("%s scalar has invalid type code %d", dt, 
s.TypeCode)
+       }
+       fieldType := dt.Fields()[dt.ChildIDs()[s.TypeCode]].Type
+       if !arrow.TypeEqual(fieldType, s.Value.DataType()) {
+               return fmt.Errorf("%s scalar with type code %d should have an 
underlying value of type %s, got %s",
+                       s.Type, s.TypeCode, fieldType, s.Value.DataType())
+       }
+       return s.Value.ValidateFull()
+}
+
+func (s *DenseUnion) CastTo(to arrow.DataType) (Scalar, error) {
+       if !s.Valid {
+               return MakeNullScalar(to), nil
+       }
+
+       switch to.ID() {
+       case arrow.STRING:
+               return NewStringScalar(s.String()), nil
+       case arrow.LARGE_STRING:
+               return NewLargeStringScalar(s.String()), nil
+       }
+
+       return nil, fmt.Errorf("cannot cast non-nil union to type other than 
string")
+}
+
+func (s *DenseUnion) ChildValue() Scalar { return s.Value }
+
+func NewDenseUnionScalar(v Scalar, code arrow.UnionTypeCode, dt 
*arrow.DenseUnionType) *DenseUnion {
+       return &DenseUnion{scalar: scalar{dt, v.IsValid()}, TypeCode: code, 
Value: v}
+}
diff --git a/go/arrow/scalar/scalar.go b/go/arrow/scalar/scalar.go
index 5edc98584b..7ae8b03473 100644
--- a/go/arrow/scalar/scalar.go
+++ b/go/arrow/scalar/scalar.go
@@ -466,10 +466,6 @@ func MakeNullScalar(dt arrow.DataType) Scalar {
        return makeNullFn[byte(dt.ID()&0x3f)](dt)
 }
 
-func unsupportedScalarType(dt arrow.DataType) Scalar {
-       panic("unsupported scalar data type: " + dt.ID().String())
-}
-
 func invalidScalarType(dt arrow.DataType) Scalar {
        panic("invalid scalar type: " + dt.ID().String())
 }
@@ -516,17 +512,33 @@ func init() {
                arrow.DECIMAL128:              func(dt arrow.DataType) Scalar { 
return &Decimal128{scalar: scalar{dt, false}} },
                arrow.LIST:                    func(dt arrow.DataType) Scalar { 
return &List{scalar: scalar{dt, false}} },
                arrow.STRUCT:                  func(dt arrow.DataType) Scalar { 
return &Struct{scalar: scalar{dt, false}} },
-               arrow.SPARSE_UNION:            unsupportedScalarType,
-               arrow.DENSE_UNION:             unsupportedScalarType,
-               arrow.DICTIONARY:              func(dt arrow.DataType) Scalar { 
return NewNullDictScalar(dt) },
-               arrow.LARGE_STRING:            func(dt arrow.DataType) Scalar { 
return &LargeString{&String{&Binary{scalar: scalar{dt, false}}}} },
-               arrow.LARGE_BINARY:            func(dt arrow.DataType) Scalar { 
return &LargeBinary{&Binary{scalar: scalar{dt, false}}} },
-               arrow.LARGE_LIST:              func(dt arrow.DataType) Scalar { 
return &LargeList{&List{scalar: scalar{dt, false}}} },
-               arrow.DECIMAL256:              func(dt arrow.DataType) Scalar { 
return &Decimal256{scalar: scalar{dt, false}} },
-               arrow.MAP:                     func(dt arrow.DataType) Scalar { 
return &Map{&List{scalar: scalar{dt, false}}} },
-               arrow.EXTENSION:               func(dt arrow.DataType) Scalar { 
return &Extension{scalar: scalar{dt, false}} },
-               arrow.FIXED_SIZE_LIST:         func(dt arrow.DataType) Scalar { 
return &FixedSizeList{&List{scalar: scalar{dt, false}}} },
-               arrow.DURATION:                func(dt arrow.DataType) Scalar { 
return &Duration{scalar: scalar{dt, false}} },
+               arrow.SPARSE_UNION: func(dt arrow.DataType) Scalar {
+                       typ := dt.(*arrow.SparseUnionType)
+                       if len(typ.Fields()) == 0 {
+                               panic("cannot make scalar of empty union type")
+                       }
+                       values := make([]Scalar, len(typ.Fields()))
+                       for i, f := range typ.Fields() {
+                               values[i] = MakeNullScalar(f.Type)
+                       }
+                       return NewSparseUnionScalar(values, typ.TypeCodes()[0], 
typ)
+               },
+               arrow.DENSE_UNION: func(dt arrow.DataType) Scalar {
+                       typ := dt.(*arrow.DenseUnionType)
+                       if len(typ.Fields()) == 0 {
+                               panic("cannot make scalar of empty union type")
+                       }
+                       return 
NewDenseUnionScalar(MakeNullScalar(typ.Fields()[0].Type), typ.TypeCodes()[0], 
typ)
+               },
+               arrow.DICTIONARY:      func(dt arrow.DataType) Scalar { return 
NewNullDictScalar(dt) },
+               arrow.LARGE_STRING:    func(dt arrow.DataType) Scalar { return 
&LargeString{&String{&Binary{scalar: scalar{dt, false}}}} },
+               arrow.LARGE_BINARY:    func(dt arrow.DataType) Scalar { return 
&LargeBinary{&Binary{scalar: scalar{dt, false}}} },
+               arrow.LARGE_LIST:      func(dt arrow.DataType) Scalar { return 
&LargeList{&List{scalar: scalar{dt, false}}} },
+               arrow.DECIMAL256:      func(dt arrow.DataType) Scalar { return 
&Decimal256{scalar: scalar{dt, false}} },
+               arrow.MAP:             func(dt arrow.DataType) Scalar { return 
&Map{&List{scalar: scalar{dt, false}}} },
+               arrow.EXTENSION:       func(dt arrow.DataType) Scalar { return 
&Extension{scalar: scalar{dt, false}} },
+               arrow.FIXED_SIZE_LIST: func(dt arrow.DataType) Scalar { return 
&FixedSizeList{&List{scalar: scalar{dt, false}}} },
+               arrow.DURATION:        func(dt arrow.DataType) Scalar { return 
&Duration{scalar: scalar{dt, false}} },
                // invalid data types to fill out array size 2^6 - 1
                63: invalidScalarType,
        }
@@ -646,6 +658,39 @@ func GetScalar(arr arrow.Array, idx int) (Scalar, error) {
                scalar.Value.Dict = arr.Dictionary()
                scalar.Value.Dict.Retain()
                return scalar, nil
+       case *array.SparseUnion:
+               var err error
+               typeCode := arr.TypeCode(idx)
+               children := make([]Scalar, arr.NumFields())
+               defer func() {
+                       if err != nil {
+                               for _, c := range children {
+                                       if c == nil {
+                                               break
+                                       }
+
+                                       if v, ok := c.(Releasable); ok {
+                                               v.Release()
+                                       }
+                               }
+                       }
+               }()
+
+               for i := range arr.UnionType().Fields() {
+                       if children[i], err = GetScalar(arr.Field(i), idx); err 
!= nil {
+                               return nil, err
+                       }
+               }
+               return NewSparseUnionScalar(children, typeCode, 
arr.UnionType().(*arrow.SparseUnionType)), nil
+       case *array.DenseUnion:
+               typeCode := arr.TypeCode(idx)
+               child := arr.Field(arr.ChildID(idx))
+               offset := arr.ValueOffset(idx)
+               value, err := GetScalar(child, int(offset))
+               if err != nil {
+                       return nil, err
+               }
+               return NewDenseUnionScalar(value, typeCode, 
arr.UnionType().(*arrow.DenseUnionType)), nil
        }
 
        return nil, fmt.Errorf("cannot create scalar from array of type %s", 
arr.DataType())
@@ -902,6 +947,12 @@ func Hash(seed maphash.Seed, s Scalar) uint64 {
                return valueHash(s.Value.Days) & valueHash(s.Value.Milliseconds)
        case *MonthDayNanoInterval:
                return valueHash(s.Value.Months) & valueHash(s.Value.Days) & 
valueHash(s.Value.Nanoseconds)
+       case *SparseUnion:
+               // typecode is ignored when comparing for equality, so don't 
hash it either
+               out ^= Hash(seed, s.Value[s.ChildID])
+       case *DenseUnion:
+               // typecode is ignored when comparing equality, so don't hash 
it either
+               out ^= Hash(seed, s.Value)
        case PrimitiveScalar:
                h.Write(s.Data())
                hash()
diff --git a/go/arrow/scalar/scalar_test.go b/go/arrow/scalar/scalar_test.go
index 22f3bee20c..7b05cf4568 100644
--- a/go/arrow/scalar/scalar_test.go
+++ b/go/arrow/scalar/scalar_test.go
@@ -1143,3 +1143,276 @@ func TestDictionaryScalarValidateErrors(t *testing.T) {
                assert.Error(t, invalid.ValidateFull())
        }
 }
+
+func checkGetValidUnionScalar(t *testing.T, arr arrow.Array, idx int, 
expected, expectedValue scalar.Scalar) {
+       s, err := scalar.GetScalar(arr, idx)
+       assert.NoError(t, err)
+       assert.NoError(t, s.ValidateFull())
+       assert.True(t, scalar.Equals(expected, s))
+
+       assert.True(t, s.IsValid())
+       assert.True(t, scalar.Equals(s.(scalar.Union).ChildValue(), 
expectedValue), s, expectedValue)
+}
+
+func checkGetNullUnionScalar(t *testing.T, arr arrow.Array, idx int) {
+       s, err := scalar.GetScalar(arr, idx)
+       assert.NoError(t, err)
+       assert.True(t, scalar.Equals(scalar.MakeNullScalar(arr.DataType()), s))
+       assert.False(t, s.IsValid())
+       assert.False(t, s.(scalar.Union).ChildValue().IsValid())
+}
+
+func makeSparseUnionScalar(ty *arrow.SparseUnionType, val scalar.Scalar, idx 
int) scalar.Scalar {
+       return scalar.NewSparseUnionScalarFromValue(val, idx, ty)
+}
+
+func makeDenseUnionScalar(ty *arrow.DenseUnionType, val scalar.Scalar, idx 
int) scalar.Scalar {
+       return scalar.NewDenseUnionScalar(val, ty.TypeCodes()[idx], ty)
+}
+
+func makeSpecificNullScalar(dt arrow.UnionType, idx int) scalar.Scalar {
+       switch dt.Mode() {
+       case arrow.SparseMode:
+               values := make([]scalar.Scalar, len(dt.Fields()))
+               for i, f := range dt.Fields() {
+                       values[i] = scalar.MakeNullScalar(f.Type)
+               }
+               return scalar.NewSparseUnionScalar(values, dt.TypeCodes()[idx], 
dt.(*arrow.SparseUnionType))
+       case arrow.DenseMode:
+               code := dt.TypeCodes()[idx]
+               value := scalar.MakeNullScalar(dt.Fields()[idx].Type)
+               return scalar.NewDenseUnionScalar(value, code, 
dt.(*arrow.DenseUnionType))
+       }
+       return nil
+}
+
+type UnionScalarSuite struct {
+       suite.Suite
+
+       mode                                            arrow.UnionMode
+       dt                                              arrow.DataType
+       unionType                                       arrow.UnionType
+       alpha, beta, two, three                         scalar.Scalar
+       unionAlpha, unionBeta, unionTwo, unionThree     scalar.Scalar
+       unionOtherTwo, unionStringNull, unionNumberNull scalar.Scalar
+}
+
+func (s *UnionScalarSuite) scalarFromValue(idx int, val scalar.Scalar) 
scalar.Scalar {
+       switch s.mode {
+       case arrow.SparseMode:
+               return makeSparseUnionScalar(s.dt.(*arrow.SparseUnionType), 
val, idx)
+       case arrow.DenseMode:
+               return makeDenseUnionScalar(s.dt.(*arrow.DenseUnionType), val, 
idx)
+       }
+       return nil
+}
+
+func (s *UnionScalarSuite) specificNull(idx int) scalar.Scalar {
+       return makeSpecificNullScalar(s.unionType, idx)
+}
+
+func (s *UnionScalarSuite) SetupTest() {
+       s.dt = arrow.UnionOf(s.mode, []arrow.Field{
+               {Name: "string", Type: arrow.BinaryTypes.String, Nullable: 
true},
+               {Name: "number", Type: arrow.PrimitiveTypes.Uint64, Nullable: 
true},
+               {Name: "other_number", Type: arrow.PrimitiveTypes.Uint64, 
Nullable: true},
+       }, []arrow.UnionTypeCode{3, 42, 43})
+
+       s.unionType = s.dt.(arrow.UnionType)
+
+       s.alpha = scalar.MakeScalar("alpha")
+       s.beta = scalar.MakeScalar("beta")
+       s.two = scalar.MakeScalar(uint64(2))
+       s.three = scalar.MakeScalar(uint64(3))
+
+       s.unionAlpha = s.scalarFromValue(0, s.alpha)
+       s.unionBeta = s.scalarFromValue(0, s.beta)
+       s.unionTwo = s.scalarFromValue(1, s.two)
+       s.unionOtherTwo = s.scalarFromValue(2, s.two)
+       s.unionThree = s.scalarFromValue(1, s.three)
+       s.unionStringNull = s.specificNull(0)
+       s.unionNumberNull = s.specificNull(1)
+}
+
+func (s *UnionScalarSuite) TestValidate() {
+       s.NoError(s.unionAlpha.ValidateFull())
+       s.NoError(s.unionAlpha.Validate())
+       s.NoError(s.unionBeta.ValidateFull())
+       s.NoError(s.unionBeta.Validate())
+       s.NoError(s.unionTwo.ValidateFull())
+       s.NoError(s.unionTwo.Validate())
+       s.NoError(s.unionOtherTwo.ValidateFull())
+       s.NoError(s.unionOtherTwo.Validate())
+       s.NoError(s.unionThree.ValidateFull())
+       s.NoError(s.unionThree.Validate())
+       s.NoError(s.unionStringNull.ValidateFull())
+       s.NoError(s.unionStringNull.Validate())
+       s.NoError(s.unionNumberNull.ValidateFull())
+       s.NoError(s.unionNumberNull.Validate())
+}
+
+func (s *UnionScalarSuite) setTypeCode(sc scalar.Scalar, c 
arrow.UnionTypeCode) {
+       switch sc := sc.(type) {
+       case *scalar.SparseUnion:
+               sc.TypeCode = c
+       case *scalar.DenseUnion:
+               sc.TypeCode = c
+       }
+}
+
+func (s *UnionScalarSuite) setIsValid(sc scalar.Scalar, v bool) {
+       switch sc := sc.(type) {
+       case *scalar.SparseUnion:
+               sc.Valid = v
+       case *scalar.DenseUnion:
+               sc.Valid = v
+       }
+}
+
+func (s *UnionScalarSuite) TestValidateErrors() {
+       // type code doesn't exist
+       sc := s.scalarFromValue(0, s.alpha)
+
+       // invalid type code
+       s.setTypeCode(sc, 0)
+       s.Error(sc.Validate())
+       s.Error(sc.ValidateFull())
+
+       s.setIsValid(sc, false)
+       s.Error(sc.Validate())
+       s.Error(sc.ValidateFull())
+
+       s.setTypeCode(sc, -42)
+       s.setIsValid(sc, true)
+       s.Error(sc.Validate())
+       s.Error(sc.ValidateFull())
+
+       s.setIsValid(sc, false)
+       s.Error(sc.Validate())
+       s.Error(sc.ValidateFull())
+
+       // type code doesn't correspond to child type
+       if sc, ok := sc.(*scalar.DenseUnion); ok {
+               sc.TypeCode = 42
+               sc.Valid = true
+               s.Error(sc.Validate())
+               s.Error(sc.ValidateFull())
+
+               sc = s.scalarFromValue(2, s.two).(*scalar.DenseUnion)
+               sc.TypeCode = 3
+               s.Error(sc.Validate())
+               s.Error(sc.ValidateFull())
+       }
+
+       // underlying value has invalid utf8
+       sc = s.scalarFromValue(0, scalar.NewStringScalar("\xff"))
+       s.NoError(sc.Validate())
+       s.Error(sc.ValidateFull())
+}
+
+func (s *UnionScalarSuite) TestEquals() {
+       // differing values
+       s.False(scalar.Equals(s.unionAlpha, s.unionBeta))
+       s.False(scalar.Equals(s.unionTwo, s.unionThree))
+       // differing validities
+       s.False(scalar.Equals(s.unionAlpha, s.unionStringNull))
+       // differing types
+       s.False(scalar.Equals(s.unionAlpha, s.unionTwo))
+       s.False(scalar.Equals(s.unionAlpha, s.unionOtherTwo))
+       // type codes don't count when comparing union scalars: the underlying
+       // values are identical even though their provenance is different
+       s.True(scalar.Equals(s.unionTwo, s.unionOtherTwo))
+       s.True(scalar.Equals(s.unionStringNull, s.unionNumberNull))
+}
+
+func (s *UnionScalarSuite) TestMakeNullScalar() {
+       sc := scalar.MakeNullScalar(s.dt)
+       s.True(arrow.TypeEqual(s.dt, sc.DataType()))
+       s.False(sc.IsValid())
+
+       // the first child field is chosen arbitrarily for the purposes of
+       // making a null scalar
+       switch s.mode {
+       case arrow.DenseMode:
+               asDense := sc.(*scalar.DenseUnion)
+               s.EqualValues(3, asDense.TypeCode)
+               s.False(asDense.Value.IsValid())
+       case arrow.SparseMode:
+               asSparse := sc.(*scalar.SparseUnion)
+               s.EqualValues(3, asSparse.TypeCode)
+               s.False(asSparse.Value[asSparse.ChildID].IsValid())
+       }
+}
+
+type SparseUnionSuite struct {
+       UnionScalarSuite
+}
+
+func (s *SparseUnionSuite) SetupSuite() {
+       s.mode = arrow.SparseMode
+}
+
+func (s *SparseUnionSuite) TestGetScalar() {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(s.T(), 0)
+
+       children := make([]arrow.Array, 3)
+       children[0], _, _ = array.FromJSON(mem, arrow.BinaryTypes.String, 
strings.NewReader(`["alpha", "", "beta", null, "gamma"]`))
+       defer children[0].Release()
+       children[1], _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Uint64, 
strings.NewReader(`[1, 2, 11, 22, null]`))
+       defer children[1].Release()
+       children[2], _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Uint64, 
strings.NewReader(`[100, 101, 102, 103, 104]`))
+       defer children[2].Release()
+
+       typeIDs, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int8, 
strings.NewReader(`[3, 42, 3, 3, 42]`))
+       defer typeIDs.Release()
+
+       arr := array.NewSparseUnion(s.dt.(*arrow.SparseUnionType), 5, children, 
typeIDs.Data().Buffers()[1], 0)
+       defer arr.Release()
+
+       checkGetValidUnionScalar(s.T(), arr, 0, s.unionAlpha, s.alpha)
+       checkGetValidUnionScalar(s.T(), arr, 1, s.unionTwo, s.two)
+       checkGetValidUnionScalar(s.T(), arr, 2, s.unionBeta, s.beta)
+       checkGetNullUnionScalar(s.T(), arr, 3)
+       checkGetNullUnionScalar(s.T(), arr, 4)
+}
+
+type DenseUnionSuite struct {
+       UnionScalarSuite
+}
+
+func (s *DenseUnionSuite) SetupSuite() {
+       s.mode = arrow.DenseMode
+}
+
+func (s *DenseUnionSuite) TestGetScalar() {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(s.T(), 0)
+
+       children := make([]arrow.Array, 3)
+       children[0], _, _ = array.FromJSON(mem, arrow.BinaryTypes.String, 
strings.NewReader(`["alpha", "beta", null]`))
+       defer children[0].Release()
+       children[1], _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Uint64, 
strings.NewReader(`[2, 3]`))
+       defer children[1].Release()
+       children[2], _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Uint64, 
strings.NewReader(`[]`))
+       defer children[2].Release()
+
+       typeIDs, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int8, 
strings.NewReader(`[3, 42, 3, 3, 42]`))
+       defer typeIDs.Release()
+       offsets, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int32, 
strings.NewReader(`[0, 0, 1, 2, 1]`))
+       defer offsets.Release()
+
+       arr := array.NewDenseUnion(s.dt.(*arrow.DenseUnionType), 5, children, 
typeIDs.Data().Buffers()[1], offsets.Data().Buffers()[1], 0)
+       defer arr.Release()
+
+       checkGetValidUnionScalar(s.T(), arr, 0, s.unionAlpha, s.alpha)
+       checkGetValidUnionScalar(s.T(), arr, 1, s.unionTwo, s.two)
+       checkGetValidUnionScalar(s.T(), arr, 2, s.unionBeta, s.beta)
+       checkGetNullUnionScalar(s.T(), arr, 3)
+       checkGetValidUnionScalar(s.T(), arr, 4, s.unionThree, s.three)
+}
+
+func TestUnionScalars(t *testing.T) {
+       suite.Run(t, new(SparseUnionSuite))
+       suite.Run(t, new(DenseUnionSuite))
+}

Reply via email to