This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push:
new 8598fb3 feat(arrow/compute): support some float16 casts (#430)
8598fb3 is described below
commit 8598fb3bd433b713424c4390fc9b785fc33fa5ee
Author: Matt Topol <[email protected]>
AuthorDate: Mon Jul 7 15:24:33 2025 -0400
feat(arrow/compute): support some float16 casts (#430)
closes #424
### Rationale for this change
Support casting float16 arrays to/from int and float32/float64
### What changes are included in this PR?
Implementation of new casting kernels for cast_float, cast_half_float,
cast_int32
### Are these changes tested?
Unit tests are added to account for this.
### Are there any user-facing changes?
The only user-facing change is that float16.Num is no longer a struct
with bits member but instead a type definition.
As a result, using `float16.Num{}` must be replaced with
`float16.Num(0)` if it is used.
---
arrow/compute/cast_test.go | 8 +-
arrow/compute/internal/kernels/cast_numeric.go | 46 ++++++++-
arrow/compute/internal/kernels/helpers.go | 6 +-
arrow/compute/internal/kernels/numeric_cast.go | 133 ++++++++++++++++++++++---
arrow/float16/float16.go | 7 +-
arrow/float16/float16_test.go | 2 +-
6 files changed, 181 insertions(+), 21 deletions(-)
diff --git a/arrow/compute/cast_test.go b/arrow/compute/cast_test.go
index 4e5f0a5..370ced1 100644
--- a/arrow/compute/cast_test.go
+++ b/arrow/compute/cast_test.go
@@ -574,7 +574,7 @@ func (c *CastSuite) TestToIntDowncastUnsafe() {
}
func (c *CastSuite) TestFloatingToInt() {
- for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Float32,
arrow.PrimitiveTypes.Float64} {
+ for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Float32,
arrow.PrimitiveTypes.Float64, arrow.FixedWidthTypes.Float16} {
for _, to := range []arrow.DataType{arrow.PrimitiveTypes.Int32,
arrow.PrimitiveTypes.Int64} {
// float to int no truncation
c.checkCast(from, to, `[1.0, null, 0.0, -1.0, 5.0]`,
`[1, null, 0, -1, 5]`)
@@ -590,6 +590,12 @@ func (c *CastSuite) TestFloatingToInt() {
}
}
+func (c *CastSuite) TestFloat16ToFloating() {
+ for _, to := range []arrow.DataType{arrow.PrimitiveTypes.Float32,
arrow.PrimitiveTypes.Float64} {
+ c.checkCast(arrow.FixedWidthTypes.Float16, to, `[1.5, null,
0.0, -1.5, 5.5]`, `[1.5, null, 0.0, -1.5, 5.5]`)
+ }
+}
+
func (c *CastSuite) TestIntToFloating() {
for _, from := range []arrow.DataType{arrow.PrimitiveTypes.Uint32,
arrow.PrimitiveTypes.Int32} {
two24 := `[16777216, 16777217]`
diff --git a/arrow/compute/internal/kernels/cast_numeric.go
b/arrow/compute/internal/kernels/cast_numeric.go
index a177259..6969d82 100644
--- a/arrow/compute/internal/kernels/cast_numeric.go
+++ b/arrow/compute/internal/kernels/cast_numeric.go
@@ -22,6 +22,7 @@ import (
"unsafe"
"github.com/apache/arrow-go/v18/arrow"
+ "github.com/apache/arrow-go/v18/arrow/float16"
)
var castNumericUnsafe func(itype, otype arrow.Type, in, out []byte, len int) =
castNumericGo
@@ -32,7 +33,19 @@ func DoStaticCast[InT, OutT numeric](in []InT, out []OutT) {
}
}
-func reinterpret[T numeric](b []byte, len int) (res []T) {
+func DoFloat16Cast[InT numeric](in []InT, out []float16.Num) {
+ for i, v := range in {
+ out[i] = float16.New(float32(v))
+ }
+}
+
+func DoFloat16CastToNumber[OutT numeric](in []float16.Num, out []OutT) {
+ for i, v := range in {
+ out[i] = OutT(v.Float32())
+ }
+}
+
+func reinterpret[T numeric | float16.Num](b []byte, len int) (res []T) {
return unsafe.Slice((*T)(unsafe.Pointer(&b[0])), len)
}
@@ -54,6 +67,8 @@ func castNumberToNumberUnsafeImpl[T numeric](outT arrow.Type,
in []T, out []byte
DoStaticCast(in, reinterpret[int64](out, len(in)))
case arrow.UINT64:
DoStaticCast(in, reinterpret[uint64](out, len(in)))
+ case arrow.FLOAT16:
+ DoFloat16Cast(in, reinterpret[float16.Num](out, len(in)))
case arrow.FLOAT32:
DoStaticCast(in, reinterpret[float32](out, len(in)))
case arrow.FLOAT64:
@@ -61,6 +76,33 @@ func castNumberToNumberUnsafeImpl[T numeric](outT
arrow.Type, in []T, out []byte
}
}
+func castFloat16ToNumberUnsafeImpl(outT arrow.Type, in []float16.Num, out
[]byte) {
+ switch outT {
+ case arrow.INT8:
+ DoFloat16CastToNumber(in, reinterpret[int8](out, len(in)))
+ case arrow.UINT8:
+ DoFloat16CastToNumber(in, reinterpret[uint8](out, len(in)))
+ case arrow.INT16:
+ DoFloat16CastToNumber(in, reinterpret[int16](out, len(in)))
+ case arrow.UINT16:
+ DoFloat16CastToNumber(in, reinterpret[uint16](out, len(in)))
+ case arrow.INT32:
+ DoFloat16CastToNumber(in, reinterpret[int32](out, len(in)))
+ case arrow.UINT32:
+ DoFloat16CastToNumber(in, reinterpret[uint32](out, len(in)))
+ case arrow.INT64:
+ DoFloat16CastToNumber(in, reinterpret[int64](out, len(in)))
+ case arrow.UINT64:
+ DoFloat16CastToNumber(in, reinterpret[uint64](out, len(in)))
+ case arrow.FLOAT16:
+ copy(reinterpret[float16.Num](out, len(in)), in)
+ case arrow.FLOAT32:
+ DoFloat16CastToNumber(in, reinterpret[float32](out, len(in)))
+ case arrow.FLOAT64:
+ DoFloat16CastToNumber(in, reinterpret[float64](out, len(in)))
+ }
+}
+
func castNumericGo(itype, otype arrow.Type, in, out []byte, len int) {
switch itype {
case arrow.INT8:
@@ -79,6 +121,8 @@ func castNumericGo(itype, otype arrow.Type, in, out []byte,
len int) {
castNumberToNumberUnsafeImpl(otype, reinterpret[int64](in,
len), out)
case arrow.UINT64:
castNumberToNumberUnsafeImpl(otype, reinterpret[uint64](in,
len), out)
+ case arrow.FLOAT16:
+ castFloat16ToNumberUnsafeImpl(otype,
reinterpret[float16.Num](in, len), out)
case arrow.FLOAT32:
castNumberToNumberUnsafeImpl(otype, reinterpret[float32](in,
len), out)
case arrow.FLOAT64:
diff --git a/arrow/compute/internal/kernels/helpers.go
b/arrow/compute/internal/kernels/helpers.go
index 4a9ead1..ef5f0bb 100644
--- a/arrow/compute/internal/kernels/helpers.go
+++ b/arrow/compute/internal/kernels/helpers.go
@@ -695,7 +695,11 @@ func castNumberToNumberUnsafe(in, out *exec.ArraySpan) {
inputOffset := in.Type.(arrow.FixedWidthDataType).Bytes() *
int(in.Offset)
outputOffset := out.Type.(arrow.FixedWidthDataType).Bytes() *
int(out.Offset)
- castNumericUnsafe(in.Type.ID(), out.Type.ID(),
in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len))
+ if in.Type.ID() == arrow.FLOAT16 || out.Type.ID() == arrow.FLOAT16 {
+ castNumericGo(in.Type.ID(), out.Type.ID(),
in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len))
+ } else {
+ castNumericUnsafe(in.Type.ID(), out.Type.ID(),
in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len))
+ }
}
func MaxDecimalDigitsForInt(id arrow.Type) (int32, error) {
diff --git a/arrow/compute/internal/kernels/numeric_cast.go
b/arrow/compute/internal/kernels/numeric_cast.go
index 1e76709..7681b02 100644
--- a/arrow/compute/internal/kernels/numeric_cast.go
+++ b/arrow/compute/internal/kernels/numeric_cast.go
@@ -28,6 +28,7 @@ import (
"github.com/apache/arrow-go/v18/arrow/compute/exec"
"github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/arrow-go/v18/arrow/decimal256"
+ "github.com/apache/arrow-go/v18/arrow/float16"
"github.com/apache/arrow-go/v18/arrow/internal/debug"
"github.com/apache/arrow-go/v18/internal/bitutils"
"golang.org/x/exp/constraints"
@@ -506,6 +507,27 @@ func CastFloat64ToDecimal(ctx *exec.KernelCtx, batch
*exec.ExecSpan, out *exec.E
return executor(ctx, batch, out)
}
+func CastDecimalToFloat16(ctx *exec.KernelCtx, batch *exec.ExecSpan, out
*exec.ExecResult) error {
+ var (
+ executor exec.ArrayKernelExec
+ )
+
+ switch dt := batch.Values[0].Array.Type.(type) {
+ case *arrow.Decimal128Type:
+ scale := dt.Scale
+ executor = ScalarUnaryNotNull(func(_ *exec.KernelCtx, v
decimal128.Num, err *error) float16.Num {
+ return float16.New(v.ToFloat32(scale))
+ })
+ case *arrow.Decimal256Type:
+ scale := dt.Scale
+ executor = ScalarUnaryNotNull(func(_ *exec.KernelCtx, v
decimal256.Num, err *error) float16.Num {
+ return float16.New(v.ToFloat32(scale))
+ })
+ }
+
+ return executor(ctx, batch, out)
+}
+
func CastDecimalToFloating[OutT constraints.Float](ctx *exec.KernelCtx, batch
*exec.ExecSpan, out *exec.ExecResult) error {
var (
executor exec.ArrayKernelExec
@@ -543,13 +565,49 @@ func boolToNum[T numeric](_ *exec.KernelCtx, in []byte,
out []T) error {
return nil
}
-func checkFloatTrunc[InT constraints.Float, OutT arrow.IntType |
arrow.UintType](in, out *exec.ArraySpan) error {
- wasTrunc := func(out OutT, in InT) bool {
- return InT(out) != in
+func boolToFloat16(_ *exec.KernelCtx, in []byte, out []float16.Num) error {
+ var (
+ zero float16.Num
+ one = float16.New(1)
+ )
+
+ for i := range out {
+ if bitutil.BitIsSet(in, i) {
+ out[i] = one
+ } else {
+ out[i] = zero
+ }
}
- wasTruncMaybeNull := func(out OutT, in InT, isValid bool) bool {
- return isValid && (InT(out) != in)
+ return nil
+}
+
+func wasTrunc[InT constraints.Float | float16.Num, OutT arrow.IntType |
arrow.UintType](out OutT, in InT) bool {
+ switch v := any(in).(type) {
+ case float16.Num:
+ return float16.New(float32(out)) != v
+ case float32:
+ return float32(out) != v
+ case float64:
+ return float64(out) != v
+ default:
+ return false
+ }
+}
+
+func wasTruncMaybeNull[InT constraints.Float | float16.Num, OutT arrow.IntType
| arrow.UintType](out OutT, in InT, isValid bool) bool {
+ switch v := any(in).(type) {
+ case float16.Num:
+ return isValid && (float16.New(float32(out)) != v)
+ case float32:
+ return isValid && (float32(out) != v)
+ case float64:
+ return isValid && (float64(out) != v)
+ default:
+ return false
}
+}
+
+func checkFloatTrunc[InT constraints.Float | float16.Num, OutT arrow.IntType |
arrow.UintType](in, out *exec.ArraySpan) error {
getError := func(val InT) error {
return fmt.Errorf("%w: float value %f was truncated converting
to %s",
arrow.ErrInvalid, val, out.Type)
@@ -598,7 +656,7 @@ func checkFloatTrunc[InT constraints.Float, OutT
arrow.IntType | arrow.UintType]
return nil
}
-func checkFloatToIntTruncImpl[T constraints.Float](in, out *exec.ArraySpan)
error {
+func checkFloatToIntTruncImpl[T constraints.Float | float16.Num](in, out
*exec.ArraySpan) error {
switch out.Type.ID() {
case arrow.INT8:
return checkFloatTrunc[T, int8](in, out)
@@ -623,6 +681,8 @@ func checkFloatToIntTruncImpl[T constraints.Float](in, out
*exec.ArraySpan) erro
func checkFloatToIntTrunc(in, out *exec.ArraySpan) error {
switch in.Type.ID() {
+ case arrow.FLOAT16:
+ return checkFloatToIntTruncImpl[float16.Num](in, out)
case arrow.FLOAT32:
return checkFloatToIntTruncImpl[float32](in, out)
case arrow.FLOAT64:
@@ -729,6 +789,26 @@ func getParseStringExec[OffsetT int32 | int64](out
arrow.Type) exec.ArrayKernelE
panic("invalid type for getParseStringExec")
}
+func addFloat16Casts(outTy arrow.DataType, kernels []exec.ScalarKernel)
[]exec.ScalarKernel {
+ kernels = append(kernels, GetCommonCastKernels(outTy.ID(),
exec.NewOutputType(outTy))...)
+
+ kernels = append(kernels, exec.NewScalarKernel(
+
[]exec.InputType{exec.NewExactInput(arrow.FixedWidthTypes.Boolean)},
+ exec.NewOutputType(outTy), ScalarUnaryBoolArg(boolToFloat16),
nil))
+
+ for _, inTy := range []arrow.DataType{arrow.BinaryTypes.Binary,
arrow.BinaryTypes.String} {
+ kernels = append(kernels, exec.NewScalarKernel(
+ []exec.InputType{exec.NewExactInput(inTy)},
exec.NewOutputType(outTy),
+ getParseStringExec[int32](outTy.ID()), nil))
+ }
+ for _, inTy := range []arrow.DataType{arrow.BinaryTypes.LargeBinary,
arrow.BinaryTypes.LargeString} {
+ kernels = append(kernels, exec.NewScalarKernel(
+ []exec.InputType{exec.NewExactInput(inTy)},
exec.NewOutputType(outTy),
+ getParseStringExec[int64](outTy.ID()), nil))
+ }
+ return kernels
+}
+
func addCommonNumberCasts[T numeric](outTy arrow.DataType, kernels
[]exec.ScalarKernel) []exec.ScalarKernel {
kernels = append(kernels, GetCommonCastKernels(outTy.ID(),
exec.NewOutputType(outTy))...)
@@ -759,7 +839,7 @@ func GetCastToInteger[T arrow.IntType |
arrow.UintType](outType arrow.DataType)
CastIntToInt, nil))
}
- for _, inTy := range floatingTypes {
+ for _, inTy := range append(floatingTypes,
arrow.FixedWidthTypes.Float16) {
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(inTy)}, output,
CastFloatingToInteger, nil))
@@ -775,7 +855,7 @@ func GetCastToInteger[T arrow.IntType |
arrow.UintType](outType arrow.DataType)
return kernels
}
-func GetCastToFloating[T constraints.Float](outType arrow.DataType)
[]exec.ScalarKernel {
+func GetCastToFloating[T constraints.Float | float16.Num](outType
arrow.DataType) []exec.ScalarKernel {
kernels := make([]exec.ScalarKernel, 0)
output := exec.NewOutputType(outType)
@@ -785,19 +865,40 @@ func GetCastToFloating[T constraints.Float](outType
arrow.DataType) []exec.Scala
CastIntegerToFloating, nil))
}
- for _, inTy := range floatingTypes {
+ for _, inTy := range append(floatingTypes,
arrow.FixedWidthTypes.Float16) {
kernels = append(kernels, exec.NewScalarKernel(
[]exec.InputType{exec.NewExactInput(inTy)}, output,
CastFloatingToFloating, nil))
}
- kernels = addCommonNumberCasts[T](outType, kernels)
- kernels = append(kernels, exec.NewScalarKernel(
- []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)}, output,
- CastDecimalToFloating[T], nil))
- kernels = append(kernels, exec.NewScalarKernel(
- []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)}, output,
- CastDecimalToFloating[T], nil))
+ var z T
+ switch any(z).(type) {
+ case float16.Num:
+ kernels = addFloat16Casts(outType, kernels)
+ kernels = append(kernels, exec.NewScalarKernel(
+ []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)},
output,
+ CastDecimalToFloat16, nil))
+ kernels = append(kernels, exec.NewScalarKernel(
+ []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)},
output,
+ CastDecimalToFloat16, nil))
+ case float32:
+ kernels = addCommonNumberCasts[float32](outType, kernels)
+ kernels = append(kernels, exec.NewScalarKernel(
+ []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)},
output,
+ CastDecimalToFloating[float32], nil))
+ kernels = append(kernels, exec.NewScalarKernel(
+ []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)},
output,
+ CastDecimalToFloating[float32], nil))
+ case float64:
+ kernels = addCommonNumberCasts[float64](outType, kernels)
+ kernels = append(kernels, exec.NewScalarKernel(
+ []exec.InputType{exec.NewIDInput(arrow.DECIMAL128)},
output,
+ CastDecimalToFloating[float64], nil))
+ kernels = append(kernels, exec.NewScalarKernel(
+ []exec.InputType{exec.NewIDInput(arrow.DECIMAL256)},
output,
+ CastDecimalToFloating[float64], nil))
+ }
+
return kernels
}
diff --git a/arrow/float16/float16.go b/arrow/float16/float16.go
index 0aa4df8..f6c276b 100644
--- a/arrow/float16/float16.go
+++ b/arrow/float16/float16.go
@@ -18,6 +18,7 @@ package float16
import (
"encoding/binary"
+ "fmt"
"math"
"strconv"
)
@@ -58,6 +59,10 @@ func New(f float32) Num {
return Num{bits: (sn << 15) | uint16(res<<10) | fc}
}
+func (f Num) Format(s fmt.State, verb rune) {
+ fmt.Fprintf(s, fmt.FormatString(s, verb), f.Float32())
+}
+
func (f Num) Float32() float32 {
sn := uint32((f.bits >> 15) & 0x1)
exp := (f.bits >> 10) & 0x1f
@@ -179,7 +184,7 @@ func (n Num) IsInf() bool { return (n.bits & 0x7c00) ==
0x7c00 }
func (n Num) IsZero() bool { return (n.bits & 0x7fff) == 0 }
-func (f Num) Uint16() uint16 { return f.bits }
+func (f Num) Uint16() uint16 { return uint16(f.bits) }
func (f Num) String() string { return
strconv.FormatFloat(float64(f.Float32()), 'g', -1, 32) }
func Inf() Num { return Num{bits: 0x7c00} }
diff --git a/arrow/float16/float16_test.go b/arrow/float16/float16_test.go
index cfde440..9857eda 100644
--- a/arrow/float16/float16_test.go
+++ b/arrow/float16/float16_test.go
@@ -38,7 +38,7 @@ func TestFloat16(t *testing.T) {
f := k.Float32()
assert.Equal(t, v, f, "float32 values should be the same")
i := New(v)
- assert.Equal(t, k.bits, i.bits, "float16 values should be the
same")
+ assert.Equal(t, k, i, "float16 values should be the same")
assert.Equal(t, k.Uint16(), i.Uint16(), "float16 values should
be the same")
assert.Equal(t, k.String(), fmt.Sprintf("%v", v), "string
representation differ")
}