capfredf updated this revision to Diff 537766.
capfredf edited the summary of this revision.
capfredf added a comment.

changes per discussions


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D154382/new/

https://reviews.llvm.org/D154382

Files:
  clang/include/clang/Interpreter/CodeCompletion.h
  clang/include/clang/Interpreter/Interpreter.h
  clang/lib/Interpreter/CMakeLists.txt
  clang/lib/Interpreter/CodeCompletion.cpp
  clang/lib/Interpreter/ExternalSource.cpp
  clang/lib/Interpreter/ExternalSource.h
  clang/lib/Interpreter/IncrementalParser.cpp
  clang/lib/Interpreter/IncrementalParser.h
  clang/lib/Interpreter/Interpreter.cpp
  clang/tools/clang-repl/ClangRepl.cpp
  clang/unittests/Interpreter/CMakeLists.txt
  clang/unittests/Interpreter/CodeCompletionTest.cpp

Index: clang/unittests/Interpreter/CodeCompletionTest.cpp
===================================================================
--- /dev/null
+++ clang/unittests/Interpreter/CodeCompletionTest.cpp
@@ -0,0 +1,61 @@
+#include "clang/Interpreter/CodeCompletion.h"
+#include "clang/Interpreter/Interpreter.h"
+
+#include "llvm/LineEditor/LineEditor.h"
+
+#include "clang/Frontend/CompilerInstance.h"
+#include "llvm/Support/Error.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace clang;
+namespace {
+auto CB = clang::IncrementalCompilerBuilder();
+
+static std::unique_ptr<Interpreter> createInterpreter() {
+  auto CI = cantFail(CB.CreateCpp());
+  return cantFail(clang::Interpreter::create(std::move(CI)));
+}
+
+TEST(CodeCompletionTest, Sanity) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int foo = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Completer = ReplListCompleter(CB, *Interp);
+  std::vector<llvm::LineEditor::Completion> comps =
+      Completer(std::string("f"), 1);
+  EXPECT_EQ((size_t)2, comps.size()); // foo and float
+  EXPECT_EQ(comps[0].TypedText, std::string("oo"));
+}
+
+TEST(CodeCompletionTest, SanityNoneValid) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int foo = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Completer = ReplListCompleter(CB, *Interp);
+  std::vector<llvm::LineEditor::Completion> comps =
+      Completer(std::string("babanana"), 8);
+  EXPECT_EQ((size_t)0, comps.size()); // foo and float
+}
+
+TEST(CodeCompletionTest, TwoDecls) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int application = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  if (auto R = Interp->ParseAndExecute("int apple = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Completer = ReplListCompleter(CB, *Interp);
+  std::vector<llvm::LineEditor::Completion> comps =
+      Completer(std::string("app"), 3);
+  EXPECT_EQ((size_t)2, comps.size());
+}
+} // anonymous namespace
Index: clang/unittests/Interpreter/CMakeLists.txt
===================================================================
--- clang/unittests/Interpreter/CMakeLists.txt
+++ clang/unittests/Interpreter/CMakeLists.txt
@@ -9,6 +9,7 @@
 add_clang_unittest(ClangReplInterpreterTests
   IncrementalProcessingTest.cpp
   InterpreterTest.cpp
+  CodeCompletionTest.cpp
   )
 target_link_libraries(ClangReplInterpreterTests PUBLIC
   clangAST
Index: clang/tools/clang-repl/ClangRepl.cpp
===================================================================
--- clang/tools/clang-repl/ClangRepl.cpp
+++ clang/tools/clang-repl/ClangRepl.cpp
@@ -13,6 +13,7 @@
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/FrontendDiagnostic.h"
+#include "clang/Interpreter/CodeCompletion.h"
 #include "clang/Interpreter/Interpreter.h"
 
 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
@@ -155,8 +156,8 @@
 
   if (OptInputs.empty()) {
     llvm::LineEditor LE("clang-repl");
-    // FIXME: Add LE.setListCompleter
     std::string Input;
+    LE.setListCompleter(clang::ReplListCompleter(CB, *Interp));
     while (std::optional<std::string> Line = LE.readLine()) {
       llvm::StringRef L = *Line;
       L = L.trim();
@@ -168,10 +169,10 @@
       }
 
       Input += L;
-
       if (Input == R"(%quit)") {
         break;
-      } else if (Input == R"(%undo)") {
+      }
+      if (Input == R"(%undo)") {
         if (auto Err = Interp->Undo()) {
           llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: ");
           HasError = true;
Index: clang/lib/Interpreter/Interpreter.cpp
===================================================================
--- clang/lib/Interpreter/Interpreter.cpp
+++ clang/lib/Interpreter/Interpreter.cpp
@@ -14,6 +14,7 @@
 #include "clang/Interpreter/Interpreter.h"
 
 #include "DeviceOffload.h"
+#include "ExternalSource.h"
 #include "IncrementalExecutor.h"
 #include "IncrementalParser.h"
 
@@ -33,6 +34,7 @@
 #include "clang/Driver/Tool.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/TextDiagnosticBuffer.h"
+#include "clang/Interpreter/CodeCompletion.h"
 #include "clang/Interpreter/Value.h"
 #include "clang/Lex/PreprocessorOptions.h"
 #include "clang/Sema/Lookup.h"
@@ -127,7 +129,6 @@
 
   Clang->getFrontendOpts().DisableFree = false;
   Clang->getCodeGenOpts().DisableFree = false;
-
   return std::move(Clang);
 }
 
@@ -237,6 +238,18 @@
                                                    *TSCtx->getContext(), Err);
 }
 
+Interpreter::Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err,
+                         std::vector<CodeCompletionResult> &CompResults,
+                         const CompilerInstance *ParentCI) {
+  llvm::ErrorAsOutParameter EAO(&Err);
+  auto LLVMCtx = std::make_unique<llvm::LLVMContext>();
+  TSCtx = std::make_unique<llvm::orc::ThreadSafeContext>(std::move(LLVMCtx));
+  auto *CConsumer = new ReplCompletionConsumer(CompResults);
+  CI->setCodeCompletionConsumer(CConsumer);
+  IncrParser = std::make_unique<IncrementalParser>(
+      *this, std::move(CI), *TSCtx->getContext(), Err, ParentCI);
+}
+
 Interpreter::~Interpreter() {
   if (IncrExecutor) {
     if (llvm::Error Err = IncrExecutor->cleanUp())
@@ -288,6 +301,34 @@
   return std::move(Interp);
 }
 
+llvm::Expected<std::unique_ptr<Interpreter>>
+Interpreter::createForCodeCompletion(
+    IncrementalCompilerBuilder &CB, const CompilerInstance *ParentCI,
+    std::vector<CodeCompletionResult> &CompResults) {
+  auto CI = CB.CreateCpp();
+  if (auto Err = CI.takeError()) {
+    return std::move(Err);
+  }
+
+  (*CI)->getPreprocessorOpts().SingleFileParseMode = true;
+
+  (*CI)->getLangOpts().SpellChecking = false;
+  (*CI)->getLangOpts().DelayedTemplateParsing = false;
+
+  auto &FrontendOpts = (*CI)->getFrontendOpts();
+  FrontendOpts.CodeCompleteOpts = getClangCompleteOpts();
+
+  llvm::Error Err = llvm::Error::success();
+  auto Interp = std::unique_ptr<Interpreter>(
+      new Interpreter(std::move(*CI), Err, CompResults, ParentCI));
+
+  if (Err)
+    return std::move(Err);
+
+  Interp->InitPTUSize = Interp->IncrParser->getPTUs().size();
+  return std::move(Interp);
+}
+
 llvm::Expected<std::unique_ptr<Interpreter>>
 Interpreter::createWithCUDA(std::unique_ptr<CompilerInstance> CI,
                             std::unique_ptr<CompilerInstance> DCI) {
@@ -738,6 +779,12 @@
   return Result.get();
 }
 
+std::string Interpreter::getAllInput() const { return IncrParser->AllInput; }
+
+void Interpreter::CodeComplete(llvm::StringRef Input, size_t Col, size_t Line) {
+  IncrParser->ParseForCodeCompletion(Input, Col, Line);
+}
+
 // Temporary rvalue struct that need special care.
 REPL_EXTERNAL_VISIBILITY void *
 __clang_Interpreter_SetValueWithAlloc(void *This, void *OutVal,
Index: clang/lib/Interpreter/IncrementalParser.h
===================================================================
--- clang/lib/Interpreter/IncrementalParser.h
+++ clang/lib/Interpreter/IncrementalParser.h
@@ -24,7 +24,7 @@
 #include <memory>
 namespace llvm {
 class LLVMContext;
-}
+} // namespace llvm
 
 namespace clang {
 class ASTConsumer;
@@ -60,9 +60,14 @@
   IncrementalParser();
 
 public:
+  // This is not necessarily needed.
+  // We can probably use external source to replace `AllInput`
+  std::string AllInput;
+
   IncrementalParser(Interpreter &Interp,
                     std::unique_ptr<CompilerInstance> Instance,
-                    llvm::LLVMContext &LLVMCtx, llvm::Error &Err);
+                    llvm::LLVMContext &LLVMCtx, llvm::Error &Err,
+                    const CompilerInstance *ParentCI = nullptr);
   virtual ~IncrementalParser();
 
   CompilerInstance *getCI() { return CI.get(); }
@@ -72,6 +77,7 @@
   ///\returns a \c PartialTranslationUnit which holds information about the
   /// \c TranslationUnitDecl and \c llvm::Module corresponding to the input.
   virtual llvm::Expected<PartialTranslationUnit &> Parse(llvm::StringRef Input);
+  void ParseForCodeCompletion(llvm::StringRef Input, size_t Col, size_t Line);
 
   /// Uses the CodeGenModule mangled name cache and avoids recomputing.
   ///\returns the mangled name of a \c GD.
@@ -85,6 +91,12 @@
 
 private:
   llvm::Expected<PartialTranslationUnit &> ParseOrWrapTopLevelDecl();
+
+  llvm::Expected<PartialTranslationUnit &> ParseForPTU(FileID FID,
+                                                       SourceLocation SrcLoc);
+
+  std::pair<FileID, SourceLocation> createSourceFile(llvm::StringRef SourceName,
+                                                     llvm::StringRef Input);
 };
 } // end namespace clang
 
