saiislam updated this revision to Diff 378630.
saiislam marked 4 inline comments as done.
saiislam added a comment.

1. Changed the option from path to nvlink-command.
2. Command line arguments are now parsed using proper API.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D111488/new/

https://reviews.llvm.org/D111488

Files:
  clang/lib/Driver/ToolChains/Cuda.cpp
  clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp

Index: clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp
===================================================================
--- clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp
+++ clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp
@@ -25,6 +25,7 @@
 /// 2. nvlink -o a.out-openmp-nvptx64 /tmp/a.cubin /tmp/b.cubin
 //===---------------------------------------------------------------------===//
 
+#include "clang/Basic/Version.h"
 #include "llvm/Object/Archive.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Errc.h"
@@ -41,6 +42,19 @@
 
 static cl::opt<bool> Help("h", cl::desc("Alias for -help"), cl::Hidden);
 
+// Mark all our options with this category, everything else (except for -help)
+// will be hidden.
+static cl::OptionCategory
+    ClangNvlinkWrapperCategory("clang-nvlink-wrapper options");
+
+static cl::opt<std::string> NvlinkUserPath("nvlink-command",
+                                           cl::desc("path of nvlink binary"),
+                                           cl::cat(ClangNvlinkWrapperCategory));
+
+// Do not parse nvlink options
+static cl::list<std::string>
+    NVArgs(cl::Sink, cl::desc("<options to be passed to nvlink>..."));
+
 static Error runNVLink(std::string NVLinkPath,
                        SmallVectorImpl<std::string> &Args) {
   std::vector<StringRef> NVLArgs;
@@ -119,8 +133,20 @@
   return Error::success();
 }
 
+static void PrintVersion(raw_ostream &OS) {
+  OS << clang::getClangToolFullVersion("clang-offload-bundler") << '\n';
+}
+
 int main(int argc, const char **argv) {
   sys::PrintStackTraceOnErrorSignal(argv[0]);
+  cl::SetVersionPrinter(PrintVersion);
+  cl::HideUnrelatedOptions(ClangNvlinkWrapperCategory);
+  cl::ParseCommandLineOptions(
+      argc, argv,
+      "A wrapper tool over nvlink program. It transparently passes every \n"
+      "input option and objects to nvlink except archive files and path of \n"
+      "nvlink binary. It reads each input archive file to extract archived \n"
+      "cubin files as temporary files.\n");
 
   if (Help) {
     cl::PrintHelpMessage();
@@ -132,12 +158,7 @@
     exit(1);
   };
 
-  ErrorOr<std::string> NvlinkPath = sys::findProgramByName("nvlink");
-  if (!NvlinkPath) {
-    reportError(createStringError(NvlinkPath.getError(),
-                                  "unable to find 'nvlink' in path"));
-  }
-
+  std::string NvlinkPath;
   SmallVector<const char *, 0> Argv(argv, argv + argc);
   SmallVector<std::string, 0> ArgvSubst;
   SmallVector<std::string, 0> TmpFiles;
@@ -145,8 +166,8 @@
   StringSaver Saver(Alloc);
   cl::ExpandResponseFiles(Saver, cl::TokenizeGNUCommandLine, Argv);
 
-  for (size_t i = 1; i < Argv.size(); ++i) {
-    std::string Arg = Argv[i];
+  for (size_t i = 0; i < NVArgs.size(); ++i) {
+    std::string Arg = NVArgs[i];
     if (sys::path::extension(Arg) == ".a") {
       if (Error Err = extractArchiveFiles(Arg, ArgvSubst, TmpFiles))
         reportError(std::move(Err));
@@ -155,7 +176,19 @@
     }
   }
 
-  if (Error Err = runNVLink(NvlinkPath.get(), ArgvSubst))
+  NvlinkPath = NvlinkUserPath;
+
+  // If user hasn't specified nvlink binary then search it in PATH
+  if (NvlinkPath.empty()) {
+    ErrorOr<std::string> NvlinkPathErr = sys::findProgramByName("nvlink");
+    if (!NvlinkPathErr) {
+      reportError(createStringError(NvlinkPathErr.getError(),
+                                    "unable to find 'nvlink' in path"));
+    }
+    NvlinkPath = NvlinkPathErr.get();
+  }
+
+  if (Error Err = runNVLink(NvlinkPath, ArgvSubst))
     reportError(std::move(Err));
   if (Error Err = cleanupTmpFiles(TmpFiles))
     reportError(std::move(Err));
Index: clang/lib/Driver/ToolChains/Cuda.cpp
===================================================================
--- clang/lib/Driver/ToolChains/Cuda.cpp
+++ clang/lib/Driver/ToolChains/Cuda.cpp
@@ -613,6 +613,12 @@
   AddStaticDeviceLibsLinking(C, *this, JA, Inputs, Args, CmdArgs, "nvptx", GPUArch,
                       false, false);
 
+  // Find nvlink and pass it as "--nvlink-command=" argument of clang-nvlink-wrapper.
+  auto NvlinkBin = getToolChain().GetProgramPath("nvlink");
+  const char *NvlinkPath =
+      Args.MakeArgString(Twine("--nvlink-command=" + NvlinkBin));
+  CmdArgs.push_back(NvlinkPath);
+
   const char *Exec =
       Args.MakeArgString(getToolChain().GetProgramPath("clang-nvlink-wrapper"));
   C.addCommand(std::make_unique<Command>(
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to