On Tue, 29 Sep 2020 at 14:20, Jonathan Wakely <jwak...@redhat.com> wrote:
> I think this is what we want:
>
>    template<typename _Tp, typename... _Types>
>      constexpr inline __same_types = (is_same_v<_Tp, _Types> && ...);
>
> is_same_v is very cheap, it uses the built-in directly, so you don't
> need to instantiate any class templates at all.
>
> >+
> >+  template <unsigned long int _Idx, class _Visitor, class _Variant>
>
> typename not class please.
>
> >+    decltype(auto) __check_visitor_result(_Visitor&& __vis,
>
> New line after the decltype(auto) please, not in the middle of the
> parameter list.

Aye.
diff --git a/libstdc++-v3/include/std/variant b/libstdc++-v3/include/std/variant
index dd8847cf829..6f647d622c4 100644
--- a/libstdc++-v3/include/std/variant
+++ b/libstdc++-v3/include/std/variant
@@ -182,7 +182,7 @@ namespace __variant
   // used for raw visitation with indices passed in
   struct __variant_idx_cookie { using type = __variant_idx_cookie; };
   // Used to enable deduction (and same-type checking) for std::visit:
-  template<typename> struct __deduce_visit_result { };
+  template<typename _Tp> struct __deduce_visit_result { using type = _Tp; };
 
   // Visit variants that might be valueless.
   template<typename _Visitor, typename... _Variants>
@@ -1017,7 +1017,22 @@ namespace __variant
 
       static constexpr auto
       _S_apply()
-      { return _Array_type{&__visit_invoke}; }
+      {
+	constexpr bool __visit_ret_type_mismatch =
+	  _Array_type::__result_is_deduced::value
+	  && !is_same_v<typename _Result_type::type,
+			decltype(__visit_invoke(std::declval<_Visitor>(),
+						std::declval<_Variants>()...))>;
+	if constexpr (__visit_ret_type_mismatch)
+	  {
+	    static_assert(!__visit_ret_type_mismatch,
+			  "std::visit requires the visitor to have the same "
+			  "return type for all alternatives of a variant");
+	    return __nonesuch{};
+	  }
+	else
+	  return _Array_type{&__visit_invoke};
+      }
     };
 
   template<typename _Result_type, typename _Visitor, typename... _Variants>
@@ -1692,6 +1707,26 @@ namespace __variant
 			   std::forward<_Variants>(__variants)...);
     }
 
+  template<typename _Tp, typename... _Types>
+     constexpr inline bool __same_types = (is_same_v<_Tp, _Types> && ...);
+
+  template <unsigned long int _Idx, typename _Visitor, typename _Variant>
+    decltype(auto)
+    __check_visitor_result(_Visitor&& __vis, _Variant&& __variant)
+    {
+      return std::forward<_Visitor>(__vis)(
+        std::get<_Idx>(std::forward<_Variant>(__variant)));
+    }
+
+  template <typename _Visitor, typename _Variant, unsigned long int... _Idxs>
+    constexpr bool __check_visitor_results(std::index_sequence<_Idxs...>)
+    {
+      return __same_types<decltype(__check_visitor_result<_Idxs>(
+	std::declval<_Visitor>(),
+	std::declval<_Variant>()))...>;
+    }
+
+
   template<typename _Visitor, typename... _Variants>
     constexpr decltype(auto)
     visit(_Visitor&& __visitor, _Variants&&... __variants)
@@ -1704,8 +1739,28 @@ namespace __variant
 
       using _Tag = __detail::__variant::__deduce_visit_result<_Result_type>;
 
-      return std::__do_visit<_Tag>(std::forward<_Visitor>(__visitor),
-				   std::forward<_Variants>(__variants)...);
+      if constexpr (sizeof...(_Variants) == 1)
+        {
+	  constexpr bool __visit_rettypes_match =
+	    __check_visitor_results<_Visitor, _Variants...>(
+	      std::make_index_sequence<
+	        std::variant_size<remove_reference_t<_Variants>...>::value>());
+	  if constexpr (!__visit_rettypes_match)
+	    {
+	      static_assert(__visit_rettypes_match,
+			  "std::visit requires the visitor to have the same "
+			  "return type for all alternatives of a variant");
+	      return;
+	    }
+	  else
+	    return std::__do_visit<_Tag>(
+	      std::forward<_Visitor>(__visitor),
+	      std::forward<_Variants>(__variants)...);
+	}
+      else
+	return std::__do_visit<_Tag>(
+          std::forward<_Visitor>(__visitor),
+	  std::forward<_Variants>(__variants)...);
     }
 
 #if __cplusplus > 201703L

Reply via email to