Index: clang/lib/Interpreter/IncrementalParser.cpp
===================================================================
--- clang/lib/Interpreter/IncrementalParser.cpp
+++ clang/lib/Interpreter/IncrementalParser.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "IncrementalParser.h"
+#include "ExternalSource.h"
 #include "clang/AST/DeclContextInternals.h"
 #include "clang/CodeGen/BackendUtil.h"
 #include "clang/CodeGen/CodeGenAction.h"
@@ -115,10 +116,12 @@
 class IncrementalAction : public WrapperFrontendAction {
 private:
   bool IsTerminating = false;
+  const CompilerInstance *ParentCI;
 
 public:
   IncrementalAction(CompilerInstance &CI, llvm::LLVMContext &LLVMCtx,
-                    llvm::Error &Err)
+                    llvm::Error &Err,
+                    const CompilerInstance *ParentCI = nullptr)
       : WrapperFrontendAction([&]() {
           llvm::ErrorAsOutParameter EAO(&Err);
           std::unique_ptr<FrontendAction> Act;
@@ -152,7 +155,8 @@
             break;
           }
           return Act;
-        }()) {}
+        }()),
+        ParentCI(ParentCI) {}
   FrontendAction *getWrapped() const { return WrappedAction.get(); }
   TranslationUnitKind getTranslationUnitKind() override {
     return TU_Incremental;
@@ -175,6 +179,17 @@
     Preprocessor &PP = CI.getPreprocessor();
     PP.EnterMainSourceFile();
 
+    if (ParentCI) {
+      ExternalSource *myExternalSource = new ExternalSource(
+          CI.getASTContext(), CI.getFileManager(), ParentCI->getASTContext(),
+          ParentCI->getFileManager());
+      llvm::IntrusiveRefCntPtr<ExternalASTSource> astContextExternalSource(
+          myExternalSource);
+      CI.getASTContext().setExternalSource(astContextExternalSource);
+      CI.getASTContext().getTranslationUnitDecl()->setHasExternalVisibleStorage(
+          true);
+    }
+
     if (!CI.hasSema())
       CI.createSema(getTranslationUnitKind(), CompletionConsumer);
   }
