================
@@ -468,84 +468,119 @@ static Error runAOTCompile(StringRef InputFile, 
StringRef OutputFile,
   return createStringError(inconvertibleErrorCode(), "Unsupported arch");
 }
 
+static constexpr char AttrSYCLModuleId[] = "sycl-module-id";
+
 /// SYCL device code module split mode.
 enum class IRSplitMode {
+  SPLIT_PER_TU,     // one module per translation unit
   SPLIT_PER_KERNEL, // one module per kernel
   SPLIT_NONE        // no splitting
 };
 
-/// Parses the value of \p -module-split-mode.
+/// Parses the value of \p --module-split-mode.
 static std::optional<IRSplitMode> convertStringToSplitMode(StringRef S) {
   return StringSwitch<std::optional<IRSplitMode>>(S)
+      .Case("source", IRSplitMode::SPLIT_PER_TU)
       .Case("kernel", IRSplitMode::SPLIT_PER_KERNEL)
       .Case("none", IRSplitMode::SPLIT_NONE)
       .Default(std::nullopt);
 }
 
+static StringRef splitModeToString(IRSplitMode Mode) {
+  switch (Mode) {
+  case IRSplitMode::SPLIT_PER_TU:
+    return "source";
+  case IRSplitMode::SPLIT_PER_KERNEL:
+    return "kernel";
+  case IRSplitMode::SPLIT_NONE:
+    return "none";
+  }
+  llvm_unreachable("bad split mode");
+}
+
 /// Result of splitting a device module: the bitcode file path and the
 /// serialized symbol table for each device image.
 struct SplitModule {
   SmallString<256> ModuleFilePath;
   SmallString<0> Symbols;
 };
 
-static bool isEntryPoint(const Function &F) {
-  return !F.isDeclaration() && F.hasKernelCallingConv();
+static bool isEntryPoint(const Function &F, bool EmitOnlyKernelsAsEntryPoints) 
{
+  if (F.isDeclaration())
+    return false;
+  if (F.hasKernelCallingConv())
+    return true;
+  if (EmitOnlyKernelsAsEntryPoints)
+    return false;
+  // sycl_external functions carry the "sycl-module-id" attribute.
+  // This branch is not reachable while EmitOnlyKernelsAsEntryPoints is
+  // hardcoded to true (see TODO in runSYCLLink).
+  return F.hasFnAttribute(AttrSYCLModuleId);
 }
 
-/// Collect kernel names from \p M and serialize them into a symbol table.
-static SmallString<0> collectSymbols(const Module &M) {
-  SmallVector<StringRef> KernelNames;
+/// Collect entry point names from \p M and serialize them into a symbol table.
+static SmallString<0> collectSymbols(const Module &M,
----------------
bader wrote:

```suggestion
static SmallString<0> collectEntryPoints(const Module &M,
```
The comment clearly states that the function collects only entry points rather 
than all symbols.

https://github.com/llvm/llvm-project/pull/196435
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to