From: Arthur Cohen <[email protected]>
gcc/rust/ChangeLog:
* expand/rust-derive-ord.cc (DeriveOrd::cmp_call): New function.
(DeriveOrd::recursive_match): Use it.
(DeriveOrd::visit_enum): Likewise.
* expand/rust-derive-ord.h: Declare it.
---
gcc/rust/expand/rust-derive-ord.cc | 68 +++++++++++++++++++++++++-----
gcc/rust/expand/rust-derive-ord.h | 16 ++++++-
2 files changed, 73 insertions(+), 11 deletions(-)
diff --git a/gcc/rust/expand/rust-derive-ord.cc
b/gcc/rust/expand/rust-derive-ord.cc
index e39c6b44ca4..1f39c94d87b 100644
--- a/gcc/rust/expand/rust-derive-ord.cc
+++ b/gcc/rust/expand/rust-derive-ord.cc
@@ -39,6 +39,17 @@ DeriveOrd::go (Item &item)
return std::move (expanded);
}
+std::unique_ptr<Expr>
+DeriveOrd::cmp_call (std::unique_ptr<Expr> &&self_expr,
+ std::unique_ptr<Expr> &&other_expr)
+{
+ auto cmp_fn_path = builder.path_in_expression (
+ {"core", "cmp", trait (ordering), fn (ordering)}, true);
+
+ return builder.call (ptrify (cmp_fn_path),
+ vec (std::move (self_expr), std::move (other_expr)));
+}
+
std::unique_ptr<Item>
DeriveOrd::cmp_impl (
std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name,
@@ -132,18 +143,14 @@ DeriveOrd::recursive_match (std::vector<SelfOther>
&&members)
{
auto &member = *it;
- auto cmp_fn_path = builder.path_in_expression (
- {"core", "cmp", trait (ordering), fn (ordering)}, true);
-
- auto cmp_call = builder.call (ptrify (cmp_fn_path),
- vec (std::move (member.self_expr),
- std::move (member.other_expr)));
+ auto call = cmp_call (std::move (member.self_expr),
+ std::move (member.other_expr));
// For the last member (so the first iterator), we just create a call
// expression
if (it == members.rbegin ())
{
- final_expr = std::move (cmp_call);
+ final_expr = std::move (call);
continue;
}
@@ -157,8 +164,7 @@ DeriveOrd::recursive_match (std::vector<SelfOther>
&&members)
builder.match_case (std::move (match_arms.second),
builder.identifier (DeriveOrd::not_equal))};
- final_expr
- = builder.match (std::move (cmp_call), std::move (match_cases));
+ final_expr = builder.match (std::move (call), std::move (match_cases));
}
return final_expr;
@@ -227,7 +233,49 @@ DeriveOrd::visit_tuple (TupleStruct &item)
// }
void
DeriveOrd::visit_enum (Enum &item)
-{}
+{
+ auto cases = std::vector<MatchCase> ();
+ auto type_name = item.get_identifier ().as_string ();
+
+ auto let_sd = builder.discriminant_value (DeriveOrd::self_discr, "self");
+ auto other_sd = builder.discriminant_value (DeriveOrd::other_discr, "other");
+
+ auto discr_cmp = cmp_call (builder.identifier (DeriveOrd::self_discr),
+ builder.identifier (DeriveOrd::other_discr));
+
+ for (auto &variant : item.get_variants ())
+ {
+ auto variant_path
+ = builder.variant_path (type_name,
+ variant->get_identifier ().as_string ());
+
+ switch (variant->get_enum_item_kind ())
+ {
+ case EnumItem::Kind::Tuple:
+ case EnumItem::Kind::Struct:
+ case EnumItem::Kind::Identifier:
+ case EnumItem::Kind::Discriminant:
+ // We don't need to do anything for these, as they are handled by the
+ // discriminant value comparison
+ break;
+ }
+ }
+
+ // Add the last case which compares the discriminant values in case `self`
and
+ // `other` are actually different variants of the enum
+ cases.emplace_back (
+ builder.match_case (builder.wildcard (), std::move (discr_cmp)));
+
+ auto match
+ = builder.match (builder.tuple (vec (builder.identifier ("self"),
+ builder.identifier ("other"))),
+ std::move (cases));
+
+ expanded
+ = cmp_impl (builder.block (vec (std::move (let_sd), std::move (other_sd)),
+ std::move (match)),
+ type_name, item.get_generic_params ());
+}
void
DeriveOrd::visit_union (Union &item)
diff --git a/gcc/rust/expand/rust-derive-ord.h
b/gcc/rust/expand/rust-derive-ord.h
index 047ebfb0c01..a360dd26d97 100644
--- a/gcc/rust/expand/rust-derive-ord.h
+++ b/gcc/rust/expand/rust-derive-ord.h
@@ -69,7 +69,9 @@ private:
Ordering ordering;
/* Identifier patterns for the non-equal match arms */
- constexpr static const char *not_equal = "non_eq";
+ constexpr static const char *not_equal = "#non_eq";
+ constexpr static const char *self_discr = "#self_discr";
+ constexpr static const char *other_discr = "#other_discr";
/**
* Create the recursive matching structure used when implementing the
@@ -89,6 +91,18 @@ private:
*/
std::pair<MatchArm, MatchArm> make_cmp_arms ();
+ MatchCase match_enum_tuple (PathInExpression variant_path,
+ const EnumItemTuple &variant);
+ MatchCase match_enum_struct (PathInExpression variant_path,
+ const EnumItemStruct &variant);
+
+ /**
+ * Generate a call to the proper trait function, based on the ordering, in
+ * order to compare two given expressions
+ */
+ std::unique_ptr<Expr> cmp_call (std::unique_ptr<Expr> &&self_expr,
+ std::unique_ptr<Expr> &&other_expr);
+
std::unique_ptr<Item>
cmp_impl (std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name,
const std::vector<std::unique_ptr<GenericParam>> &type_generics);
--
2.49.0