tbaeder created this revision.
tbaeder added reviewers: aaron.ballman, erichkeane, tahonermann, shafik.
Herald added a project: All.
tbaeder requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

@aaron.ballman Not sure if this (the added early return) is all you meant when 
you talked about virtual destructors.

I also wasn't sure about the added `getOverridingFunction()` - I think 
`isDerivedFrom(..., paths)` is what I should be using, but that seemed 
excessively hard to use.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D142630

Files:
  clang/lib/AST/Interp/ByteCodeExprGen.cpp
  clang/lib/AST/Interp/Context.cpp
  clang/lib/AST/Interp/Context.h
  clang/lib/AST/Interp/Function.h
  clang/lib/AST/Interp/Interp.h
  clang/lib/AST/Interp/InterpState.h
  clang/test/AST/Interp/records.cpp

Index: clang/test/AST/Interp/records.cpp
===================================================================
--- clang/test/AST/Interp/records.cpp
+++ clang/test/AST/Interp/records.cpp
@@ -1,8 +1,10 @@
 // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -verify %s
 // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++14 -verify %s
+// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++20 -verify %s
 // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -triple i686 -verify %s
 // RUN: %clang_cc1 -verify=ref %s
 // RUN: %clang_cc1 -verify=ref -std=c++14 %s
+// RUN: %clang_cc1 -verify=ref -std=c++20 %s
 // RUN: %clang_cc1 -verify=ref -triple i686 %s
 
 struct BoolPair {
@@ -286,6 +288,7 @@
 };
 
 namespace DeriveFailures {
+#if __cplusplus < 202002L
   struct Base { // ref-note 2{{declared here}}
     int Val;
   };
@@ -301,10 +304,12 @@
                            // ref-note {{in call to 'Derived(12)'}} \
                            // ref-note {{declared here}} \
                            // expected-error {{must be initialized by a constant expression}}
+
   static_assert(D.Val == 0, ""); // ref-error {{not an integral constant expression}} \
                                  // ref-note {{initializer of 'D' is not a constant expression}} \
                                  // expected-error {{not an integral constant expression}} \
                                  // expected-note {{read of object outside its lifetime}}
+#endif
 
   struct AnotherBase {
     int Val;
@@ -354,3 +359,82 @@
   static_assert(getS(true).a == 12, "");
   static_assert(getS(false).a == 13, "");
 };
+
+#if __cplusplus >= 202002L
+namespace VirtualCalls {
+namespace Obvious {
+
+  class A {
+  public:
+    constexpr A(){}
+    constexpr virtual int foo() {
+      return 3;
+    }
+  };
+  class B : public A {
+  public:
+    constexpr int foo() override {
+      return 6;
+    }
+  };
+
+  constexpr int getFooB(bool b) {
+    A *a;
+    A myA;
+    B myB;
+
+    if (b)
+      a = &myA;
+    else
+      a = &myB;
+
+    return a->foo();
+  }
+  static_assert(getFooB(true) == 3, "");
+  static_assert(getFooB(false) == 6, "");
+}
+
+namespace MultipleBases {
+  class A {
+  public:
+    constexpr virtual int getInt() const { return 10; }
+  };
+  class B {
+  public:
+  };
+  class C : public A, public B {
+  public:
+    constexpr int getInt() const override { return 20; }
+  };
+
+  constexpr int callGetInt(const A& a) { return a.getInt(); }
+  static_assert(callGetInt(C()) == 20, "");
+  static_assert(callGetInt(A()) == 10, "");
+}
+
+namespace Destructors {
+  class Base {
+  public:
+    int i;
+    constexpr Base(int &i) : i(i) {i++;}
+    constexpr virtual ~Base() {i--;}
+  };
+
+  class Derived : public Base {
+  public:
+    constexpr Derived(int &i) : Base(i) {}
+    constexpr virtual ~Derived() {i--;}
+  };
+
+  constexpr int test() {
+    int i = 0;
+    Derived d(i);
+    return i;
+  }
+  static_assert(test() == 1);
+}
+
+
+
+};
+#endif
Index: clang/lib/AST/Interp/InterpState.h
===================================================================
--- clang/lib/AST/Interp/InterpState.h
+++ clang/lib/AST/Interp/InterpState.h
@@ -86,6 +86,8 @@
     return M ? M->getSource(F, PC) : F->getSource(PC);
   }
 
+  Context &getContext() const { return Ctx; }
+
 private:
   /// AST Walker state.
   State &Parent;
