Author: Aaron En Ye Shi Date: 2020-11-03T18:40:26Z New Revision: ca5b31502c828f8e7160a77f54a5a131dc298005
URL: https://github.com/llvm/llvm-project/commit/ca5b31502c828f8e7160a77f54a5a131dc298005 DIFF: https://github.com/llvm/llvm-project/commit/ca5b31502c828f8e7160a77f54a5a131dc298005.diff LOG: [HIP] Math Headers to use type promotion Similar to libcxx implementation of cmath function overloads, use type promotion templates to determine return types of multi-argument math functions. Fixes: SWDEV-256825 Reviewed By: tra, yaxunl Differential Revision: https://reviews.llvm.org/D90409 Added: Modified: clang/lib/Headers/__clang_hip_cmath.h Removed: ################################################################################ diff --git a/clang/lib/Headers/__clang_hip_cmath.h b/clang/lib/Headers/__clang_hip_cmath.h index fea799ead32f..00519a9795bc 100644 --- a/clang/lib/Headers/__clang_hip_cmath.h +++ b/clang/lib/Headers/__clang_hip_cmath.h @@ -16,6 +16,8 @@ #if defined(__cplusplus) #include <limits> +#include <type_traits> +#include <utility> #endif #include <limits.h> #include <stdint.h> @@ -205,6 +207,72 @@ template <bool __B, class __T = void> struct __hip_enable_if {}; template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; }; +// decltype is only available in C++11 and above. +#if __cplusplus >= 201103L +// __hip_promote +namespace __hip { + +template <class _Tp> struct __numeric_type { + static void __test(...); + static _Float16 __test(_Float16); + static float __test(float); + static double __test(char); + static double __test(int); + static double __test(unsigned); + static double __test(long); + static double __test(unsigned long); + static double __test(long long); + static double __test(unsigned long long); + static double __test(double); + + typedef decltype(__test(std::declval<_Tp>())) type; + static const bool value = !std::is_same<type, void>::value; +}; + +template <> struct __numeric_type<void> { static const bool value = true; }; + +template <class _A1, class _A2 = void, class _A3 = void, + bool = __numeric_type<_A1>::value &&__numeric_type<_A2>::value + &&__numeric_type<_A3>::value> +class __promote_imp { +public: + static const bool value = false; +}; + +template <class _A1, class _A2, class _A3> +class __promote_imp<_A1, _A2, _A3, true> { +private: + typedef typename __promote_imp<_A1>::type __type1; + typedef typename __promote_imp<_A2>::type __type2; + typedef typename __promote_imp<_A3>::type __type3; + +public: + typedef decltype(__type1() + __type2() + __type3()) type; + static const bool value = true; +}; + +template <class _A1, class _A2> class __promote_imp<_A1, _A2, void, true> { +private: + typedef typename __promote_imp<_A1>::type __type1; + typedef typename __promote_imp<_A2>::type __type2; + +public: + typedef decltype(__type1() + __type2()) type; + static const bool value = true; +}; + +template <class _A1> class __promote_imp<_A1, void, void, true> { +public: + typedef typename __numeric_type<_A1>::type type; + static const bool value = true; +}; + +template <class _A1, class _A2 = void, class _A3 = void> +class __promote : public __promote_imp<_A1, _A2, _A3> {}; + +} // namespace __hip +#endif //__cplusplus >= 201103L + // __HIP_OVERLOAD1 is used to resolve function calls with integer argument to // avoid compilation error due to ambibuity. e.g. floor(5) is resolved with // floor(double). @@ -219,6 +287,18 @@ template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; }; // __HIP_OVERLOAD2 is used to resolve function calls with mixed float/double // or integer argument to avoid compilation error due to ambibuity. e.g. // max(5.0f, 6.0) is resolved with max(double, double). +#if __cplusplus >= 201103L +#define __HIP_OVERLOAD2(__retty, __fn) \ + template <typename __T1, typename __T2> \ + __DEVICE__ typename __hip_enable_if< \ + std::numeric_limits<__T1>::is_specialized && \ + std::numeric_limits<__T2>::is_specialized, \ + typename __hip::__promote<__T1, __T2>::type>::type \ + __fn(__T1 __x, __T2 __y) { \ + typedef typename __hip::__promote<__T1, __T2>::type __result_type; \ + return __fn((__result_type)__x, (__result_type)__y); \ + } +#else #define __HIP_OVERLOAD2(__retty, __fn) \ template <typename __T1, typename __T2> \ __DEVICE__ \ @@ -228,6 +308,7 @@ template <class __T> struct __hip_enable_if<true, __T> { typedef __T type; }; __fn(__T1 __x, __T2 __y) { \ return __fn((double)__x, (double)__y); \ } +#endif __HIP_OVERLOAD1(double, abs) __HIP_OVERLOAD1(double, acos) @@ -296,6 +377,18 @@ __HIP_OVERLOAD2(double, max) __HIP_OVERLOAD2(double, min) // Additional Overloads that don't quite match HIP_OVERLOAD. +#if __cplusplus >= 201103L +template <typename __T1, typename __T2, typename __T3> +__DEVICE__ typename __hip_enable_if< + std::numeric_limits<__T1>::is_specialized && + std::numeric_limits<__T2>::is_specialized && + std::numeric_limits<__T3>::is_specialized, + typename __hip::__promote<__T1, __T2, __T3>::type>::type +fma(__T1 __x, __T2 __y, __T3 __z) { + typedef typename __hip::__promote<__T1, __T2, __T3>::type __result_type; + return ::fma((__result_type)__x, (__result_type)__y, (__result_type)__z); +} +#else template <typename __T1, typename __T2, typename __T3> __DEVICE__ typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized && @@ -305,6 +398,7 @@ __DEVICE__ fma(__T1 __x, __T2 __y, __T3 __z) { return ::fma((double)__x, (double)__y, (double)__z); } +#endif template <typename __T> __DEVICE__ @@ -327,6 +421,17 @@ __DEVICE__ return ::modf((double)__x, __exp); } +#if __cplusplus >= 201103L +template <typename __T1, typename __T2> +__DEVICE__ + typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized && + std::numeric_limits<__T2>::is_specialized, + typename __hip::__promote<__T1, __T2>::type>::type + remquo(__T1 __x, __T2 __y, int *__quo) { + typedef typename __hip::__promote<__T1, __T2>::type __result_type; + return ::remquo((__result_type)__x, (__result_type)__y, __quo); +} +#else template <typename __T1, typename __T2> __DEVICE__ typename __hip_enable_if<std::numeric_limits<__T1>::is_specialized && @@ -335,6 +440,7 @@ __DEVICE__ remquo(__T1 __x, __T2 __y, int *__quo) { return ::remquo((double)__x, (double)__y, __quo); } +#endif template <typename __T> __DEVICE__ _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits