capfredf created this revision.
Herald added a project: All.
capfredf requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

This patch enabled code completion for ClangREPL. The feature was built upon two
existing Clang components: a list completer for `LineEditor` and 
`SemaCodeCompletion`.

The former serves the main entry point to trigger the latter. Because a
completion point for a compiler instance needs to be unchanged once it is set,
an incremental compiler instance for code completion is created. In addition to
completion points, it differs from a regular incremental compiler in the 
following ways:

1. It does not execute input.

2. To obtain declarations or bindings from previous input in the same REPL

session, it carries over AST context source from the main compiler, i.e. the one
used for the interpreter.

3. It contains a `ReplCompletionConsumer`, a subclass of

`CodeCompletionConsumer`. The consumer communicates completion results from
`SemaCodeCompletion` back to the list completer for the REPL.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D154621

Files:
  clang/include/clang/Interpreter/CodeCompletion.h
  clang/include/clang/Interpreter/Interpreter.h
  clang/lib/Interpreter/CMakeLists.txt
  clang/lib/Interpreter/CodeCompletion.cpp
  clang/lib/Interpreter/ExternalSource.cpp
  clang/lib/Interpreter/ExternalSource.h
  clang/lib/Interpreter/IncrementalParser.cpp
  clang/lib/Interpreter/IncrementalParser.h
  clang/lib/Interpreter/Interpreter.cpp
  clang/tools/clang-repl/ClangRepl.cpp
  clang/unittests/Interpreter/CMakeLists.txt
  clang/unittests/Interpreter/CodeCompletionTest.cpp

