jhuber6 created this revision.
jhuber6 added reviewers: ABataev, JonChesterfield, jdoerfert, tianshilei1992.
Herald added subscribers: guansong, hiraditya, yaxunl.
Herald added a project: All.
jhuber6 requested review of this revision.
Herald added subscribers: llvm-commits, cfe-commits, sstefan1.
Herald added projects: clang, LLVM.

This patch changes the code we generate to enter a target region on the
device. This is in-line with the new definition in the runtime that was
added previously. Additionally we implement this in the OpenMPIRBuilder
so that this code can be shared with Flang in the future.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D128550

Files:
  clang/lib/CodeGen/CGOpenMPRuntime.cpp
  llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
  llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
  llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -791,6 +791,38 @@
   Entry->setAlignment(Align(1));
 }
 
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
+    const LocationDescription &Loc, Value *&Return, Value *Ident,
+    Value *DeviceID, Value *NumTeams, Value *NumThreads, Value *HostPtr,
+    ArrayRef<Value *> KernelArgs, ArrayRef<Value *> NoWaitArgs) {
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+
+  auto *KernelArgsPtr =
+      Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
+  for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
+    llvm::Value *Arg =
+        Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
+    Builder.CreateAlignedStore(
+        KernelArgs[I], Arg,
+        M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
+  }
+
+  bool HasNoWait = !NoWaitArgs.empty();
+  SmallVector<Value *> OffloadingArgs{Ident,      DeviceID, NumTeams,
+                                      NumThreads, HostPtr,  KernelArgsPtr};
+  if (HasNoWait)
+    OffloadingArgs.append(NoWaitArgs.begin(), NoWaitArgs.end());
+
+  Return = Builder.CreateCall(
+      HasNoWait
+          ? getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel_nowait)
+          : getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
+      OffloadingArgs);
+
+  return Builder.saveIP();
+}
+
 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
                                                omp::Directive CanceledDirective,
                                                FinalizeCallbackTy ExitCB) {
Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -88,6 +88,8 @@
 __OMP_STRUCT_TYPE(Ident, ident_t, Int32, Int32, Int32, Int32, Int8Ptr)
 __OMP_STRUCT_TYPE(OffloadEntry, __tgt_offload_entry, Int8Ptr, Int8Ptr, SizeTy,
                   Int32, Int32)
+__OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, Int32, VoidPtrPtr,
+                  VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr)
 __OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, Int8Ptr)
 
 #undef __OMP_STRUCT_TYPE
@@ -412,6 +414,10 @@
 __OMP_RTL(__tgt_target_teams_nowait_mapper, false, Int32, IdentPtr, Int64,
           VoidPtr, Int32, VoidPtrPtr, VoidPtrPtr, Int64Ptr, Int64Ptr,
           VoidPtrPtr, VoidPtrPtr, Int32, Int32, Int32, VoidPtr, Int32, VoidPtr)
+__OMP_RTL(__tgt_target_kernel, false, Int32, IdentPtr, Int64, Int32, Int32,
+          VoidPtr, KernelArgsPtr)
+__OMP_RTL(__tgt_target_kernel_nowait, false, Int32, IdentPtr, Int64, Int32,
+          Int32, VoidPtr, KernelArgsPtr, Int32, VoidPtr, Int32, VoidPtr)
 __OMP_RTL(__tgt_register_requires, false, Void, Int64)
 __OMP_RTL(__tgt_target_data_begin_mapper, false, Void, IdentPtr, Int64, Int32, VoidPtrPtr,
           VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr)
@@ -937,6 +943,10 @@
                 ParamAttrs())
 __OMP_RTL_ATTRS(__tgt_target_teams_nowait_mapper, ForkAttrs, AttributeSet(),
                 ParamAttrs())
+__OMP_RTL_ATTRS(__tgt_target_kernel, ForkAttrs, AttributeSet(),
+                ParamAttrs())
+__OMP_RTL_ATTRS(__tgt_target_kernel_nowait, ForkAttrs, AttributeSet(),
+                ParamAttrs())
 __OMP_RTL_ATTRS(__tgt_register_requires, ForkAttrs, AttributeSet(),
                 ParamAttrs())
 __OMP_RTL_ATTRS(__tgt_target_data_begin_mapper, ForkAttrs, AttributeSet(),
Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -821,6 +821,23 @@
                                 omp::Directive CanceledDirective,
                                 FinalizeCallbackTy ExitCB = {});
 
+  /// Generate a target region entry call.
+  ///
+  /// \param Loc The location at which the request originated and is fulfilled.
+  /// \param Return Return value of the created function returned by reference.
+  /// \param DeviceID Identifier for the device via the 'device' clause.
+  /// \param NumTeams Numer of teams for the region via the 'num_teams' clause
+  ///                 or 0 if unspecified and -1 if there is no 'teams' clause.
+  /// \param NumThreads Number of threads via the 'thread_limit' clause.
+  /// \param HostPtr Pointer to the host-side pointer of the target kernel.
+  /// \param KernelArgs Array of arguments to the kernel.
+  /// \param NoWaitKernelArgs Optional array of arguments to the nowait kernel.
+  InsertPointTy emitTargetKernel(const LocationDescription &Loc,
+                                 Value *&Return, Value *Ident, Value *DeviceID,
+                                 Value *NumTeams, Value *NumThreads,
+                                 Value *HostPtr, ArrayRef<Value *> KernelArgs,
+                                 ArrayRef<Value *> NoWaitArgs = {});
+
   /// Generate a barrier runtime call.
   ///
   /// \param Loc The location at which the request originated and is fulfilled.
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6718,7 +6718,7 @@
       break;
     }
   } else if (DefaultNT == -1) {
-    return nullptr;
+    return llvm::ConstantInt::get(CGF.Int32Ty, -1);
   }
 
   return Bld.getInt32(DefaultNT);
@@ -10310,7 +10310,34 @@
     // Emit tripcount for the target loop-based directive.
     emitTargetNumIterationsCall(CGF, D, DeviceID, SizeEmitter);
 
-    bool HasNowait = D.hasClausesOfKind<OMPNowaitClause>();
+    // Arguments for the target kernel.
+    SmallVector<llvm::Value *> KernelArgs{
+        PointerNum,
+        InputInfo.BasePointersArray.getPointer(),
+        InputInfo.PointersArray.getPointer(),
+        InputInfo.SizesArray.getPointer(),
+        MapTypesArray,
+        MapNamesArray,
+        InputInfo.MappersArray.getPointer()};
+
+    // Arguments passed to the 'nowait' variant.
+    SmallVector<llvm::Value *> NoWaitKernelArgs{
+        CGF.Builder.getInt32(0),
+        llvm::ConstantPointerNull::get(CGM.VoidPtrTy),
+        CGF.Builder.getInt32(0),
+        llvm::ConstantPointerNull::get(CGM.VoidPtrTy),
+    };
+
+    bool HasNoWait = D.hasClausesOfKind<OMPNowaitClause>();
+    CGF.Builder.restoreIP(
+        HasNoWait
+            ? OMPBuilder.emitTargetKernel(CGF.Builder, Return, RTLoc, DeviceID,
+                                          NumTeams, NumThreads, OutlinedFnID,
+                                          KernelArgs, NoWaitKernelArgs)
+            : OMPBuilder.emitTargetKernel(CGF.Builder, Return, RTLoc, DeviceID,
+                                          NumTeams, NumThreads, OutlinedFnID,
+                                          KernelArgs));
+
     // The target region is an outlined function launched by the runtime
     // via calls __tgt_target() or __tgt_target_teams().
     //
@@ -10337,69 +10364,6 @@
     // In contrast, on the NVPTX target, the implementation of
     // __tgt_target_teams() launches a GPU kernel with the requested number
     // of teams and threads so no additional calls to the runtime are required.
-    if (NumTeams) {
-      // If we have NumTeams defined this means that we have an enclosed teams
-      // region. Therefore we also expect to have NumThreads defined. These two
-      // values should be defined in the presence of a teams directive,
-      // regardless of having any clauses associated. If the user is using teams
-      // but no clauses, these two values will be the default that should be
-      // passed to the runtime library - a 32-bit integer with the value zero.
-      assert(NumThreads && "Thread limit expression should be available along "
-                           "with number of teams.");
-      SmallVector<llvm::Value *> OffloadingArgs = {
-          RTLoc,
-          DeviceID,
-          OutlinedFnID,
-          PointerNum,
-          InputInfo.BasePointersArray.getPointer(),
-          InputInfo.PointersArray.getPointer(),
-          InputInfo.SizesArray.getPointer(),
-          MapTypesArray,
-          MapNamesArray,
-          InputInfo.MappersArray.getPointer(),
-          NumTeams,
-          NumThreads};
-      if (HasNowait) {
-        // Add int32_t depNum = 0, void *depList = nullptr, int32_t
-        // noAliasDepNum = 0, void *noAliasDepList = nullptr.
-        OffloadingArgs.push_back(CGF.Builder.getInt32(0));
-        OffloadingArgs.push_back(llvm::ConstantPointerNull::get(CGM.VoidPtrTy));
-        OffloadingArgs.push_back(CGF.Builder.getInt32(0));
-        OffloadingArgs.push_back(llvm::ConstantPointerNull::get(CGM.VoidPtrTy));
-      }
-      Return = CGF.EmitRuntimeCall(
-          OMPBuilder.getOrCreateRuntimeFunction(
-              CGM.getModule(), HasNowait
-                                   ? OMPRTL___tgt_target_teams_nowait_mapper
-                                   : OMPRTL___tgt_target_teams_mapper),
-          OffloadingArgs);
-    } else {
-      SmallVector<llvm::Value *> OffloadingArgs = {
-          RTLoc,
-          DeviceID,
-          OutlinedFnID,
-          PointerNum,
-          InputInfo.BasePointersArray.getPointer(),
-          InputInfo.PointersArray.getPointer(),
-          InputInfo.SizesArray.getPointer(),
-          MapTypesArray,
-          MapNamesArray,
-          InputInfo.MappersArray.getPointer()};
-      if (HasNowait) {
-        // Add int32_t depNum = 0, void *depList = nullptr, int32_t
-        // noAliasDepNum = 0, void *noAliasDepList = nullptr.
-        OffloadingArgs.push_back(CGF.Builder.getInt32(0));
-        OffloadingArgs.push_back(llvm::ConstantPointerNull::get(CGM.VoidPtrTy));
-        OffloadingArgs.push_back(CGF.Builder.getInt32(0));
-        OffloadingArgs.push_back(llvm::ConstantPointerNull::get(CGM.VoidPtrTy));
-      }
-      Return = CGF.EmitRuntimeCall(
-          OMPBuilder.getOrCreateRuntimeFunction(
-              CGM.getModule(), HasNowait ? OMPRTL___tgt_target_nowait_mapper
-                                         : OMPRTL___tgt_target_mapper),
-          OffloadingArgs);
-    }
-
     // Check the error code and execute the host version if required.
     llvm::BasicBlock *OffloadFailedBlock =
         CGF.createBasicBlock("omp_offload.failed");
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to