Index: clang/lib/AST/Interp/Interp.h
===================================================================
--- clang/lib/AST/Interp/Interp.h
+++ clang/lib/AST/Interp/Interp.h
@@ -1509,6 +1509,22 @@
 
     if (S.checkingPotentialConstantExpression())
       return false;
+
+    // For a virtual call, we need to get the right function here.
+    if (Func->isVirtual()) {
+      // Our ThisPtr has the decl of the right type at this point,
+      // so we just need to find the function to call.
+      const CXXRecordDecl *DynamicDecl =
+          ThisPtr.getDeclDesc()->getType()->getAsCXXRecordDecl();
+      const CXXRecordDecl *StaticDecl =
+          cast<CXXRecordDecl>(Func->getParentDecl());
+      const CXXMethodDecl *InitialFunction =
+          cast<CXXMethodDecl>(Func->getDecl());
+      const CXXMethodDecl *Overrider = S.getContext().getOverridingFunction(
+          DynamicDecl, StaticDecl, InitialFunction);
+      if (Overrider != InitialFunction)
+        Func = S.P.getFunction(Overrider);
+    }
   }
 
   if (!CheckCallable(S, PC, Func))
Index: clang/lib/AST/Interp/Function.h
===================================================================
--- clang/lib/AST/Interp/Function.h
+++ clang/lib/AST/Interp/Function.h
@@ -130,6 +130,13 @@
   /// Checks if the function is a constructor.
   bool isConstructor() const { return isa<CXXConstructorDecl>(F); }
 
+  /// Returns the parent record decl, if any.
+  const CXXRecordDecl *getParentDecl() const {
+    if (const auto *MD = dyn_cast<CXXMethodDecl>(F))
+      return MD->getParent();
+    return nullptr;
+  }
+
   /// Checks if the function is fully done compiling.
   bool isFullyCompiled() const { return IsFullyCompiled; }
 
Index: clang/lib/AST/Interp/Context.h
===================================================================
--- clang/lib/AST/Interp/Context.h
+++ clang/lib/AST/Interp/Context.h
@@ -61,6 +61,11 @@
   /// Classifies an expression.
   std::optional<PrimType> classify(QualType T) const;
 
+  const CXXMethodDecl *
+  getOverridingFunction(const CXXRecordDecl *DynamicDecl,
+                        const CXXRecordDecl *StaticDecl,
+                        const CXXMethodDecl *InitialFunction);
+
 private:
   /// Runs a function.
   bool Run(State &Parent, Function *Func, APValue &Result);
Index: clang/lib/AST/Interp/Context.cpp
===================================================================
--- clang/lib/AST/Interp/Context.cpp
+++ clang/lib/AST/Interp/Context.cpp
@@ -152,3 +152,38 @@
   });
   return false;
 }
+
+// TODO: Virtual bases?
+const CXXMethodDecl *
+Context::getOverridingFunction(const CXXRecordDecl *DynamicDecl,
+                               const CXXRecordDecl *StaticDecl,
+                               const CXXMethodDecl *InitialFunction) {
+
+  const CXXRecordDecl *CurRecord = DynamicDecl;
+  const CXXMethodDecl *FoundFunction = InitialFunction;
+  for (;;) {
+    const CXXMethodDecl *Overrider =
+        FoundFunction->getCorrespondingMethodDeclaredInClass(CurRecord, false);
+    if (Overrider)
+      return Overrider;
+
+    // Common case of only one base class.
+    if (CurRecord->getNumBases() == 1) {
+      CurRecord = CurRecord->bases_begin()->getType()->getAsCXXRecordDecl();
+      continue;
+    }
+
+    // Otherwise, go to the base class that will lead to the StaticDecl.
+    for (const CXXBaseSpecifier &Spec : CurRecord->bases()) {
+      const CXXRecordDecl *Base = Spec.getType()->getAsCXXRecordDecl();
+      if (Base == StaticDecl || Base->isDerivedFrom(StaticDecl)) {
+        CurRecord = Base;
+        break;
+      }
+    }
+  }
+
+  llvm_unreachable(
+      "Couldn't find an overriding function in the class hierarchy?");
+  return nullptr;
+}
Index: clang/lib/AST/Interp/ByteCodeExprGen.cpp
===================================================================
--- clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -2047,6 +2047,8 @@
       if (!this->emitCall(DtorFunc, SourceInfo{}))
         return false;
     }
+    if (Dtor->isVirtual())
+      return this->emitPopPtr(SourceInfo{});
   }
 
   for (const Record::Base &Base : llvm::reverse(R->bases())) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to