This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 0b2358c2e4 [Relay] make "ToScalar" support directly obtaining "int64_t" (#16324) 0b2358c2e4 is described below commit 0b2358c2e4656648f726d4a16507ef2513451ad5 Author: mawnja <190936...@qq.com> AuthorDate: Thu Jan 11 03:21:25 2024 +0800 [Relay] make "ToScalar" support directly obtaining "int64_t" (#16324) Because on Windows, "long double" is 64 bits instead of 128 bits like on Linux, to avoid overflow from "long double" to "int64_t" Co-authored-by: wenjian.ma <wenjian...@denglin.ai> --- src/relay/transforms/pattern_utils.h | 43 ++++++++++++++++++++--------------- src/relay/transforms/simplify_expr.cc | 2 +- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 50c2e00298..b26bd76496 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -468,43 +468,43 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { * \param i element index * \return Converted scalar value, or None if conversion failed */ -static inline std::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) { +template <typename T> +static inline std::optional<T> TryToScalar(const runtime::NDArray& array, size_t i = 0) { if (array->dtype.code == kDLInt) { if (array->dtype.bits == 8) { - return std::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<int8_t*>(array->data)[i]); } else if (array->dtype.bits == 16) { - return std::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<int16_t*>(array->data)[i]); } else if (array->dtype.bits == 32) { - return std::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<int32_t*>(array->data)[i]); } else if (array->dtype.bits == 64) { - return std::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<int64_t*>(array->data)[i]); } } else if (array->dtype.code == kDLUInt) { if (array->dtype.bits == 1) { // bool - return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<uint8_t*>(array->data)[i]); } else if (array->dtype.bits == 8) { - return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<uint8_t*>(array->data)[i]); } else if (array->dtype.bits == 16) { - return std::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<uint16_t*>(array->data)[i]); } else if (array->dtype.bits == 32) { - return std::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<uint32_t*>(array->data)[i]); } else if (array->dtype.bits == 64) { - return std::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<uint64_t*>(array->data)[i]); } } else if (array->dtype.code == kDLFloat) { if (array->dtype.bits == 16) { - return std::optional<long double>( - __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>( - reinterpret_cast<uint16_t*>(array->data)[i])); + return std::optional<T>(__extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>( + reinterpret_cast<uint16_t*>(array->data)[i])); } if (array->dtype.bits == 32) { - return std::optional<long double>(reinterpret_cast<float*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<float*>(array->data)[i]); } else if (array->dtype.bits == 64) { - return std::optional<long double>(reinterpret_cast<double*>(array->data)[i]); + return std::optional<T>(reinterpret_cast<double*>(array->data)[i]); } } else if (array->dtype.code == kDLBfloat) { if (array->dtype.bits == 16) { - return std::optional<long double>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>( + return std::optional<T>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>( reinterpret_cast<uint16_t*>(array->data)[i])); } } @@ -517,8 +517,15 @@ static inline std::optional<long double> TryToScalar(const runtime::NDArray& arr * \param i element index * \return Converted scalar value */ +template <typename T> +static inline T ToScalar(const runtime::NDArray& array, size_t i = 0) { + auto try_value = TryToScalar<T>(array, i); + ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); + return try_value.value(); +} + static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) { - auto try_value = TryToScalar(array, i); + auto try_value = TryToScalar<long double>(array, i); ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); return try_value.value(); } @@ -534,7 +541,7 @@ static inline Array<Integer> ToVector(const runtime::NDArray& array) { size_t len = array.Shape().front(); Array<Integer> out; for (size_t i = 0; i < len; ++i) { - long double elem_val = ToScalar(array, i); + uint64_t elem_val = ToScalar<uint64_t>(array, i); out.push_back(Integer(IntImm(DataType::Int(32), static_cast<int64_t>(elem_val)))); } return out; diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 208c9821b6..8036d301e1 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -794,7 +794,7 @@ class EliminateIdentityRewrite : public DFPatternRewrite { if (!IsScalar(GetRef<Expr>(constant))) { return false; } - auto value = TryToScalar(constant->data, 0); + auto value = TryToScalar<long double>(constant->data, 0); if (!value) { // unsupported dtype return false;