Index: clang/unittests/Interpreter/CodeCompletionTest.cpp
===================================================================
--- /dev/null
+++ clang/unittests/Interpreter/CodeCompletionTest.cpp
@@ -0,0 +1,61 @@
+#include "clang/Interpreter/CodeCompletion.h"
+#include "clang/Interpreter/Interpreter.h"
+
+#include "llvm/LineEditor/LineEditor.h"
+
+#include "clang/Frontend/CompilerInstance.h"
+#include "llvm/Support/Error.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace clang;
+namespace {
+auto CB = clang::IncrementalCompilerBuilder();
+
+static std::unique_ptr<Interpreter> createInterpreter() {
+  auto CI = cantFail(CB.CreateCpp());
+  return cantFail(clang::Interpreter::create(std::move(CI)));
+}
+
+TEST(CodeCompletionTest, Sanity) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int foo = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Completer = ReplListCompleter(CB, *Interp);
+  std::vector<llvm::LineEditor::Completion> comps =
+      Completer(std::string("f"), 1);
+  EXPECT_EQ((size_t)2, comps.size()); // foo and float
+  EXPECT_EQ(comps[0].TypedText, std::string("oo"));
+}
+
+TEST(CodeCompletionTest, SanityNoneValid) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int foo = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Completer = ReplListCompleter(CB, *Interp);
+  std::vector<llvm::LineEditor::Completion> comps =
+      Completer(std::string("babanana"), 8);
+  EXPECT_EQ((size_t)0, comps.size()); // foo and float
+}
+
+TEST(CodeCompletionTest, TwoDecls) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int application = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  if (auto R = Interp->ParseAndExecute("int apple = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Completer = ReplListCompleter(CB, *Interp);
+  std::vector<llvm::LineEditor::Completion> comps =
+      Completer(std::string("app"), 3);
+  EXPECT_EQ((size_t)2, comps.size());
+}
+} // anonymous namespace
Index: clang/unittests/Interpreter/CMakeLists.txt
===================================================================
--- clang/unittests/Interpreter/CMakeLists.txt
+++ clang/unittests/Interpreter/CMakeLists.txt
@@ -9,6 +9,7 @@
 add_clang_unittest(ClangReplInterpreterTests
   IncrementalProcessingTest.cpp
   InterpreterTest.cpp
+  CodeCompletionTest.cpp
   )
 target_link_libraries(ClangReplInterpreterTests PUBLIC
   clangAST
Index: clang/tools/clang-repl/ClangRepl.cpp
===================================================================
--- clang/tools/clang-repl/ClangRepl.cpp
+++ clang/tools/clang-repl/ClangRepl.cpp
@@ -13,6 +13,7 @@
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/FrontendDiagnostic.h"
+#include "clang/Interpreter/CodeCompletion.h"
 #include "clang/Interpreter/Interpreter.h"
 
 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
@@ -155,8 +156,8 @@
 
   if (OptInputs.empty()) {
     llvm::LineEditor LE("clang-repl");
-    // FIXME: Add LE.setListCompleter
     std::string Input;
+    LE.setListCompleter(clang::ReplListCompleter(CB, *Interp));
     while (std::optional<std::string> Line = LE.readLine()) {
       llvm::StringRef L = *Line;
       L = L.trim();
@@ -168,10 +169,10 @@
       }
 
       Input += L;
-
       if (Input == R"(%quit)") {
         break;
-      } else if (Input == R"(%undo)") {
+      }
+      if (Input == R"(%undo)") {
         if (auto Err = Interp->Undo()) {
           llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: ");
           HasError = true;
Index: clang/lib/Interpreter/Interpreter.cpp
===================================================================
--- clang/lib/Interpreter/Interpreter.cpp
+++ clang/lib/Interpreter/Interpreter.cpp
@@ -14,6 +14,7 @@
 #include "clang/Interpreter/Interpreter.h"
 
 #include "DeviceOffload.h"
+#include "ExternalSource.h"
 #include "IncrementalExecutor.h"
 #include "IncrementalParser.h"
 
@@ -33,6 +34,7 @@
 #include "clang/Driver/Tool.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/TextDiagnosticBuffer.h"
+#include "clang/Interpreter/CodeCompletion.h"
 #include "clang/Interpreter/Value.h"
 #include "clang/Lex/PreprocessorOptions.h"
 #include "clang/Sema/Lookup.h"
@@ -127,7 +129,6 @@
 
   Clang->getFrontendOpts().DisableFree = false;
   Clang->getCodeGenOpts().DisableFree = false;
-
   return std::move(Clang);
 }
 
@@ -237,6 +238,18 @@
                                                    *TSCtx->getContext(), Err);
 }
 
+Interpreter::Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err,
+                         std::vector<CodeCompletionResult> &CompResults,
+                         const CompilerInstance *ParentCI) {
+  llvm::ErrorAsOutParameter EAO(&Err);
+  auto LLVMCtx = std::make_unique<llvm::LLVMContext>();
+  TSCtx = std::make_unique<llvm::orc::ThreadSafeContext>(std::move(LLVMCtx));
+  auto *CConsumer = new ReplCompletionConsumer(CompResults);
+  CI->setCodeCompletionConsumer(CConsumer);
+  IncrParser = std::make_unique<IncrementalParser>(
+      *this, std::move(CI), *TSCtx->getContext(), Err, ParentCI);
+}
+
 Interpreter::~Interpreter() {
   if (IncrExecutor) {
     if (llvm::Error Err = IncrExecutor->cleanUp())
@@ -288,6 +301,34 @@
   return std::move(Interp);
 }
 
+llvm::Expected<std::unique_ptr<Interpreter>>
+Interpreter::createForCodeCompletion(
+    IncrementalCompilerBuilder &CB, const CompilerInstance *ParentCI,
+    std::vector<CodeCompletionResult> &CompResults) {
+  auto CI = CB.CreateCpp();
+  if (auto Err = CI.takeError()) {
+    return std::move(Err);
+  }
+
+  (*CI)->getPreprocessorOpts().SingleFileParseMode = true;
+
+  (*CI)->getLangOpts().SpellChecking = false;
+  (*CI)->getLangOpts().DelayedTemplateParsing = false;
+
+  auto &FrontendOpts = (*CI)->getFrontendOpts();
+  FrontendOpts.CodeCompleteOpts = getClangCompleteOpts();
+
+  llvm::Error Err = llvm::Error::success();
+  auto Interp = std::unique_ptr<Interpreter>(
+      new Interpreter(std::move(*CI), Err, CompResults, ParentCI));
+
+  if (Err)
+    return std::move(Err);
+
+  Interp->InitPTUSize = Interp->IncrParser->getPTUs().size();
+  return std::move(Interp);
+}
+
 llvm::Expected<std::unique_ptr<Interpreter>>
 Interpreter::createWithCUDA(std::unique_ptr<CompilerInstance> CI,
                             std::unique_ptr<CompilerInstance> DCI) {
@@ -738,6 +779,12 @@
   return Result.get();
 }
 
+std::string Interpreter::getAllInput() const { return IncrParser->AllInput; }
+
+void Interpreter::CodeComplete(llvm::StringRef Input, size_t Col, size_t Line) {
+  IncrParser->ParseForCodeCompletion(Input, Col, Line);
+}
+
 // Temporary rvalue struct that need special care.
 REPL_EXTERNAL_VISIBILITY void *
 __clang_Interpreter_SetValueWithAlloc(void *This, void *OutVal,
Index: clang/lib/Interpreter/IncrementalParser.h
===================================================================
--- clang/lib/Interpreter/IncrementalParser.h
+++ clang/lib/Interpreter/IncrementalParser.h
@@ -24,7 +24,7 @@
 #include <memory>
 namespace llvm {
 class LLVMContext;
-}
+} // namespace llvm
 
 namespace clang {
 class ASTConsumer;
@@ -60,9 +60,14 @@
   IncrementalParser();
 
 public:
+  // This is not necessarily needed.
+  // We can probably use external source to replace `AllInput`
+  std::string AllInput;
+
   IncrementalParser(Interpreter &Interp,
                     std::unique_ptr<CompilerInstance> Instance,
-                    llvm::LLVMContext &LLVMCtx, llvm::Error &Err);
+                    llvm::LLVMContext &LLVMCtx, llvm::Error &Err,
+                    const CompilerInstance *ParentCI = nullptr);
   virtual ~IncrementalParser();
 
   CompilerInstance *getCI() { return CI.get(); }
@@ -72,6 +77,7 @@
   ///\returns a \c PartialTranslationUnit which holds information about the
   /// \c TranslationUnitDecl and \c llvm::Module corresponding to the input.
   virtual llvm::Expected<PartialTranslationUnit &> Parse(llvm::StringRef Input);
+  void ParseForCodeCompletion(llvm::StringRef Input, size_t Col, size_t Line);
 
   /// Uses the CodeGenModule mangled name cache and avoids recomputing.
   ///\returns the mangled name of a \c GD.
@@ -85,6 +91,12 @@
 
 private:
   llvm::Expected<PartialTranslationUnit &> ParseOrWrapTopLevelDecl();
+
+  llvm::Expected<PartialTranslationUnit &> ParseForPTU(FileID FID,
+                                                       SourceLocation SrcLoc);
+
+  std::pair<FileID, SourceLocation> createSourceFile(llvm::StringRef SourceName,
+                                                     llvm::StringRef Input);
 };
 } // end namespace clang
 
