llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Andy Kaylor (andykaylor) <details> <summary>Changes</summary> This adds handling for compare expressions involving member pointer to functions. --- Full diff: https://github.com/llvm/llvm-project/pull/176029.diff 4 Files Affected: - (modified) clang/lib/CIR/Dialect/Transforms/CXXABILowering.cpp (+10-6) - (modified) clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h (+4) - (modified) clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerItaniumCXXABI.cpp (+51) - (added) clang/test/CIR/CodeGen/pointer-to-member-func-cmp.cpp (+121) ``````````diff diff --git a/clang/lib/CIR/Dialect/Transforms/CXXABILowering.cpp b/clang/lib/CIR/Dialect/Transforms/CXXABILowering.cpp index 145f8574893f4..dbe656ac011d8 100644 --- a/clang/lib/CIR/Dialect/Transforms/CXXABILowering.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CXXABILowering.cpp @@ -159,12 +159,16 @@ mlir::LogicalResult CIRCmpOpABILowering::matchAndRewrite( cir::CmpOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::Type type = op.getLhs().getType(); - assert((mlir::isa<cir::DataMemberType>(type)) && - "input to cmp in ABI lowering must be a data member"); - - assert(!cir::MissingFeatures::methodType()); - mlir::Value loweredResult = lowerModule->getCXXABI().lowerDataMemberCmp( - op, adaptor.getLhs(), adaptor.getRhs(), rewriter); + assert((mlir::isa<cir::DataMemberType, cir::MethodType>(type)) && + "input to cmp in ABI lowering must be a data member or method"); + + mlir::Value loweredResult; + if (mlir::isa<cir::DataMemberType>(type)) + loweredResult = lowerModule->getCXXABI().lowerDataMemberCmp( + op, adaptor.getLhs(), adaptor.getRhs(), rewriter); + else + loweredResult = lowerModule->getCXXABI().lowerMethodCmp( + op, adaptor.getLhs(), adaptor.getRhs(), rewriter); rewriter.replaceOp(op, loweredResult); return mlir::success(); diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h index 0dedfa7221f5f..f4d608cdbad03 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h @@ -81,6 +81,10 @@ class CIRCXXABI { virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs, mlir::Value loweredRhs, mlir::OpBuilder &builder) const = 0; + + virtual mlir::Value lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const = 0; }; /// Creates an Itanium-family ABI. diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerItaniumCXXABI.cpp index d944fa3294684..0850368d62718 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerItaniumCXXABI.cpp @@ -67,6 +67,10 @@ class LowerItaniumCXXABI : public CIRCXXABI { mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs, mlir::Value loweredRhs, mlir::OpBuilder &builder) const override; + + mlir::Value lowerMethodCmp(cir::CmpOp op, mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const override; }; } // namespace @@ -249,4 +253,51 @@ LowerItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs, loweredRhs); } +mlir::Value LowerItaniumCXXABI::lowerMethodCmp(cir::CmpOp op, + mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const { + assert(op.getKind() == cir::CmpOpKind::eq || + op.getKind() == cir::CmpOpKind::ne); + + cir::IntType ptrdiffCIRTy = getPtrDiffCIRTy(lm); + mlir::Value ptrdiffZero = cir::ConstantOp::create( + builder, op.getLoc(), cir::IntAttr::get(ptrdiffCIRTy, 0)); + mlir::Location loc = op.getLoc(); + + mlir::Value lhsPtrField = + cir::ExtractMemberOp::create(builder, loc, ptrdiffCIRTy, loweredLhs, 0); + mlir::Value rhsPtrField = + cir::ExtractMemberOp::create(builder, loc, ptrdiffCIRTy, loweredRhs, 0); + mlir::Value ptrCmp = + cir::CmpOp::create(builder, loc, op.getKind(), lhsPtrField, rhsPtrField); + mlir::Value ptrCmpToNull = + cir::CmpOp::create(builder, loc, op.getKind(), lhsPtrField, ptrdiffZero); + + mlir::Value lhsAdjField = + cir::ExtractMemberOp::create(builder, loc, ptrdiffCIRTy, loweredLhs, 1); + mlir::Value rhsAdjField = + cir::ExtractMemberOp::create(builder, loc, ptrdiffCIRTy, loweredRhs, 1); + mlir::Value adjCmp = + cir::CmpOp::create(builder, loc, op.getKind(), lhsAdjField, rhsAdjField); + + auto create_and = [&](mlir::Value lhs, mlir::Value rhs) { + return cir::BinOp::create(builder, loc, cir::BinOpKind::And, lhs, rhs); + }; + auto create_or = [&](mlir::Value lhs, mlir::Value rhs) { + return cir::BinOp::create(builder, loc, cir::BinOpKind::Or, lhs, rhs); + }; + + mlir::Value result; + if (op.getKind() == cir::CmpOpKind::eq) { + // (lhs.ptr == null || lhs.adj == rhs.adj) && lhs.ptr == rhs.ptr + result = create_and(ptrCmp, create_or(ptrCmpToNull, adjCmp)); + } else { + // (lhs.ptr != null && lhs.adj != rhs.adj) || lhs.ptr != rhs.ptr + result = create_or(ptrCmp, create_and(ptrCmpToNull, adjCmp)); + } + + return result; +} + } // namespace cir diff --git a/clang/test/CIR/CodeGen/pointer-to-member-func-cmp.cpp b/clang/test/CIR/CodeGen/pointer-to-member-func-cmp.cpp new file mode 100644 index 0000000000000..f00d227e97fe7 --- /dev/null +++ b/clang/test/CIR/CodeGen/pointer-to-member-func-cmp.cpp @@ -0,0 +1,121 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-cir -mmlir -mlir-print-ir-before=cir-cxxabi-lowering %s -o %t.cir 2> %t-before.cir +// RUN: FileCheck --check-prefix=CIR-BEFORE --input-file=%t-before.cir %s +// RUN: FileCheck --check-prefix=CIR-AFTER --input-file=%t.cir %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-llvm %s -o %t-cir.ll +// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -emit-llvm %s -o %t.ll +// RUN: FileCheck --check-prefix=OGCG --input-file=%t.ll %s + +struct Foo { + void m1(int); + virtual void m2(int); + virtual void m3(int); +}; + +bool cmp_eq(void (Foo::*lhs)(int), void (Foo::*rhs)(int)) { + return lhs == rhs; +} + +// CIR-BEFORE: cir.func {{.*}} @_Z6cmp_eqM3FooFviES1_ +// CIR-BEFORE: %[[LHS:.*]] = cir.load{{.*}} %0 : !cir.ptr<!cir.method<!cir.func<(!s32i)> in !rec_Foo>> +// CIR-BEFORE: %[[RHS:.*]] = cir.load{{.*}} %1 : !cir.ptr<!cir.method<!cir.func<(!s32i)> in !rec_Foo>> +// CIR-BEFORE: %[[CMP:.*]] = cir.cmp(eq, %[[LHS]], %[[RHS]]) : !cir.method<!cir.func<(!s32i)> in !rec_Foo>, !cir.bool +// CIR-BEFORE: cir.store %[[CMP]], %{{.*}} : !cir.bool, !cir.ptr<!cir.bool> + +// CIR-AFTER: @_Z6cmp_eqM3FooFviES1_ +// CIR-AFTER: %[[LHS:.*]] = cir.load{{.*}} %0 : !cir.ptr<!rec_anon_struct>, !rec_anon_struct +// CIR-AFTER: %[[RHS:.*]] = cir.load{{.*}} %1 : !cir.ptr<!rec_anon_struct>, !rec_anon_struct +// CIR-AFTER: %[[NULL:.*]] = cir.const #cir.int<0> : !s64i +// CIR-AFTER: %[[LHS_PTR:.*]] = cir.extract_member %[[LHS]][0] : !rec_anon_struct -> !s64i +// CIR-AFTER: %[[RHS_PTR:.*]] = cir.extract_member %[[RHS]][0] : !rec_anon_struct -> !s64i +// CIR-AFTER: %[[PTR_CMP:.*]] = cir.cmp(eq, %[[LHS_PTR]], %[[RHS_PTR]]) : !s64i, !cir.bool +// CIR-AFTER: %[[PTR_NULL:.*]] = cir.cmp(eq, %[[LHS_PTR]], %[[NULL]]) : !s64i, !cir.bool +// CIR-AFTER: %[[LHS_ADJ:.*]] = cir.extract_member %[[LHS]][1] : !rec_anon_struct -> !s64i +// CIR-AFTER: %[[RHS_ADJ:.*]] = cir.extract_member %[[RHS]][1] : !rec_anon_struct -> !s64i +// CIR-AFTER: %[[ADJ_CMP:.*]] = cir.cmp(eq, %[[LHS_ADJ]], %[[RHS_ADJ]]) : !s64i, !cir.bool +// CIR-AFTER: %[[TMP:.*]] = cir.binop(or, %[[PTR_NULL]], %[[ADJ_CMP]]) : !cir.bool +// CIR-AFTER: %[[RESULT:.*]] = cir.binop(and, %[[PTR_CMP]], %[[TMP]]) : !cir.bool + +// LLVM: define {{.*}} i1 @_Z6cmp_eqM3FooFviES1_ +// LLVM: %[[LHS:.*]] = load { i64, i64 }, ptr %{{.+}} +// LLVM: %[[RHS:.*]] = load { i64, i64 }, ptr %{{.+}} +// LLVM: %[[LHS_PTR:.*]] = extractvalue { i64, i64 } %[[LHS]], 0 +// LLVM: %[[RHS_PTR:.*]] = extractvalue { i64, i64 } %[[RHS]], 0 +// LLVM: %[[PTR_CMP:.*]] = icmp eq i64 %[[LHS_PTR]], %[[RHS_PTR]] +// LLVM: %[[PTR_NULL:.*]] = icmp eq i64 %[[LHS_PTR]], 0 +// LLVM: %[[LHS_ADJ:.*]] = extractvalue { i64, i64 } %[[LHS]], 1 +// LLVM: %[[RHS_ADJ:.*]] = extractvalue { i64, i64 } %[[RHS]], 1 +// LLVM: %[[ADJ_CMP:.*]] = icmp eq i64 %[[LHS_ADJ]], %[[RHS_ADJ]] +// LLVM: %[[TMP:.*]] = or i1 %[[PTR_NULL]], %[[ADJ_CMP]] +// LLVM: %[[RESULT:.*]] = and i1 %[[PTR_CMP]], %[[TMP]] + +// OGCG: define {{.*}} i1 @_Z6cmp_eqM3FooFviES1_ +// OGCG: %[[LHS_TMP:.*]] = alloca { i64, i64 } +// OGCG: %[[RHS_TMP:.*]] = alloca { i64, i64 } +// OGCG: %[[LHS_ADDR:.*]] = alloca { i64, i64 } +// OGCG: %[[RHS_ADDR:.*]] = alloca { i64, i64 } +// OGCG: %[[LHS:.*]] = load { i64, i64 }, ptr %[[LHS_ADDR]] +// OGCG: %[[RHS:.*]] = load { i64, i64 }, ptr %[[RHS_ADDR]] +// OGCG: %[[LHS_PTR:.*]] = extractvalue { i64, i64 } %[[LHS]], 0 +// OGCG: %[[RHS_PTR:.*]] = extractvalue { i64, i64 } %[[RHS]], 0 +// OGCG: %[[PTR_CMP:.*]] = icmp eq i64 %[[LHS_PTR]], %[[RHS_PTR]] +// OGCG: %[[PTR_NULL:.*]] = icmp eq i64 %[[LHS_PTR]], 0 +// OGCG: %[[LHS_ADJ:.*]] = extractvalue { i64, i64 } %[[LHS]], 1 +// OGCG: %[[RHS_ADJ:.*]] = extractvalue { i64, i64 } %[[RHS]], 1 +// OGCG: %[[ADJ_CMP:.*]] = icmp eq i64 %[[LHS_ADJ]], %[[RHS_ADJ]] +// OGCG: %[[TMP:.*]] = or i1 %[[PTR_NULL]], %[[ADJ_CMP]] +// OGCG: %[[RESULT:.*]] = and i1 %[[PTR_CMP]], %[[TMP]] + +bool cmp_ne(void (Foo::*lhs)(int), void (Foo::*rhs)(int)) { + return lhs != rhs; +} + +// CIR-BEFORE: cir.func {{.*}} @_Z6cmp_neM3FooFviES1_ +// CIR-BEFORE: %[[LHS:.*]] = cir.load{{.*}} %0 : !cir.ptr<!cir.method<!cir.func<(!s32i)> in !rec_Foo>> +// CIR-BEFORE: %[[RHS:.*]] = cir.load{{.*}} %1 : !cir.ptr<!cir.method<!cir.func<(!s32i)> in !rec_Foo>> +// CIR-BEFORE: %[[CMP:.*]] = cir.cmp(ne, %[[LHS]], %[[RHS]]) : !cir.method<!cir.func<(!s32i)> in !rec_Foo>, !cir.bool +// CIR-BEFORE: cir.store %[[CMP]], %{{.*}} : !cir.bool, !cir.ptr<!cir.bool> + +// CIR-AFTER: cir.func {{.*}} @_Z6cmp_neM3FooFviES1_ +// CIR-AFTER: %[[LHS:.*]] = cir.load{{.*}} %0 : !cir.ptr<!rec_anon_struct>, !rec_anon_struct +// CIR-AFTER: %[[RHS:.*]] = cir.load{{.*}} %1 : !cir.ptr<!rec_anon_struct>, !rec_anon_struct +// CIR-AFTER: %[[NULL:.*]] = cir.const #cir.int<0> : !s64i +// CIR-AFTER: %[[LHS_PTR:.*]] = cir.extract_member %[[LHS]][0] : !rec_anon_struct -> !s64i +// CIR-AFTER: %[[RHS_PTR:.*]] = cir.extract_member %[[RHS]][0] : !rec_anon_struct -> !s64i +// CIR-AFTER: %[[PTR_CMP:.*]] = cir.cmp(ne, %[[LHS_PTR]], %[[RHS_PTR]]) : !s64i, !cir.bool +// CIR-AFTER: %[[PTR_NULL:.*]] = cir.cmp(ne, %[[LHS_PTR]], %[[NULL]]) : !s64i, !cir.bool +// CIR-AFTER: %[[LHS_ADJ:.*]] = cir.extract_member %[[LHS]][1] : !rec_anon_struct -> !s64i +// CIR-AFTER: %[[RHS_ADJ:.*]] = cir.extract_member %[[RHS]][1] : !rec_anon_struct -> !s64i +// CIR-AFTER: %[[ADJ_CMP:.*]] = cir.cmp(ne, %[[LHS_ADJ]], %[[RHS_ADJ]]) : !s64i, !cir.bool +// CIR-AFTER: %[[TMP:.*]] = cir.binop(and, %[[PTR_NULL]], %[[ADJ_CMP]]) : !cir.bool +// CIR-AFTER: %[[RESULT:.*]] = cir.binop(or, %[[PTR_CMP]], %[[TMP]]) : !cir.bool + +// LLVM: define {{.*}} i1 @_Z6cmp_neM3FooFviES1_ +// LLVM: %[[LHS:.*]] = load { i64, i64 }, ptr %{{.*}} +// LLVM: %[[RHS:.*]] = load { i64, i64 }, ptr %{{.*}} +// LLVM: %[[LHS_PTR:.*]] = extractvalue { i64, i64 } %[[LHS]], 0 +// LLVM: %[[RHS_PTR:.*]] = extractvalue { i64, i64 } %[[RHS]], 0 +// LLVM: %[[PTR_CMP:.*]] = icmp ne i64 %[[LHS_PTR]], %[[RHS_PTR]] +// LLVM: %[[PTR_NULL:.*]] = icmp ne i64 %[[LHS_PTR]], 0 +// LLVM: %[[LHS_ADJ:.*]] = extractvalue { i64, i64 } %[[LHS]], 1 +// LLVM: %[[RHS_ADJ:.*]] = extractvalue { i64, i64 } %[[RHS]], 1 +// LLVM: %[[ADJ_CMP:.*]] = icmp ne i64 %[[LHS_ADJ]], %[[RHS_ADJ]] +// LLVM: %[[TMP:.*]] = and i1 %[[PTR_NULL]], %[[ADJ_CMP]] +// LLVM: %[[RESULT:.*]] = or i1 %[[PTR_CMP]], %[[TMP]] + +// OGCG: define {{.*}} i1 @_Z6cmp_neM3FooFviES1_ +// OGCG: %[[LHS_TMP:.*]] = alloca { i64, i64 } +// OGCG: %[[RHS_TMP:.*]] = alloca { i64, i64 } +// OGCG: %[[LHS_ADDR:.*]] = alloca { i64, i64 } +// OGCG: %[[RHS_ADDR:.*]] = alloca { i64, i64 } +// OGCG: %[[LHS:.*]] = load { i64, i64 }, ptr %[[LHS_ADDR]] +// OGCG: %[[RHS:.*]] = load { i64, i64 }, ptr %[[RHS_ADDR]] +// OGCG: %[[LHS_PTR:.*]] = extractvalue { i64, i64 } %[[LHS]], 0 +// OGCG: %[[RHS_PTR:.*]] = extractvalue { i64, i64 } %[[RHS]], 0 +// OGCG: %[[PTR_CMP:.*]] = icmp ne i64 %[[LHS_PTR]], %[[RHS_PTR]] +// OGCG: %[[PTR_NULL:.*]] = icmp ne i64 %[[LHS_PTR]], 0 +// OGCG: %[[LHS_ADJ:.*]] = extractvalue { i64, i64 } %[[LHS]], 1 +// OGCG: %[[RHS_ADJ:.*]] = extractvalue { i64, i64 } %[[RHS]], 1 +// OGCG: %[[ADJ_CMP:.*]] = icmp ne i64 %[[LHS_ADJ]], %[[RHS_ADJ]] +// OGCG: %[[TMP:.*]] = and i1 %[[PTR_NULL]], %[[ADJ_CMP]] +// OGCG: %[[RESULT:.*]] = or i1 %[[PTR_CMP]], %[[TMP]] `````````` </details> https://github.com/llvm/llvm-project/pull/176029 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