@@ -206,10 +221,11 @@
 IncrementalParser::IncrementalParser(Interpreter &Interp,
                                      std::unique_ptr<CompilerInstance> Instance,
                                      llvm::LLVMContext &LLVMCtx,
-                                     llvm::Error &Err)
+                                     llvm::Error &Err,
+                                     const CompilerInstance *ParentCI)
     : CI(std::move(Instance)) {
   llvm::ErrorAsOutParameter EAO(&Err);
-  Act = std::make_unique<IncrementalAction>(*CI, LLVMCtx, Err);
+  Act = std::make_unique<IncrementalAction>(*CI, LLVMCtx, Err, ParentCI);
   if (Err)
     return;
   CI->ExecuteAction(*Act);
@@ -305,22 +321,49 @@
   return LastPTU;
 }
 
-llvm::Expected<PartialTranslationUnit &>
-IncrementalParser::Parse(llvm::StringRef input) {
+void IncrementalParser::ParseForCodeCompletion(llvm::StringRef input,
+                                               size_t Col, size_t Line) {
   Preprocessor &PP = CI->getPreprocessor();
   assert(PP.isIncrementalProcessingEnabled() && "Not in incremental mode!?");
 
   std::ostringstream SourceName;
-  SourceName << "input_line_" << InputCount++;
+  SourceName << "input_line_[Completion]";
+
+  auto [FID, SrcLoc] = createSourceFile(SourceName.str(), input);
+  auto FE = CI->getSourceManager().getFileEntryRefForID(FID);
+  // auto Entry = PP.getFileManager().getFile(DummyFN);
+  // if (!Entry) {
+  //   std::cout << "Entry invalid \n";
+  //   return;
+  // }
+  if (FE) {
+    PP.SetCodeCompletionPoint(*FE, Line, Col);
+
+    // NewLoc only used for diags.
+    if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, SrcLoc))
+      return;
+
+    auto PTU = ParseOrWrapTopLevelDecl();
+    if (auto Err = PTU.takeError()) {
+      consumeError(std::move(Err));
+      return;
+    }
 
+    return;
+  }
+}
+
+std::pair<FileID, SourceLocation>
+IncrementalParser::createSourceFile(llvm::StringRef SourceName,
+                                    llvm::StringRef Input) {
   // Create an uninitialized memory buffer, copy code in and append "\n"
-  size_t InputSize = input.size(); // don't include trailing 0
+  size_t InputSize = Input.size(); // don't include trailing 0
   // MemBuffer size should *not* include terminating zero
   std::unique_ptr<llvm::MemoryBuffer> MB(
       llvm::WritableMemoryBuffer::getNewUninitMemBuffer(InputSize + 1,
                                                         SourceName.str()));
   char *MBStart = const_cast<char *>(MB->getBufferStart());
-  memcpy(MBStart, input.data(), InputSize);
+  memcpy(MBStart, Input.data(), InputSize);
   MBStart[InputSize] = '\n';
 
   SourceManager &SM = CI->getSourceManager();
@@ -330,18 +373,46 @@
   SourceLocation NewLoc = SM.getLocForStartOfFile(SM.getMainFileID());
 
   // Create FileID for the current buffer.
-  FileID FID = SM.createFileID(std::move(MB), SrcMgr::C_User, /*LoadedID=*/0,
-                               /*LoadedOffset=*/0, NewLoc);
+  // FileID FID = SM.createFileID(std::move(MB), SrcMgr::C_User, /*LoadedID=*/0,
+  //                              /*LoadedOffset=*/0, NewLoc);
+
+  const clang::FileEntry *FE = SM.getFileManager().getVirtualFile(
+      SourceName.str(), InputSize, 0 /* mod time*/);
+  SM.overrideFileContents(FE, std::move(MB));
+  FileID FID = SM.createFileID(FE, NewLoc, SrcMgr::C_User);
+  return {FID, NewLoc};
+}
+
+llvm::Expected<PartialTranslationUnit &>
+IncrementalParser::ParseForPTU(FileID FID, SourceLocation SrcLoc) {
+  // Create an uninitialized memory buffer, copy code in and append "\n"
+  Preprocessor &PP = CI->getPreprocessor();
+  assert(PP.isIncrementalProcessingEnabled() && "Not in incremental mode!?");
 
   // NewLoc only used for diags.
-  if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, NewLoc))
+  if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, SrcLoc))
     return llvm::make_error<llvm::StringError>("Parsing failed. "
                                                "Cannot enter source file.",
                                                std::error_code());
 
   auto PTU = ParseOrWrapTopLevelDecl();
   if (!PTU)