Index: clang/lib/Interpreter/IncrementalParser.cpp
===================================================================
--- clang/lib/Interpreter/IncrementalParser.cpp
+++ clang/lib/Interpreter/IncrementalParser.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "IncrementalParser.h"
+#include "ExternalSource.h"
 #include "clang/AST/DeclContextInternals.h"
 #include "clang/CodeGen/BackendUtil.h"
 #include "clang/CodeGen/CodeGenAction.h"
@@ -115,10 +116,12 @@
 class IncrementalAction : public WrapperFrontendAction {
 private:
   bool IsTerminating = false;
+  const CompilerInstance *ParentCI;
 
 public:
   IncrementalAction(CompilerInstance &CI, llvm::LLVMContext &LLVMCtx,
-                    llvm::Error &Err)
+                    llvm::Error &Err,
+                    const CompilerInstance *ParentCI = nullptr)
       : WrapperFrontendAction([&]() {
           llvm::ErrorAsOutParameter EAO(&Err);
           std::unique_ptr<FrontendAction> Act;
@@ -152,7 +155,8 @@
             break;
           }
           return Act;
-        }()) {}
+        }()),
+        ParentCI(ParentCI) {}
   FrontendAction *getWrapped() const { return WrappedAction.get(); }
   TranslationUnitKind getTranslationUnitKind() override {
     return TU_Incremental;
@@ -175,6 +179,17 @@
     Preprocessor &PP = CI.getPreprocessor();
     PP.EnterMainSourceFile();
 
+    if (ParentCI) {
+      ExternalSource *myExternalSource = new ExternalSource(
+          CI.getASTContext(), CI.getFileManager(), ParentCI->getASTContext(),
+          ParentCI->getFileManager());
+      llvm::IntrusiveRefCntPtr<ExternalASTSource> astContextExternalSource(
+          myExternalSource);
+      CI.getASTContext().setExternalSource(astContextExternalSource);
+      CI.getASTContext().getTranslationUnitDecl()->setHasExternalVisibleStorage(
+          true);
+    }
+
     if (!CI.hasSema())
       CI.createSema(getTranslationUnitKind(), CompletionConsumer);
   }
