This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 241b99acbb GH-37411: [C++][Python] Add string -> date cast kernel (fix 
python scalar cast) (#38038)
241b99acbb is described below

commit 241b99acbb97df93cfa02429cc712391a37de5e3
Author: Joris Van den Bossche <[email protected]>
AuthorDate: Tue Oct 10 19:44:49 2023 +0200

    GH-37411: [C++][Python] Add string -> date cast kernel (fix python scalar 
cast) (#38038)
    
    ### Rationale for this change
    
    Adding `string -> date32/date64` cast kernels, which then also fixes the 
pyarrow scalar cast method (which was earlier refactored to rely on the general 
cast kernels)
    
    * Closes: #37411
    
    Authored-by: Joris Van den Bossche <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 .../arrow/compute/kernels/scalar_cast_numeric.cc   |  3 +-
 .../arrow/compute/kernels/scalar_cast_temporal.cc  | 49 ++++++++++++++++++
 cpp/src/arrow/compute/kernels/scalar_cast_test.cc  | 21 +++++++-
 cpp/src/arrow/util/value_parsing.h                 | 58 +++++++++++-----------
 python/pyarrow/tests/test_scalars.py               |  7 +++
 5 files changed, 106 insertions(+), 32 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc 
b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
index a02f83351b..b054e57f04 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
@@ -286,7 +286,8 @@ struct ParseString {
 };
 
 template <typename O, typename I>
-struct CastFunctor<O, I, enable_if_base_binary<I>> {
+struct CastFunctor<
+    O, I, enable_if_t<(is_number_type<O>::value && 
is_base_binary_type<I>::value)>> {
   static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* 
out) {
     return applicator::ScalarUnaryNotNull<O, I, ParseString<O>>::Exec(ctx, 
batch, out);
   }
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc 
b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc
index 50d24ecab0..a561264391 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc
@@ -30,6 +30,7 @@
 namespace arrow {
 
 using internal::ParseTimestampISO8601;
+using internal::ParseYYYY_MM_DD;
 
 namespace compute {
 namespace internal {
@@ -451,6 +452,44 @@ struct CastFunctor<TimestampType, I, 
enable_if_t<is_base_binary_type<I>::value>>
   }
 };
 
+template <typename DateType>
+struct ParseDate {
+  using value_type = typename DateType::c_type;
+
+  using duration_type =
+      typename std::conditional<std::is_same<DateType, Date32Type>::value,
+                                arrow_vendored::date::days,
+                                std::chrono::milliseconds>::type;
+
+  template <typename OutValue, typename Arg0Value>
+  OutValue Call(KernelContext* ctx, Arg0Value val, Status* st) const {
+    OutValue result = OutValue(0);
+
+    if (ARROW_PREDICT_FALSE(val.size() != 10)) {
+      *st = Status::Invalid("Failed to parse string: '", val, "' as a scalar 
of type ",
+                            
TypeTraits<DateType>::type_singleton()->ToString());
+      return result;
+    }
+
+    duration_type since_epoch;
+    if (ARROW_PREDICT_FALSE(!ParseYYYY_MM_DD(val.data(), &since_epoch))) {
+      *st = Status::Invalid("Failed to parse string: '", val, "' as a scalar 
of type ",
+                            
TypeTraits<DateType>::type_singleton()->ToString());
+    } else {
+      result = static_cast<value_type>(since_epoch.count());
+    }
+    return result;
+  }
+};
+
+template <typename O, typename I>
+struct CastFunctor<O, I,
+                   enable_if_t<(is_date_type<O>::value && 
is_string_type<I>::value)>> {
+  static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* 
out) {
+    return applicator::ScalarUnaryNotNull<O, I, ParseDate<O>>::Exec(ctx, 
batch, out);
+  }
+};
+
 template <typename Type>
 void AddCrossUnitCast(CastFunction* func) {
   ScalarKernel kernel;
@@ -483,6 +522,11 @@ std::shared_ptr<CastFunction> GetDate32Cast() {
   // timestamp -> date32
   AddSimpleCast<TimestampType, Date32Type>(InputType(Type::TIMESTAMP), 
date32(),
                                            func.get());
+
+  // string -> date32
+  AddSimpleCast<StringType, Date32Type>(utf8(), date32(), func.get());
+  AddSimpleCast<LargeStringType, Date32Type>(large_utf8(), date32(), 
func.get());
+
   return func;
 }
 
@@ -500,6 +544,11 @@ std::shared_ptr<CastFunction> GetDate64Cast() {
   // timestamp -> date64
   AddSimpleCast<TimestampType, Date64Type>(InputType(Type::TIMESTAMP), 
date64(),
                                            func.get());
+
+  // string -> date64
+  AddSimpleCast<StringType, Date64Type>(utf8(), date64(), func.get());
+  AddSimpleCast<LargeStringType, Date64Type>(large_utf8(), date64(), 
func.get());
+
   return func;
 }
 
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc 
b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
index 57cd3ad5ed..c84125bbdd 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
@@ -211,8 +211,8 @@ TEST(Cast, CanCast) {
     ExpectCannotCast(from_base_binary, {null()});
   }
 
-  ExpectCanCast(utf8(), {timestamp(TimeUnit::MILLI)});
-  ExpectCanCast(large_utf8(), {timestamp(TimeUnit::NANO)});
+  ExpectCanCast(utf8(), {timestamp(TimeUnit::MILLI), date32(), date64()});
+  ExpectCanCast(large_utf8(), {timestamp(TimeUnit::NANO), date32(), date64()});
   ExpectCannotCast(timestamp(TimeUnit::MICRO),
                    {binary(), large_binary()});  // no formatting supported
 
@@ -2016,6 +2016,23 @@ TEST(Cast, StringToTimestamp) {
   }
 }
 
+TEST(Cast, StringToDate) {
+  for (auto string_type : {utf8(), large_utf8()}) {
+    auto strings = ArrayFromJSON(string_type, R"(["1970-01-01", null, 
"2000-02-29"])");
+
+    CheckCast(strings, ArrayFromJSON(date32(), "[0, null, 11016]"));
+    CheckCast(strings, ArrayFromJSON(date64(), "[0, null, 951782400000]"));
+
+    for (auto date_type : {date32(), date64()}) {
+      for (std::string not_ts : {"", "2012-01-xx", "2012-01-01 09:00:00"}) {
+        auto options = CastOptions::Safe(date_type);
+        CheckCastFails(ArrayFromJSON(string_type, "[\"" + not_ts + "\"]"), 
options);
+      }
+    }
+    // NOTE: YYYY-MM-DD parsing is tested comprehensively in 
value_parsing_test.cc
+  }
+}
+
 static void AssertBinaryZeroCopy(std::shared_ptr<Array> lhs, 
std::shared_ptr<Array> rhs) {
   // null bitmap and data buffers are always zero-copied
   AssertBufferSame(*lhs, *rhs, 0);
diff --git a/cpp/src/arrow/util/value_parsing.h 
b/cpp/src/arrow/util/value_parsing.h
index d4bbf20665..b3c711840f 100644
--- a/cpp/src/arrow/util/value_parsing.h
+++ b/cpp/src/arrow/util/value_parsing.h
@@ -443,33 +443,6 @@ namespace detail {
 
 using ts_type = TimestampType::c_type;
 
-template <typename Duration>
-static inline bool ParseYYYY_MM_DD(const char* s, Duration* since_epoch) {
-  uint16_t year = 0;
-  uint8_t month = 0;
-  uint8_t day = 0;
-  if (ARROW_PREDICT_FALSE(s[4] != '-') || ARROW_PREDICT_FALSE(s[7] != '-')) {
-    return false;
-  }
-  if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 0, 4, &year))) {
-    return false;
-  }
-  if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 5, 2, &month))) {
-    return false;
-  }
-  if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 8, 2, &day))) {
-    return false;
-  }
-  arrow_vendored::date::year_month_day ymd{arrow_vendored::date::year{year},
-                                           arrow_vendored::date::month{month},
-                                           arrow_vendored::date::day{day}};
-  if (ARROW_PREDICT_FALSE(!ymd.ok())) return false;
-
-  *since_epoch = std::chrono::duration_cast<Duration>(
-      arrow_vendored::date::sys_days{ymd}.time_since_epoch());
-  return true;
-}
-
 template <typename Duration>
 static inline bool ParseHH(const char* s, Duration* out) {
   uint8_t hours = 0;
@@ -641,6 +614,33 @@ static inline bool ParseSubSeconds(const char* s, size_t 
length, TimeUnit::type
 
 }  // namespace detail
 
+template <typename Duration>
+static inline bool ParseYYYY_MM_DD(const char* s, Duration* since_epoch) {
+  uint16_t year = 0;
+  uint8_t month = 0;
+  uint8_t day = 0;
+  if (ARROW_PREDICT_FALSE(s[4] != '-') || ARROW_PREDICT_FALSE(s[7] != '-')) {
+    return false;
+  }
+  if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 0, 4, &year))) {
+    return false;
+  }
+  if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 5, 2, &month))) {
+    return false;
+  }
+  if (ARROW_PREDICT_FALSE(!ParseUnsigned(s + 8, 2, &day))) {
+    return false;
+  }
+  arrow_vendored::date::year_month_day ymd{arrow_vendored::date::year{year},
+                                           arrow_vendored::date::month{month},
+                                           arrow_vendored::date::day{day}};
+  if (ARROW_PREDICT_FALSE(!ymd.ok())) return false;
+
+  *since_epoch = std::chrono::duration_cast<Duration>(
+      arrow_vendored::date::sys_days{ymd}.time_since_epoch());
+  return true;
+}
+
 static inline bool ParseTimestampISO8601(const char* s, size_t length,
                                          TimeUnit::type unit, 
TimestampType::c_type* out,
                                          bool* out_zone_offset_present = 
NULLPTR) {
@@ -672,7 +672,7 @@ static inline bool ParseTimestampISO8601(const char* s, 
size_t length,
   if (ARROW_PREDICT_FALSE(length < 10)) return false;
 
   seconds_type seconds_since_epoch;
-  if (ARROW_PREDICT_FALSE(!detail::ParseYYYY_MM_DD(s, &seconds_since_epoch))) {
+  if (ARROW_PREDICT_FALSE(!ParseYYYY_MM_DD(s, &seconds_since_epoch))) {
     return false;
   }
 
@@ -843,7 +843,7 @@ struct StringConverter<DATE_TYPE, 
enable_if_date<DATE_TYPE>> {
     }
 
     duration_type since_epoch;
-    if (ARROW_PREDICT_FALSE(!detail::ParseYYYY_MM_DD(s, &since_epoch))) {
+    if (ARROW_PREDICT_FALSE(!ParseYYYY_MM_DD(s, &since_epoch))) {
       return false;
     }
 
diff --git a/python/pyarrow/tests/test_scalars.py 
b/python/pyarrow/tests/test_scalars.py
index 1d8d77f50d..d7585d1415 100644
--- a/python/pyarrow/tests/test_scalars.py
+++ b/python/pyarrow/tests/test_scalars.py
@@ -352,6 +352,13 @@ def test_cast_int_to_float():
         int_scalar.cast(pa.float64())  # verify default is safe cast
 
 
[email protected]("typ", [pa.date32(), pa.date64()])
+def test_cast_string_to_date(typ):
+    scalar = pa.scalar('2021-01-01')
+    result = scalar.cast(typ)
+    assert result == pa.scalar(datetime.date(2021, 1, 1), type=typ)
+
+
 @pytest.mark.pandas
 def test_timestamp():
     import pandas as pd

Reply via email to