This is an automated email from the ASF dual-hosted git repository. uwe pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new e912675 ARROW-2936: [Python] Implement Table.cast for casting from one schema to another (if possible) e912675 is described below commit e91267555cc72c3cf0e5a472d3c89eefed69d6f7 Author: Krisztián Szűcs <szucs.kriszt...@gmail.com> AuthorDate: Wed Sep 12 16:33:41 2018 +0200 ARROW-2936: [Python] Implement Table.cast for casting from one schema to another (if possible) Also contains a fix for float truncation. Author: Krisztián Szűcs <szucs.kriszt...@gmail.com> Closes #2530 from kszucs/ARROW-2936 and squashes the following commits: 1d3b7ec0 <Krisztián Szűcs> unsafe cast assertion; py2 compatible tests ca44e219 <Krisztián Szűcs> apidoc 772a666f <Krisztián Szűcs> flake8 90fc3183 <Krisztián Szűcs> Table.cast implementation; fix float truncation casting rule --- cpp/src/arrow/compute/compute-test.cc | 44 +++++++++++++++---- cpp/src/arrow/compute/kernels/cast.cc | 16 +++---- python/pyarrow/table.pxi | 37 ++++++++++++++-- python/pyarrow/tests/test_array.py | 32 ++++++++++++++ python/pyarrow/tests/test_table.py | 82 +++++++++++++++++++++++++++++++++++ 5 files changed, 189 insertions(+), 22 deletions(-) diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc index a1dfdef..233f8a6 100644 --- a/cpp/src/arrow/compute/compute-test.cc +++ b/cpp/src/arrow/compute/compute-test.cc @@ -286,20 +286,21 @@ TEST_F(TestCast, ToIntDowncastUnsafe) { } TEST_F(TestCast, FloatingPointToInt) { + // which means allow_float_truncate == false auto options = CastOptions::Safe(); vector<bool> is_valid = {true, false, true, true, true}; vector<bool> all_valid = {true, true, true, true, true}; - // float32 point to integer - vector<float> v1 = {1.5, 0, 0.5, -1.5, 5.5}; + // float32 to int32 no truncation + vector<float> v1 = {1.0, 0, 0.0, -1.0, 5.0}; vector<int32_t> e1 = {1, 0, 0, -1, 5}; CheckCase<FloatType, float, Int32Type, int32_t>(float32(), v1, is_valid, int32(), e1, options); CheckCase<FloatType, float, Int32Type, int32_t>(float32(), v1, all_valid, int32(), e1, options); - // float64 point to integer + // float64 to int32 no truncation vector<double> v2 = {1.0, 0, 0.0, -1.0, 5.0}; vector<int32_t> e2 = {1, 0, 0, -1, 5}; CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v2, is_valid, int32(), e2, @@ -307,15 +308,40 @@ TEST_F(TestCast, FloatingPointToInt) { CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v2, all_valid, int32(), e2, options); - vector<double> v3 = {1.5, 0, 0.5, -1.5, 5.5}; - vector<int32_t> e3 = {1, 0, 0, -1, 5}; - CheckFails<DoubleType>(float64(), v3, is_valid, int32(), options); - CheckFails<DoubleType>(float64(), v3, all_valid, int32(), options); + // float64 to int64 no truncation + vector<double> v3 = {1.0, 0, 0.0, -1.0, 5.0}; + vector<int64_t> e3 = {1, 0, 0, -1, 5}; + CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v3, is_valid, int64(), e3, + options); + CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v3, all_valid, int64(), e3, + options); + + // float64 to int32 truncate + vector<double> v4 = {1.5, 0, 0.5, -1.5, 5.5}; + vector<int32_t> e4 = {1, 0, 0, -1, 5}; + + options.allow_float_truncate = false; + CheckFails<DoubleType>(float64(), v4, is_valid, int32(), options); + CheckFails<DoubleType>(float64(), v4, all_valid, int32(), options); + + options.allow_float_truncate = true; + CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v4, is_valid, int32(), e4, + options); + CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v4, all_valid, int32(), e4, + options); + + // float64 to int64 truncate + vector<double> v5 = {1.5, 0, 0.5, -1.5, 5.5}; + vector<int64_t> e5 = {1, 0, 0, -1, 5}; + + options.allow_float_truncate = false; + CheckFails<DoubleType>(float64(), v5, is_valid, int64(), options); + CheckFails<DoubleType>(float64(), v5, all_valid, int64(), options); options.allow_float_truncate = true; - CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v3, is_valid, int32(), e3, + CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v5, is_valid, int64(), e5, options); - CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v3, all_valid, int32(), e3, + CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v5, all_valid, int64(), e5, options); } diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 2a0479d..369ebb9 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -194,20 +194,16 @@ struct is_integer_downcast< }; template <typename O, typename I, typename Enable = void> -struct is_float_downcast { +struct is_float_truncate { static constexpr bool value = false; }; template <typename O, typename I> -struct is_float_downcast< +struct is_float_truncate< O, I, - typename std::enable_if<std::is_base_of<Number, O>::value && + typename std::enable_if<std::is_base_of<Integer, O>::value && std::is_base_of<FloatingPoint, I>::value>::type> { - using O_T = typename O::c_type; - using I_T = typename I::c_type; - - // Smaller output size - static constexpr bool value = !std::is_same<O, I>::value && (sizeof(O_T) < sizeof(I_T)); + static constexpr bool value = true; }; template <typename O, typename I> @@ -270,7 +266,7 @@ struct CastFunctor<O, I, }; template <typename O, typename I> -struct CastFunctor<O, I, typename std::enable_if<is_float_downcast<O, I>::value>::type> { +struct CastFunctor<O, I, typename std::enable_if<is_float_truncate<O, I>::value>::type> { void operator()(FunctionContext* ctx, const CastOptions& options, const ArrayData& input, ArrayData* output) { using in_type = typename I::c_type; @@ -316,7 +312,7 @@ struct CastFunctor<O, I, typename std::enable_if<is_float_downcast<O, I>::value> template <typename O, typename I> struct CastFunctor<O, I, typename std::enable_if<is_numeric_cast<O, I>::value && - !is_float_downcast<O, I>::value && + !is_float_truncate<O, I>::value && !is_integer_downcast<O, I>::value>::type> { void operator()(FunctionContext* ctx, const CastOptions& options, const ArrayData& input, ArrayData* output) { diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index bbf40e0..62f6803 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -638,8 +638,8 @@ cdef _schema_from_arrays(arrays, names, dict metadata, raise ValueError('Must pass names when constructing ' 'from Array objects') if len(names) != K: - raise ValueError("Length of names ({}) does not match " - "length of arrays ({})".format(len(names), K)) + raise ValueError('Length of names ({}) does not match ' + 'length of arrays ({})'.format(len(names), K)) for i in range(K): val = arrays[i] if isinstance(val, (Array, ChunkedArray)): @@ -760,7 +760,7 @@ cdef class RecordBatch: def column(self, i): """ - Select single column from record batcha + Select single column from record batch Returns ------- @@ -1078,6 +1078,37 @@ cdef class Table: return result + def cast(self, Schema target_schema, bint safe=True): + """ + Cast table values to another schema + + Parameters + ---------- + target_schema : Schema + Schema to cast to, the names and order of fields must match + safe : boolean, default True + Check for overflows or other unsafe conversions + + Returns + ------- + casted : Table + """ + cdef: + Column column, casted + Field field + list newcols = [] + + if self.schema.names != target_schema.names: + raise ValueError("Target schema's field names are not matching " + "the table's field names: {!r}, {!r}" + .format(self.schema.names, target_schema.names)) + + for column, field in zip(self.itercolumns(), target_schema): + casted = column.cast(field.type, safe=safe) + newcols.append(casted) + + return Table.from_arrays(newcols, schema=target_schema) + @classmethod def from_pandas(cls, df, Schema schema=None, bint preserve_index=True, nthreads=None, columns=None, bint safe=True): diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index d4b582e..0002dce 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -549,6 +549,38 @@ def test_cast_integers_unsafe(): _check_cast_case(case, safe=False) +def test_floating_point_truncate_safe(): + safe_cases = [ + (np.array([1.0, 2.0, 3.0], dtype='float32'), 'float32', + np.array([1, 2, 3], dtype='i4'), pa.int32()), + (np.array([1.0, 2.0, 3.0], dtype='float64'), 'float64', + np.array([1, 2, 3], dtype='i4'), pa.int32()), + (np.array([-10.0, 20.0, -30.0], dtype='float64'), 'float64', + np.array([-10, 20, -30], dtype='i4'), pa.int32()), + ] + for case in safe_cases: + _check_cast_case(case, safe=True) + + +def test_floating_point_truncate_unsafe(): + unsafe_cases = [ + (np.array([1.1, 2.2, 3.3], dtype='float32'), 'float32', + np.array([1, 2, 3], dtype='i4'), pa.int32()), + (np.array([1.1, 2.2, 3.3], dtype='float64'), 'float64', + np.array([1, 2, 3], dtype='i4'), pa.int32()), + (np.array([-10.1, 20.2, -30.3], dtype='float64'), 'float64', + np.array([-10, 20, -30], dtype='i4'), pa.int32()), + ] + for case in unsafe_cases: + # test safe casting raises + with pytest.raises(pa.ArrowInvalid, + match='Floating point value truncated'): + _check_cast_case(case, safe=True) + + # test unsafe casting truncates + _check_cast_case(case, safe=False) + + def test_cast_timestamp_unit(): # ARROW-1680 val = datetime.datetime.now() diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 14609ad..f45e918 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -749,3 +749,85 @@ def test_table_negative_indexing(): with pytest.raises(IndexError): table[4] + + +def test_table_cast_to_incompatible_schema(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + ] + table = pa.Table.from_arrays(data, names=tuple('ab')) + + target_schema1 = pa.schema([ + pa.field('A', pa.int32()), + pa.field('b', pa.int16()), + ]) + target_schema2 = pa.schema([ + pa.field('a', pa.int32()), + ]) + message = ("Target schema's field names are not matching the table's " + "field names:.*") + with pytest.raises(ValueError, match=message): + table.cast(target_schema1) + with pytest.raises(ValueError, match=message): + table.cast(target_schema2) + + +def test_table_safe_casting(): + data = [ + pa.array(range(5), type=pa.int64()), + pa.array([-10, -5, 0, 5, 10], type=pa.int32()), + pa.array([1.0, 2.0, 3.0], type=pa.float64()), + pa.array(['ab', 'bc', 'cd'], type=pa.string()) + ] + table = pa.Table.from_arrays(data, names=tuple('abcd')) + + expected_data = [ + pa.array(range(5), type=pa.int32()), + pa.array([-10, -5, 0, 5, 10], type=pa.int16()), + pa.array([1, 2, 3], type=pa.int64()), + pa.array(['ab', 'bc', 'cd'], type=pa.string()) + ] + expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd')) + + target_schema = pa.schema([ + pa.field('a', pa.int32()), + pa.field('b', pa.int16()), + pa.field('c', pa.int64()), + pa.field('d', pa.string()) + ]) + casted_table = table.cast(target_schema) + + assert casted_table.equals(expected_table) + + +def test_table_unsafe_casting(): + data = [ + pa.array(range(5), type=pa.int64()), + pa.array([-10, -5, 0, 5, 10], type=pa.int32()), + pa.array([1.1, 2.2, 3.3], type=pa.float64()), + pa.array(['ab', 'bc', 'cd'], type=pa.string()) + ] + table = pa.Table.from_arrays(data, names=tuple('abcd')) + + expected_data = [ + pa.array(range(5), type=pa.int32()), + pa.array([-10, -5, 0, 5, 10], type=pa.int16()), + pa.array([1, 2, 3], type=pa.int64()), + pa.array(['ab', 'bc', 'cd'], type=pa.string()) + ] + expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd')) + + target_schema = pa.schema([ + pa.field('a', pa.int32()), + pa.field('b', pa.int16()), + pa.field('c', pa.int64()), + pa.field('d', pa.string()) + ]) + + with pytest.raises(pa.ArrowInvalid, + match='Floating point value truncated'): + table.cast(target_schema) + + casted_table = table.cast(target_schema, safe=False) + assert casted_table.equals(expected_table)