This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0b2358c2e4 [Relay] make "ToScalar" support directly obtaining 
"int64_t" (#16324)
0b2358c2e4 is described below

commit 0b2358c2e4656648f726d4a16507ef2513451ad5
Author: mawnja <190936...@qq.com>
AuthorDate: Thu Jan 11 03:21:25 2024 +0800

    [Relay] make "ToScalar" support directly obtaining "int64_t" (#16324)
    
    Because on Windows, "long double" is 64 bits instead of 128 bits like on 
Linux, to avoid overflow from "long double" to "int64_t"
    
    Co-authored-by: wenjian.ma <wenjian...@denglin.ai>
---
 src/relay/transforms/pattern_utils.h  | 43 ++++++++++++++++++++---------------
 src/relay/transforms/simplify_expr.cc |  2 +-
 2 files changed, 26 insertions(+), 19 deletions(-)

diff --git a/src/relay/transforms/pattern_utils.h 
b/src/relay/transforms/pattern_utils.h
index 50c2e00298..b26bd76496 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -468,43 +468,43 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
  * \param i element index
  * \return Converted scalar value, or None if conversion failed
  */
-static inline std::optional<long double> TryToScalar(const runtime::NDArray& 
array, size_t i = 0) {
+template <typename T>
+static inline std::optional<T> TryToScalar(const runtime::NDArray& array, 
size_t i = 0) {
   if (array->dtype.code == kDLInt) {
     if (array->dtype.bits == 8) {
-      return std::optional<long 
double>(reinterpret_cast<int8_t*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<int8_t*>(array->data)[i]);
     } else if (array->dtype.bits == 16) {
-      return std::optional<long 
double>(reinterpret_cast<int16_t*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<int16_t*>(array->data)[i]);
     } else if (array->dtype.bits == 32) {
-      return std::optional<long 
double>(reinterpret_cast<int32_t*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<int32_t*>(array->data)[i]);
     } else if (array->dtype.bits == 64) {
-      return std::optional<long 
double>(reinterpret_cast<int64_t*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<int64_t*>(array->data)[i]);
     }
   } else if (array->dtype.code == kDLUInt) {
     if (array->dtype.bits == 1) {  // bool
-      return std::optional<long 
double>(reinterpret_cast<uint8_t*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<uint8_t*>(array->data)[i]);
     } else if (array->dtype.bits == 8) {
-      return std::optional<long 
double>(reinterpret_cast<uint8_t*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<uint8_t*>(array->data)[i]);
     } else if (array->dtype.bits == 16) {
-      return std::optional<long 
double>(reinterpret_cast<uint16_t*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<uint16_t*>(array->data)[i]);
     } else if (array->dtype.bits == 32) {
-      return std::optional<long 
double>(reinterpret_cast<uint32_t*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<uint32_t*>(array->data)[i]);
     } else if (array->dtype.bits == 64) {
-      return std::optional<long 
double>(reinterpret_cast<uint64_t*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<uint64_t*>(array->data)[i]);
     }
   } else if (array->dtype.code == kDLFloat) {
     if (array->dtype.bits == 16) {
-      return std::optional<long double>(
-          __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
-              reinterpret_cast<uint16_t*>(array->data)[i]));
+      return std::optional<T>(__extendXfYf2__<uint16_t, uint16_t, 10, float, 
uint32_t, 23>(
+          reinterpret_cast<uint16_t*>(array->data)[i]));
     }
     if (array->dtype.bits == 32) {
-      return std::optional<long 
double>(reinterpret_cast<float*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<float*>(array->data)[i]);
     } else if (array->dtype.bits == 64) {
-      return std::optional<long 
double>(reinterpret_cast<double*>(array->data)[i]);
+      return std::optional<T>(reinterpret_cast<double*>(array->data)[i]);
     }
   } else if (array->dtype.code == kDLBfloat) {
     if (array->dtype.bits == 16) {
-      return std::optional<long double>(__extendXfYf2__<uint16_t, uint16_t, 7, 
float, uint32_t, 23>(
+      return std::optional<T>(__extendXfYf2__<uint16_t, uint16_t, 7, float, 
uint32_t, 23>(
           reinterpret_cast<uint16_t*>(array->data)[i]));
     }
   }
@@ -517,8 +517,15 @@ static inline std::optional<long double> TryToScalar(const 
runtime::NDArray& arr
  * \param i element index
  * \return Converted scalar value
  */
+template <typename T>
+static inline T ToScalar(const runtime::NDArray& array, size_t i = 0) {
+  auto try_value = TryToScalar<T>(array, i);
+  ICHECK(try_value) << "Unknown data type: " << 
tvm::runtime::DLDataType2String(array->dtype);
+  return try_value.value();
+}
+
 static inline long double ToScalar(const runtime::NDArray& array, size_t i = 
0) {
-  auto try_value = TryToScalar(array, i);
+  auto try_value = TryToScalar<long double>(array, i);
   ICHECK(try_value) << "Unknown data type: " << 
tvm::runtime::DLDataType2String(array->dtype);
   return try_value.value();
 }
@@ -534,7 +541,7 @@ static inline Array<Integer> ToVector(const 
runtime::NDArray& array) {
   size_t len = array.Shape().front();
   Array<Integer> out;
   for (size_t i = 0; i < len; ++i) {
-    long double elem_val = ToScalar(array, i);
+    uint64_t elem_val = ToScalar<uint64_t>(array, i);
     out.push_back(Integer(IntImm(DataType::Int(32), 
static_cast<int64_t>(elem_val))));
   }
   return out;
diff --git a/src/relay/transforms/simplify_expr.cc 
b/src/relay/transforms/simplify_expr.cc
index 208c9821b6..8036d301e1 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -794,7 +794,7 @@ class EliminateIdentityRewrite : public DFPatternRewrite {
     if (!IsScalar(GetRef<Expr>(constant))) {
       return false;
     }
-    auto value = TryToScalar(constant->data, 0);
+    auto value = TryToScalar<long double>(constant->data, 0);
     if (!value) {
       // unsupported dtype
       return false;

Reply via email to