From: Arthur Cohen <[email protected]>
gcc/rust/ChangeLog:
* expand/rust-derive-ord.cc (DeriveOrd::make_cmp_arms): New function.
(is_last): Likewise.
(recursive_match): Likewise.
(DeriveOrd::recursive_match): Likewise.
(DeriveOrd::visit_struct): Add proper implementation.
(DeriveOrd::visit_union): Likewise.
* expand/rust-derive-ord.h: Declare these new functions.
---
gcc/rust/expand/rust-derive-ord.cc | 90 +++++++++++++++++++++++++++---
gcc/rust/expand/rust-derive-ord.h | 37 +++++++++++-
2 files changed, 116 insertions(+), 11 deletions(-)
diff --git a/gcc/rust/expand/rust-derive-ord.cc
b/gcc/rust/expand/rust-derive-ord.cc
index 7eaaa474d1b..2403e9c2a33 100644
--- a/gcc/rust/expand/rust-derive-ord.cc
+++ b/gcc/rust/expand/rust-derive-ord.cc
@@ -92,15 +92,85 @@ DeriveOrd::cmp_fn (std::unique_ptr<BlockExpr> &&block,
Identifier type_name)
builder.reference_type (ptrify (
builder.type_path (type_name.as_string ())))));
- auto function_name = ordering == Ordering::Partial ? "partial_cmp" : "cmp";
+ auto function_name = fn (ordering);
return builder.function (function_name, std::move (params),
ptrify (return_type), std::move (block));
}
+
+std::pair<MatchArm, MatchArm>
+DeriveOrd::make_cmp_arms ()
+{
+ // All comparison results other than Ordering::Equal
+ auto non_equal = builder.identifier_pattern (DeriveOrd::not_equal);
+
+ std::unique_ptr<Pattern> equal = ptrify (
+ builder.path_in_expression ({"core", "cmp", "Ordering", "Equal"}, true));
+
+ // We need to wrap the pattern in Option::Some if we are doing total ordering
+ if (ordering == Ordering::Total)
+ {
+ auto pattern_items = std::unique_ptr<TupleStructItems> (
+ new TupleStructItemsNoRange (vec (std::move (equal))));
+
+ equal
+ = std::make_unique<TupleStructPattern> (builder.path_in_expression (
+ LangItem::Kind::OPTION_SOME),
+ std::move (pattern_items));
+ }
+
+ return {builder.match_arm (std::move (equal)),
+ builder.match_arm (std::move (non_equal))};
+}
+
+template <typename T>
+inline bool
+is_last (T &elt, std::vector<T> &vec)
+{
+ rust_assert (!vec.empty ());
+
+ return &elt == &vec.back ();
+}
+
std::unique_ptr<Expr>
-recursive_match ()
+DeriveOrd::recursive_match (std::vector<SelfOther> &&members)
{
- return nullptr;
+ std::unique_ptr<Expr> final_expr = nullptr;
+
+ for (auto it = members.rbegin (); it != members.rend (); it++)
+ {
+ auto &member = *it;
+
+ auto cmp_fn_path = builder.path_in_expression (
+ {"core", "cmp", trait (ordering), fn (ordering)}, true);
+
+ auto cmp_call = builder.call (ptrify (cmp_fn_path),
+ vec (std::move (member.self_expr),
+ std::move (member.other_expr)));
+
+ // For the last member (so the first iterator), we just create a call
+ // expression
+ if (it == members.rbegin ())
+ {
+ final_expr = std::move (cmp_call);
+ continue;
+ }
+
+ // If we aren't dealing with the last member, then we need to wrap all of
+ // that in a big match expression and keep going
+ auto match_arms = make_cmp_arms ();
+
+ auto match_cases
+ = {builder.match_case (std::move (match_arms.first),
+ std::move (final_expr)),
+ builder.match_case (std::move (match_arms.second),
+ builder.identifier (DeriveOrd::not_equal))};
+
+ final_expr
+ = builder.match (std::move (cmp_call), std::move (match_cases));
+ }
+
+ return final_expr;
}
// we need to do a recursive match expression for all of the fields used in a
@@ -128,10 +198,12 @@ recursive_match ()
void
DeriveOrd::visit_struct (StructStruct &item)
{
- // FIXME: Put cmp_fn call inside cmp_impl, pass a block to cmp_impl instead -
- // this avoids repeating the same parameter twice (the type name)
- expanded = cmp_impl (builder.block (), item.get_identifier (),
- item.get_generic_params ());
+ auto fields = SelfOther::fields (builder, item.get_fields ());
+
+ auto match_expr = recursive_match (std::move (fields));
+
+ expanded = cmp_impl (builder.block (std::move (match_expr)),
+ item.get_identifier (), item.get_generic_params ());
}
// same as structs, but for each field index instead of each field name -
@@ -162,10 +234,10 @@ DeriveOrd::visit_enum (Enum &item)
void
DeriveOrd::visit_union (Union &item)
{
- auto trait_name = ordering == Ordering::Total ? "Ord" : "PartialOrd";
+ auto trait_name = trait (ordering);
rust_error_at (item.get_locus (), "derive(%s) cannot be used on unions",
- trait_name);
+ trait_name.c_str ());
}
} // namespace AST
diff --git a/gcc/rust/expand/rust-derive-ord.h
b/gcc/rust/expand/rust-derive-ord.h
index fae13261e7c..047ebfb0c01 100644
--- a/gcc/rust/expand/rust-derive-ord.h
+++ b/gcc/rust/expand/rust-derive-ord.h
@@ -20,6 +20,7 @@
#define RUST_DERIVE_ORD_H
#include "rust-ast.h"
+#include "rust-derive-cmp-common.h"
#include "rust-derive.h"
namespace Rust {
@@ -42,6 +43,22 @@ public:
Partial
};
+ std::string fn (Ordering ordering)
+ {
+ if (ordering == Ordering::Total)
+ return "cmp";
+ else
+ return "partial_cmp";
+ }
+
+ std::string trait (Ordering ordering)
+ {
+ if (ordering == Ordering::Total)
+ return "Ord";
+ else
+ return "PartialOrd";
+ }
+
DeriveOrd (Ordering ordering, location_t loc);
std::unique_ptr<Item> go (Item &item);
@@ -51,10 +68,26 @@ private:
Ordering ordering;
+ /* Identifier patterns for the non-equal match arms */
+ constexpr static const char *not_equal = "non_eq";
+
/**
* Create the recursive matching structure used when implementing the
- * comparison function on multiple sub items (fields, tuple indexes...) */
- std::unique_ptr<Expr> recursive_match ();
+ * comparison function on multiple sub items (fields, tuple indexes...)
+ */
+ std::unique_ptr<Expr> recursive_match (std::vector<SelfOther> &&members);
+
+ /**
+ * Make the match arms for one inner match in a comparison function block.
+ * This returns the "equal" match arm and the "rest" match arm, so something
+ * like `Ordering::Equal` and `non_eq` in the following match expression:
+ *
+ * match cmp(...) {
+ * Ordering::Equal => match cmp(...) { ... }
+ * non_eq => non_eq,
+ * }
+ */
+ std::pair<MatchArm, MatchArm> make_cmp_arms ();
std::unique_ptr<Item>
cmp_impl (std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name,
--
2.49.0