-    return PTU.takeError();
+    return std::move(PTU.takeError());
+  return *PTU;
+}
+
+llvm::Expected<PartialTranslationUnit &>
+IncrementalParser::Parse(llvm::StringRef input) {
+  Preprocessor &PP = CI->getPreprocessor();
+  std::ostringstream SourceName;
+  SourceName << "input_line_" << InputCount++;
+
+  auto [FID, SrcLoc] = createSourceFile(SourceName.str(), input);
+  auto PTU = ParseForPTU(FID, SrcLoc);
+
+  if (!PTU) {
+    return std::move(PTU.takeError());
+  }
 
   if (PP.getLangOpts().DelayedTemplateParsing) {
     // Microsoft-specific:
Index: clang/lib/Interpreter/ExternalSource.h
===================================================================
--- /dev/null
+++ clang/lib/Interpreter/ExternalSource.h
@@ -0,0 +1,38 @@
+//==----- ExternalSource.h - External AST Source for Code Completion ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines components that make declarations parsed and executed by
+// the interpreter visible to the context where code completion is being
+// triggered.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/ExternalASTSource.h"
+
+namespace clang {
+class ASTContext;
+class FileManager;
+class ASTImporter;
+
+class ExternalSource : public clang::ExternalASTSource {
+  ASTContext &ChildASTCtxt;
+  TranslationUnitDecl *ChildTUDeclCtxt;
+  ASTContext &ParentASTCtxt;
+  TranslationUnitDecl *ParentTUDeclCtxt;
+
+  std::unique_ptr<ASTImporter> Importer;
+
+public:
+  ExternalSource(ASTContext &ChildASTCtxt, FileManager &ChildFM,
+                 ASTContext &ParentASTCtxt, FileManager &ParentFM);
+  bool FindExternalVisibleDeclsByName(const DeclContext *DC,
+                                      DeclarationName Name) override;
+  void
+  completeVisibleDeclsMap(const clang::DeclContext *childDeclContext) override;
+};
+} // namespace clang
Index: clang/lib/Interpreter/ExternalSource.cpp
===================================================================
--- /dev/null
+++ clang/lib/Interpreter/ExternalSource.cpp
@@ -0,0 +1,77 @@
+//===--- ExternalSource.cpp - External AST Source for Code Completion ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The file implements classes that make declarations parsed and executed by the
+// interpreter visible to the context where code completion is being triggered.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ExternalSource.h"
+#include "clang/AST/ASTImporter.h"
+#include "clang/AST/DeclarationName.h"
+#include "clang/Basic/IdentifierTable.h"
+
+namespace clang {
+ExternalSource::ExternalSource(ASTContext &ChildASTCtxt, FileManager &ChildFM,
+                               ASTContext &ParentASTCtxt, FileManager &ParentFM)
+    : ChildASTCtxt(ChildASTCtxt),
+      ChildTUDeclCtxt(ChildASTCtxt.getTranslationUnitDecl()),
+      ParentASTCtxt(ParentASTCtxt),
+      ParentTUDeclCtxt(ParentASTCtxt.getTranslationUnitDecl()) {
+  ASTImporter *importer =
+      new ASTImporter(ChildASTCtxt, ChildFM, ParentASTCtxt, ParentFM,
+                      /*MinimalImport : ON*/ true);
+  Importer.reset(importer);
+}
+
+bool ExternalSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
+                                                    DeclarationName Name) {
+  IdentifierTable &ParentIdTable = ParentASTCtxt.Idents;
+
+  auto ParentDeclName =
+      DeclarationName(&(ParentIdTable.get(Name.getAsString())));
+
+  DeclContext::lookup_result lookup_result =
+      ParentTUDeclCtxt->lookup(ParentDeclName);
+
+  if (!lookup_result.empty()) {
+    return true;
+  }
+  return false;
+}
+
+void ExternalSource::completeVisibleDeclsMap(
+    const DeclContext *ChildDeclContext) {
+  assert(ChildDeclContext && ChildDeclContext == ChildTUDeclCtxt &&
+         "No child decl context!");
+
+  if (!ChildDeclContext->hasExternalVisibleStorage())
+    return;
+
+  for (auto *DeclCtxt = ParentTUDeclCtxt; DeclCtxt != nullptr;
+       DeclCtxt = DeclCtxt->getPreviousDecl()) {
+    for (auto &IDeclContext : DeclCtxt->decls()) {
+      if (NamedDecl *Decl = llvm::dyn_cast<NamedDecl>(IDeclContext)) {
+        if (auto DeclOrErr = Importer->Import(Decl)) {
+          if (NamedDecl *importedNamedDecl =
+                  llvm::dyn_cast<NamedDecl>(*DeclOrErr)) {
+            SetExternalVisibleDeclsForName(ChildDeclContext,
+                                           importedNamedDecl->getDeclName(),
+                                           importedNamedDecl);
+          }
+
+        } else {
+          llvm::consumeError(DeclOrErr.takeError());
+        }
+      }
+    }
+    ChildDeclContext->setHasExternalLexicalStorage(false);
+  }
+}
+
+} // namespace clang
Index: clang/lib/Interpreter/CodeCompletion.cpp
===================================================================
--- /dev/null
+++ clang/lib/Interpreter/CodeCompletion.cpp
@@ -0,0 +1,127 @@
+//===------ CodeCompletion.cpp - Code Completion for ClangRepl -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the classes which performs code completion at the REPL.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Interpreter/CodeCompletion.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Interpreter/Interpreter.h"
+#include "clang/Lex/PreprocessorOptions.h"
+#include "clang/Sema/CodeCompleteOptions.h"
+#include "clang/Sema/Sema.h"
+
+namespace clang {
+
+clang::CodeCompleteOptions getClangCompleteOpts() {
+  clang::CodeCompleteOptions Opts;
+  Opts.IncludeCodePatterns = true;
+  Opts.IncludeMacros = true;
+  Opts.IncludeGlobals = true;
+  Opts.IncludeBriefComments = true;
+  return Opts;
+}
+
+void ReplCompletionConsumer::ProcessCodeCompleteResults(
+    class Sema &S, CodeCompletionContext Context,
+    CodeCompletionResult *InResults, unsigned NumResults) {
+  for (unsigned I = 0; I < NumResults; ++I) {
+    auto &Result = InResults[I];
+    switch (Result.Kind) {
+    case CodeCompletionResult::RK_Declaration:
+      if (Result.Declaration->getIdentifier()) {
+        Results.push_back(Result);
+      }
+      break;
+    default:
+      break;
+    case CodeCompletionResult::RK_Keyword:
+      Results.push_back(Result);
+      break;
+    }
+  }
+}
+
+std::vector<StringRef> ReplListCompleter::toCodeCompleteStrings(
+    const std::vector<CodeCompletionResult> &Results) const {
+  std::vector<StringRef> CompletionStrings;
+  for (auto Res : Results) {
+    switch (Res.Kind) {
+    case CodeCompletionResult::RK_Declaration:
+      if (auto *ID = Res.Declaration->getIdentifier()) {
+        CompletionStrings.push_back(ID->getName());
+      }
+      break;
+    case CodeCompletionResult::RK_Keyword:
+      CompletionStrings.push_back(Res.Keyword);
+      break;
+    default:
+      break;
+    }
+  }
+  return CompletionStrings;
+}
+
+std::vector<llvm::LineEditor::Completion>
+ReplListCompleter::operator()(llvm::StringRef Buffer, size_t Pos) const {
+  std::vector<llvm::LineEditor::Completion> Comps;
+  std::vector<CodeCompletionResult> Results;
+  auto Interp = Interpreter::createForCodeCompletion(
+      CB, MainInterp.getCompilerInstance(), Results);
+
+  if (auto Err = Interp.takeError()) {
+    // log the error and returns an empty vector;
+    llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: ");
+    return Comps;
+  }
+
+  std::string AllCodeText =
+      MainInterp.getAllInput() + "\nvoid dummy(){\n" + Buffer.str() + "}";
+
+  // We need to wrap our input because we need `Sema::CodeCompleteOrdinaryName`
+  // to work on code from the REPL in a statement completion context. By
+  // default, `Sema::CodeCompleteOrdinaryName` thinks the input is a regular c++
+  // file. For example,
+  // ```
+  // clang-repl> int foo = 42;
+  // clang-repl> f_
+  // ```
+  //
+  // `Sema::CodeCompleteOrdinaryName` treats the code as
+  //
+  // ```
+  // int foo = 42;
+  // f_
+  // ```
+  //
+  // Since top-level expressions are not supported, `foo` should not be an
+  // option. But in a REPL session, we should be allowed to use `foo` to make a
+  // statement like `foo + 84;`.
+
+  auto Lines = std::count(AllCodeText.begin(), AllCodeText.end(), '\n') + 1;
+
+  (*Interp)->CodeComplete(AllCodeText, Pos + 1, Lines);
+
+  size_t space_pos = Buffer.rfind(" ");
+  llvm::StringRef s;
+  if (space_pos == llvm::StringRef::npos) {
+    s = Buffer;
+  } else {
+    s = Buffer.substr(space_pos + 1);
+  }
+
+  for (auto c : toCodeCompleteStrings(Results)) {
+    if (c.startswith(s)) {
+      Comps.push_back(
+          llvm::LineEditor::Completion(c.substr(s.size()).str(), c.str()));
+    }
+  }
+  return Comps;
+}
+} // namespace clang
Index: clang/lib/Interpreter/CMakeLists.txt
===================================================================
--- clang/lib/Interpreter/CMakeLists.txt
+++ clang/lib/Interpreter/CMakeLists.txt
@@ -12,7 +12,9 @@
   )
 
 add_clang_library(clangInterpreter
+  CodeCompletion.cpp
   DeviceOffload.cpp
+  ExternalSource.cpp
   IncrementalExecutor.cpp
   IncrementalParser.cpp
   Interpreter.cpp
Index: clang/include/clang/Interpreter/Interpreter.h
===================================================================
--- clang/include/clang/Interpreter/Interpreter.h
+++ clang/include/clang/Interpreter/Interpreter.h
@@ -35,9 +35,11 @@
 
 namespace clang {
 
+class CodeCompletionResult;
 class CompilerInstance;
 class IncrementalExecutor;
 class IncrementalParser;
+class ReplCompletionConsumer;
 
 /// Create a pre-configured \c CompilerInstance for incremental processing.
 class IncrementalCompilerBuilder {
@@ -80,8 +82,12 @@
 
   // An optional parser for CUDA offloading
   std::unique_ptr<IncrementalParser> DeviceParser;
+  std::unique_ptr<ReplCompletionConsumer> CConsumer;
 
   Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err);
+  Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err,
+              std::vector<CodeCompletionResult> &CompResults,
+              const CompilerInstance *ParentCI = nullptr);
 
   llvm::Error CreateExecutor();
   unsigned InitPTUSize = 0;
@@ -93,13 +99,22 @@
 
 public:
   ~Interpreter();
+
   static llvm::Expected<std::unique_ptr<Interpreter>>
   create(std::unique_ptr<CompilerInstance> CI);
+
   static llvm::Expected<std::unique_ptr<Interpreter>>
   createWithCUDA(std::unique_ptr<CompilerInstance> CI,
                  std::unique_ptr<CompilerInstance> DCI);
+
+  static llvm::Expected<std::unique_ptr<Interpreter>>
+  createForCodeCompletion(IncrementalCompilerBuilder &CB,
+                          const CompilerInstance *ParentCI,
+                          std::vector<CodeCompletionResult> &CompResults);
+
   const ASTContext &getASTContext() const;
   ASTContext &getASTContext();
+  void CodeComplete(llvm::StringRef Input, size_t Col, size_t Line = 1);
   const CompilerInstance *getCompilerInstance() const;
   llvm::Expected<llvm::orc::LLJIT &> getExecutionEngine();
 
@@ -136,6 +151,8 @@
 
   Expr *SynthesizeExpr(Expr *E);
 
+  std::string getAllInput() const;
+
 private:
   size_t getEffectivePTUSize() const;
 
Index: clang/include/clang/Interpreter/CodeCompletion.h
===================================================================
--- /dev/null
+++ clang/include/clang/Interpreter/CodeCompletion.h
@@ -0,0 +1,62 @@
+//===------ CodeCompletion.h - Code Completion for ClangRepl -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the classes which performs code completion at the REPL.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_INTERPRETER_CODE_COMPLETION_H
+#define LLVM_CLANG_INTERPRETER_CODE_COMPLETION_H
+#include "clang/Sema/CodeCompleteConsumer.h"
+#include "llvm/LineEditor/LineEditor.h"
+
+namespace clang {
+class Interpreter;
+class IncrementalCompilerBuilder;
+
+clang::CodeCompleteOptions getClangCompleteOpts();
+
+class ReplCompletionConsumer : public CodeCompleteConsumer {
+public:
+  ReplCompletionConsumer(std::vector<CodeCompletionResult> &Results)
+      : CodeCompleteConsumer(getClangCompleteOpts()),
+        CCAllocator(std::make_shared<GlobalCodeCompletionAllocator>()),
+        CCTUInfo(CCAllocator), Results(Results){};
+  void ProcessCodeCompleteResults(class Sema &S, CodeCompletionContext Context,
+                                  CodeCompletionResult *InResults,
+                                  unsigned NumResults) final;
+
+  clang::CodeCompletionAllocator &getAllocator() override {
+    return *CCAllocator;
+  }
+
+  clang::CodeCompletionTUInfo &getCodeCompletionTUInfo() override {
+    return CCTUInfo;
+  }
+
+private:
+  std::shared_ptr<GlobalCodeCompletionAllocator> CCAllocator;
+  CodeCompletionTUInfo CCTUInfo;
+  std::vector<CodeCompletionResult> &Results;
+};
+
+struct ReplListCompleter {
+  IncrementalCompilerBuilder &CB;
+  Interpreter &MainInterp;
+  ReplListCompleter(IncrementalCompilerBuilder &CB, Interpreter &Interp)
+      : CB(CB), MainInterp(Interp){};
+  std::vector<llvm::LineEditor::Completion> operator()(llvm::StringRef Buffer,
+                                                       size_t Pos) const;
+
+private:
+  std::vector<StringRef>
+  toCodeCompleteStrings(const std::vector<CodeCompletionResult> &Results) const;
+};
+
+} // namespace clang
+#endif
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to