@@ -206,10 +221,11 @@
 IncrementalParser::IncrementalParser(Interpreter &Interp,
                                      std::unique_ptr<CompilerInstance> Instance,
                                      llvm::LLVMContext &LLVMCtx,
-                                     llvm::Error &Err)
+                                     llvm::Error &Err,
+                                     const CompilerInstance *ParentCI)
     : CI(std::move(Instance)) {
   llvm::ErrorAsOutParameter EAO(&Err);
-  Act = std::make_unique<IncrementalAction>(*CI, LLVMCtx, Err);
+  Act = std::make_unique<IncrementalAction>(*CI, LLVMCtx, Err, ParentCI);
   if (Err)
     return;
   CI->ExecuteAction(*Act);
@@ -305,22 +321,49 @@
   return LastPTU;
 }
 
-llvm::Expected<PartialTranslationUnit &>
-IncrementalParser::Parse(llvm::StringRef input) {
+void IncrementalParser::ParseForCodeCompletion(llvm::StringRef input,
+                                               size_t Col, size_t Line) {
   Preprocessor &PP = CI->getPreprocessor();
   assert(PP.isIncrementalProcessingEnabled() && "Not in incremental mode!?");
 
   std::ostringstream SourceName;
-  SourceName << "input_line_" << InputCount++;
+  SourceName << "input_line_[Completion]";
+
+  auto [FID, SrcLoc] = createSourceFile(SourceName.str(), input);
+  auto FE = CI->getSourceManager().getFileEntryRefForID(FID);
+  // auto Entry = PP.getFileManager().getFile(DummyFN);
+  // if (!Entry) {
+  //   std::cout << "Entry invalid \n";
+  //   return;
+  // }
+  if (FE) {
+    PP.SetCodeCompletionPoint(*FE, Line, Col);
+
+    // NewLoc only used for diags.
+    if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, SrcLoc))
+      return;
+
+    auto PTU = ParseOrWrapTopLevelDecl();
+    if (auto Err = PTU.takeError()) {
+      consumeError(std::move(Err));
+      return;
+    }
 
+    return;
+  }
+}
+
+std::pair<FileID, SourceLocation>
+IncrementalParser::createSourceFile(llvm::StringRef SourceName,
+                                    llvm::StringRef Input) {
   // Create an uninitialized memory buffer, copy code in and append "\n"
-  size_t InputSize = input.size(); // don't include trailing 0
+  size_t InputSize = Input.size(); // don't include trailing 0
   // MemBuffer size should *not* include terminating zero
   std::unique_ptr<llvm::MemoryBuffer> MB(
       llvm::WritableMemoryBuffer::getNewUninitMemBuffer(InputSize + 1,
                                                         SourceName.str()));
   char *MBStart = const_cast<char *>(MB->getBufferStart());
-  memcpy(MBStart, input.data(), InputSize);
+  memcpy(MBStart, Input.data(), InputSize);
   MBStart[InputSize] = '\n';
 
   SourceManager &SM = CI->getSourceManager();
@@ -330,18 +373,46 @@
   SourceLocation NewLoc = SM.getLocForStartOfFile(SM.getMainFileID());
 
   // Create FileID for the current buffer.
-  FileID FID = SM.createFileID(std::move(MB), SrcMgr::C_User, /*LoadedID=*/0,
-                               /*LoadedOffset=*/0, NewLoc);
+  // FileID FID = SM.createFileID(std::move(MB), SrcMgr::C_User, /*LoadedID=*/0,
+  //                              /*LoadedOffset=*/0, NewLoc);
+
+  const clang::FileEntry *FE = SM.getFileManager().getVirtualFile(
+      SourceName.str(), InputSize, 0 /* mod time*/);
+  SM.overrideFileContents(FE, std::move(MB));
+  FileID FID = SM.createFileID(FE, NewLoc, SrcMgr::C_User);
+  return {FID, NewLoc};
+}
+
+llvm::Expected<PartialTranslationUnit &>
+IncrementalParser::ParseForPTU(FileID FID, SourceLocation SrcLoc) {
+  // Create an uninitialized memory buffer, copy code in and append "\n"
+  Preprocessor &PP = CI->getPreprocessor();
+  assert(PP.isIncrementalProcessingEnabled() && "Not in incremental mode!?");
 
   // NewLoc only used for diags.
-  if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, NewLoc))
+  if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, SrcLoc))
     return llvm::make_error<llvm::StringError>("Parsing failed. "
                                                "Cannot enter source file.",
                                                std::error_code());
 
   auto PTU = ParseOrWrapTopLevelDecl();
   if (!PTU)
-    return PTU.takeError();
+    return std::move(PTU.takeError());
+  return *PTU;
+}
+
+llvm::Expected<PartialTranslationUnit &>
+IncrementalParser::Parse(llvm::StringRef input) {
+  Preprocessor &PP = CI->getPreprocessor();
+  std::ostringstream SourceName;
+  SourceName << "input_line_" << InputCount++;
+
+  auto [FID, SrcLoc] = createSourceFile(SourceName.str(), input);
+  auto PTU = ParseForPTU(FID, SrcLoc);
+
+  if (!PTU) {
+    return std::move(PTU.takeError());
+  }
 
   if (PP.getLangOpts().DelayedTemplateParsing) {
     // Microsoft-specific:
Index: clang/lib/Interpreter/ExternalSource.h
===================================================================
--- /dev/null
+++ clang/lib/Interpreter/ExternalSource.h
@@ -0,0 +1,38 @@
+//==----- ExternalSource.h - External AST Source for Code Completion ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines components that make declarations parsed and executed by
+// the interpreter visible to the context where code completion is being
+// triggered.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/ExternalASTSource.h"
+
+namespace clang {
+class ASTContext;
+class FileManager;
+class ASTImporter;
+
+class ExternalSource : public clang::ExternalASTSource {
+  ASTContext &ChildASTCtxt;
+  TranslationUnitDecl *ChildTUDeclCtxt;
+  ASTContext &ParentASTCtxt;
+  TranslationUnitDecl *ParentTUDeclCtxt;
+
+  std::unique_ptr<ASTImporter> Importer;
+
+public:
+  ExternalSource(ASTContext &ChildASTCtxt, FileManager &ChildFM,
+                 ASTContext &ParentASTCtxt, FileManager &ParentFM);
+  bool FindExternalVisibleDeclsByName(const DeclContext *DC,
+                                      DeclarationName Name) override;
+  void
+  completeVisibleDeclsMap(const clang::DeclContext *childDeclContext) override;
+};
+} // namespace clang
Index: clang/lib/Interpreter/ExternalSource.cpp
===================================================================
--- /dev/null
+++ clang/lib/Interpreter/ExternalSource.cpp
@@ -0,0 +1,77 @@
+//===--- ExternalSource.cpp - External AST Source for Code Completion ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The file implements classes that make declarations parsed and executed by the
+// interpreter visible to the context where code completion is being triggered.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ExternalSource.h"
+#include "clang/AST/ASTImporter.h"
+#include "clang/AST/DeclarationName.h"
+#include "clang/Basic/IdentifierTable.h"
+
+namespace clang {
+ExternalSource::ExternalSource(ASTContext &ChildASTCtxt, FileManager &ChildFM,
+                               ASTContext &ParentASTCtxt, FileManager &ParentFM)
+    : ChildASTCtxt(ChildASTCtxt),
+      ChildTUDeclCtxt(ChildASTCtxt.getTranslationUnitDecl()),
+      ParentASTCtxt(ParentASTCtxt),
+      ParentTUDeclCtxt(ParentASTCtxt.getTranslationUnitDecl()) {
+  ASTImporter *importer =
+      new ASTImporter(ChildASTCtxt, ChildFM, ParentASTCtxt, ParentFM,
+                      /*MinimalImport : ON*/ true);
+  Importer.reset(importer);
+}
+
+bool ExternalSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
+                                                    DeclarationName Name) {
+  IdentifierTable &ParentIdTable = ParentASTCtxt.Idents;
+
+  auto ParentDeclName =
+      DeclarationName(&(ParentIdTable.get(Name.getAsString())));
+
+  DeclContext::lookup_result lookup_result =
+      ParentTUDeclCtxt->lookup(ParentDeclName);
+
+  if (!lookup_result.empty()) {
+    return true;
+  }
+  return false;
+}
+
+void ExternalSource::completeVisibleDeclsMap(
+    const DeclContext *ChildDeclContext) {
+  assert(ChildDeclContext && ChildDeclContext == ChildTUDeclCtxt &&
+         "No child decl context!");
+
+  if (!ChildDeclContext->hasExternalVisibleStorage())
+    return;
+
+  for (auto *DeclCtxt = ParentTUDeclCtxt; DeclCtxt != nullptr;
+       DeclCtxt = DeclCtxt->getPreviousDecl()) {
+    for (auto &IDeclContext : DeclCtxt->decls()) {
+      if (NamedDecl *Decl = llvm::dyn_cast<NamedDecl>(IDeclContext)) {
+        if (auto DeclOrErr = Importer->Import(Decl)) {
+          if (NamedDecl *importedNamedDecl =
+                  llvm::dyn_cast<NamedDecl>(*DeclOrErr)) {
+            SetExternalVisibleDeclsForName(ChildDeclContext,
+                                           importedNamedDecl->getDeclName(),
+                                           importedNamedDecl);
+          }
+
+        } else {
+          llvm::consumeError(DeclOrErr.takeError());
+        }
+      }
+    }
+    ChildDeclContext->setHasExternalLexicalStorage(false);
+  }
+}
+
+} // namespace clang
Index: clang/lib/Interpreter/CodeCompletion.cpp
===================================================================
--- /dev/null
+++ clang/lib/Interpreter/CodeCompletion.cpp
@@ -0,0 +1,127 @@
+//===------ CodeCompletion.cpp - Code Completion for ClangRepl -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the classes which performs code completion at the REPL.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Interpreter/CodeCompletion.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Interpreter/Interpreter.h"
+#include "clang/Lex/PreprocessorOptions.h"
+#include "clang/Sema/CodeCompleteOptions.h"
+#include "clang/Sema/Sema.h"
+
+namespace clang {
+
+clang::CodeCompleteOptions getClangCompleteOpts() {
+  clang::CodeCompleteOptions Opts;
+  Opts.IncludeCodePatterns = true;
+  Opts.IncludeMacros = true;
+  Opts.IncludeGlobals = true;
+  Opts.IncludeBriefComments = true;
+  return Opts;
+}
+
+void ReplCompletionConsumer::ProcessCodeCompleteResults(
+    class Sema &S, CodeCompletionContext Context,
+    CodeCompletionResult *InResults, unsigned NumResults) {
+  for (unsigned I = 0; I < NumResults; ++I) {
+    auto &Result = InResults[I];
+    switch (Result.Kind) {
+    case CodeCompletionResult::RK_Declaration:
+      if (Result.Declaration->getIdentifier()) {
+        Results.push_back(Result);
+      }
+      break;
+    default:
+      break;
+    case CodeCompletionResult::RK_Keyword:
+      Results.push_back(Result);
+      break;
+    }
+  }
+}
+
+std::vector<StringRef> ReplListCompleter::toCodeCompleteStrings(
+    const std::vector<CodeCompletionResult> &Results) const {
+  std::vector<StringRef> CompletionStrings;
+  for (auto Res : Results) {
+    switch (Res.Kind) {
+    case CodeCompletionResult::RK_Declaration:
+      if (auto *ID = Res.Declaration->getIdentifier()) {
+        CompletionStrings.push_back(ID->getName());
+      }
+      break;
+    case CodeCompletionResult::RK_Keyword:
+      CompletionStrings.push_back(Res.Keyword);
+      break;
+    default:
+      break;
+    }
+  }
+  return CompletionStrings;
+}
+
+std::vector<llvm::LineEditor::Completion>
+ReplListCompleter::operator()(llvm::StringRef Buffer, size_t Pos) const {
+  std::vector<llvm::LineEditor::Completion> Comps;
+  std::vector<CodeCompletionResult> Results;
+  auto Interp = Interpreter::createForCodeCompletion(
+      CB, MainInterp.getCompilerInstance(), Results);
+
+  if (auto Err = Interp.takeError()) {
+    // log the error and returns an empty vector;
+    llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: ");
+    return Comps;
+  }
+
+  std::string AllCodeText =
+      MainInterp.getAllInput() + "\nvoid dummy(){\n" + Buffer.str() + "}";
+
+  // We need to wrap our input because we need `Sema::CodeCompleteOrdinaryName`
+  // to work on code from the REPL in a statement completion context. By
+  // default, `Sema::CodeCompleteOrdinaryName` thinks the input is a regular c++
+  // file. For example,
+  // ```
+  // clang-repl> int foo = 42;
+  // clang-repl> f_
+  // ```
+  //
+  // `Sema::CodeCompleteOrdinaryName` treats the code as
+  //
+  // ```
+  // int foo = 42;
+  // f_
+  // ```
+  //
+  // Since top-level expressions are not supported, `foo` should not be an
+  // option. But in a REPL session, we should be allowed to use `foo` to make a
+  // statement like `foo + 84;`.
+
+  auto Lines = std::count(AllCodeText.begin(), AllCodeText.end(), '\n') + 1;
+
+  (*Interp)->CodeComplete(AllCodeText, Pos + 1, Lines);
+
+  size_t space_pos = Buffer.rfind(" ");
+  llvm::StringRef s;
+  if (space_pos == llvm::StringRef::npos) {
+    s = Buffer;
+  } else {
+    s = Buffer.substr(space_pos + 1);
+  }
+
+  for (auto c : toCodeCompleteStrings(Results)) {
+    if (c.startswith(s)) {
+      Comps.push_back(
+          llvm::LineEditor::Completion(c.substr(s.size()).str(), c.str()));
+    }
+  }
+  return Comps;
+}
+} // namespace clang
Index: clang/lib/Interpreter/CMakeLists.txt
===================================================================
--- clang/lib/Interpreter/CMakeLists.txt
+++ clang/lib/Interpreter/CMakeLists.txt
@@ -12,7 +12,9 @@
   )
 
 add_clang_library(clangInterpreter
+  CodeCompletion.cpp
   DeviceOffload.cpp
+  ExternalSource.cpp
   IncrementalExecutor.cpp
   IncrementalParser.cpp
   Interpreter.cpp
Index: clang/include/clang/Interpreter/Interpreter.h
===================================================================
--- clang/include/clang/Interpreter/Interpreter.h
+++ clang/include/clang/Interpreter/Interpreter.h
@@ -35,9 +35,11 @@
 
 namespace clang {
 
+class CodeCompletionResult;
 class CompilerInstance;
 class IncrementalExecutor;
 class IncrementalParser;
+class ReplCompletionConsumer;
 
 /// Create a pre-configured \c CompilerInstance for incremental processing.
 class IncrementalCompilerBuilder {
@@ -80,8 +82,12 @@
 
   // An optional parser for CUDA offloading
   std::unique_ptr<IncrementalParser> DeviceParser;
+  std::unique_ptr<ReplCompletionConsumer> CConsumer;
 
   Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err);
+  Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err,
+              std::vector<CodeCompletionResult> &CompResults,
+              const CompilerInstance *ParentCI = nullptr);
 
   llvm::Error CreateExecutor();
   unsigned InitPTUSize = 0;
@@ -93,13 +99,22 @@
 
 public:
   ~Interpreter();
+
   static llvm::Expected<std::unique_ptr<Interpreter>>
   create(std::unique_ptr<CompilerInstance> CI);
+
   static llvm::Expected<std::unique_ptr<Interpreter>>
   createWithCUDA(std::unique_ptr<CompilerInstance> CI,
                  std::unique_ptr<CompilerInstance> DCI);
+
+  static llvm::Expected<std::unique_ptr<Interpreter>>
+  createForCodeCompletion(IncrementalCompilerBuilder &CB,
+                          const CompilerInstance *ParentCI,
+                          std::vector<CodeCompletionResult> &CompResults);
+
   const ASTContext &getASTContext() const;
   ASTContext &getASTContext();
+  void CodeComplete(llvm::StringRef Input, size_t Col, size_t Line = 1);
   const CompilerInstance *getCompilerInstance() const;
   llvm::Expected<llvm::orc::LLJIT &> getExecutionEngine();
 
@@ -136,6 +151,8 @@
 
   Expr *SynthesizeExpr(Expr *E);
 
+  std::string getAllInput() const;
+
 private:
   size_t getEffectivePTUSize() const;
 
Index: clang/include/clang/Interpreter/CodeCompletion.h
===================================================================
--- /dev/null
+++ clang/include/clang/Interpreter/CodeCompletion.h
@@ -0,0 +1,62 @@
+//===------ CodeCompletion.h - Code Completion for ClangRepl -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the classes which performs code completion at the REPL.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_INTERPRETER_CODE_COMPLETION_H
+#define LLVM_CLANG_INTERPRETER_CODE_COMPLETION_H
+#include "clang/Sema/CodeCompleteConsumer.h"
+#include "llvm/LineEditor/LineEditor.h"
+
+namespace clang {
+class Interpreter;
+class IncrementalCompilerBuilder;
+
+clang::CodeCompleteOptions getClangCompleteOpts();
+
+class ReplCompletionConsumer : public CodeCompleteConsumer {
+public:
+  ReplCompletionConsumer(std::vector<CodeCompletionResult> &Results)
+      : CodeCompleteConsumer(getClangCompleteOpts()),
+        CCAllocator(std::make_shared<GlobalCodeCompletionAllocator>()),
+        CCTUInfo(CCAllocator), Results(Results){};
+  void ProcessCodeCompleteResults(class Sema &S, CodeCompletionContext Context,
+                                  CodeCompletionResult *InResults,
+                                  unsigned NumResults) final;
+
+  clang::CodeCompletionAllocator &getAllocator() override {
+    return *CCAllocator;
+  }
+
+  clang::CodeCompletionTUInfo &getCodeCompletionTUInfo() override {
+    return CCTUInfo;
+  }
+
+private:
+  std::shared_ptr<GlobalCodeCompletionAllocator> CCAllocator;
+  CodeCompletionTUInfo CCTUInfo;
+  std::vector<CodeCompletionResult> &Results;
+};
+
+struct ReplListCompleter {
+  IncrementalCompilerBuilder &CB;
+  Interpreter &MainInterp;
+  ReplListCompleter(IncrementalCompilerBuilder &CB, Interpreter &Interp)
+      : CB(CB), MainInterp(Interp){};
+  std::vector<llvm::LineEditor::Completion> operator()(llvm::StringRef Buffer,
+                                                       size_t Pos) const;
+
+private:
+  std::vector<StringRef>
+  toCodeCompleteStrings(const std::vector<CodeCompletionResult> &Results) const;
+};
+
+} // namespace clang
+#endif
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to