gtbercea created this revision.
gtbercea added reviewers: ABataev, carlo.bertolli, caomhin, hfinkel, Hahnfeld.
Herald added subscribers: cfe-commits, guansong, jholewinski.

This patch handles the Clang code generation phase for the OpenMP data sharing 
infrastructure.

TODO: add a more detailed description.


Repository:
  rC Clang

https://reviews.llvm.org/D43660

Files:
  lib/CodeGen/CGDecl.cpp
  lib/CodeGen/CGOpenMPRuntime.cpp
  lib/CodeGen/CGOpenMPRuntime.h
  lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
  lib/CodeGen/CGOpenMPRuntimeNVPTX.h
  lib/CodeGen/CGStmtOpenMP.cpp
  lib/CodeGen/CodeGenFunction.cpp
  test/OpenMP/nvptx_parallel_codegen.cpp

Index: test/OpenMP/nvptx_parallel_codegen.cpp
===================================================================
--- test/OpenMP/nvptx_parallel_codegen.cpp
+++ test/OpenMP/nvptx_parallel_codegen.cpp
@@ -92,20 +92,20 @@
   //
   // CHECK: [[EXEC_PARALLEL]]
   // CHECK: [[WF1:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
-  // CHECK: [[WM1:%.+]] = icmp eq i8* [[WF1]], bitcast (void (i32*, i32*)* [[PARALLEL_FN1:@.+]] to i8*)
+  // CHECK: [[WM1:%.+]] = icmp eq i8* [[WF1]], bitcast (void (i16, i32)* [[PARALLEL_FN1:@.+]]_wrapper to i8*)
   // CHECK: br i1 [[WM1]], label {{%?}}[[EXEC_PFN1:.+]], label {{%?}}[[CHECK_NEXT1:.+]]
   //
   // CHECK: [[EXEC_PFN1]]
-  // CHECK: call void [[PARALLEL_FN1]](
+  // CHECK: call void [[PARALLEL_FN1]]_wrapper(
   // CHECK: br label {{%?}}[[TERM_PARALLEL:.+]]
   //
   // CHECK: [[CHECK_NEXT1]]
   // CHECK: [[WF2:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
-  // CHECK: [[WM2:%.+]] = icmp eq i8* [[WF2]], bitcast (void (i32*, i32*)* [[PARALLEL_FN2:@.+]] to i8*)
+  // CHECK: [[WM2:%.+]] = icmp eq i8* [[WF2]], bitcast (void (i16, i32)* [[PARALLEL_FN2:@.+]]_wrapper to i8*)
   // CHECK: br i1 [[WM2]], label {{%?}}[[EXEC_PFN2:.+]], label {{%?}}[[CHECK_NEXT2:.+]]
   //
   // CHECK: [[EXEC_PFN2]]
-  // CHECK: call void [[PARALLEL_FN2]](
+  // CHECK: call void [[PARALLEL_FN2]]_wrapper(
   // CHECK: br label {{%?}}[[TERM_PARALLEL:.+]]
   //
   // CHECK: [[CHECK_NEXT2]]
@@ -152,13 +152,13 @@
   // CHECK-DAG: [[MWS:%.+]] = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
   // CHECK: [[MTMP1:%.+]] = sub i32 [[MNTH]], [[MWS]]
   // CHECK: call void @__kmpc_kernel_init(i32 [[MTMP1]]
-  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i32*, i32*)* [[PARALLEL_FN1]] to i8*),
+  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i16, i32)* [[PARALLEL_FN1]]_wrapper to i8*),
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK: call void @__kmpc_serialized_parallel(
   // CHECK: {{call|invoke}} void [[PARALLEL_FN3:@.+]](
   // CHECK: call void @__kmpc_end_serialized_parallel(
-  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i32*, i32*)* [[PARALLEL_FN2]] to i8*),
+  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i16, i32)* [[PARALLEL_FN2]]_wrapper to i8*),
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK-64-DAG: load i32, i32* [[REF_A]]
@@ -203,7 +203,7 @@
   //
   // CHECK: [[AWAIT_WORK]]
   // CHECK: call void @llvm.nvvm.barrier0()
-  // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]]
+  // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]],
   // CHECK: [[KPRB:%.+]] = zext i1 [[KPR]] to i8
   // store i8 [[KPRB]], i8* [[OMP_EXEC_STATUS]], align 1
   // CHECK: [[WORK:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
@@ -217,11 +217,11 @@
   //
   // CHECK: [[EXEC_PARALLEL]]
   // CHECK: [[WF:%.+]] = load i8*, i8** [[OMP_WORK_FN]],
-  // CHECK: [[WM:%.+]] = icmp eq i8* [[WF]], bitcast (void (i32*, i32*)* [[PARALLEL_FN4:@.+]] to i8*)
+  // CHECK: [[WM:%.+]] = icmp eq i8* [[WF]], bitcast (void (i16, i32)* [[PARALLEL_FN4:@.+]]_wrapper to i8*)
   // CHECK: br i1 [[WM]], label {{%?}}[[EXEC_PFN:.+]], label {{%?}}[[CHECK_NEXT:.+]]
   //
   // CHECK: [[EXEC_PFN]]
-  // CHECK: call void [[PARALLEL_FN4]](
+  // CHECK: call void [[PARALLEL_FN4]]_wrapper(
   // CHECK: br label {{%?}}[[TERM_PARALLEL:.+]]
   //
   // CHECK: [[CHECK_NEXT]]
@@ -283,7 +283,7 @@
   // CHECK: br i1 [[CMP]], label {{%?}}[[IF_THEN:.+]], label {{%?}}[[IF_ELSE:.+]]
   //
   // CHECK: [[IF_THEN]]
-  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i32*, i32*)* [[PARALLEL_FN4]] to i8*),
+  // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i16, i32)* [[PARALLEL_FN4]]_wrapper to i8*),
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK: call void @llvm.nvvm.barrier0()
   // CHECK: br label {{%?}}[[IF_END:.+]]
Index: lib/CodeGen/CodeGenFunction.cpp
===================================================================
--- lib/CodeGen/CodeGenFunction.cpp
+++ lib/CodeGen/CodeGenFunction.cpp
@@ -1058,6 +1058,11 @@
   EmitStartEHSpec(CurCodeDecl);
 
   PrologueCleanupDepth = EHStack.stable_begin();
+
+  // Emit OpenMP specific initialization of the device functions.
+  if (getLangOpts().OpenMP && CurCodeDecl)
+    CGM.getOpenMPRuntime().emitFunctionProlog(*this, CurCodeDecl);
+
   EmitFunctionProlog(*CurFnInfo, CurFn, Args);
 
   if (D && isa<CXXMethodDecl>(D) && cast<CXXMethodDecl>(D)->isInstance()) {
Index: lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- lib/CodeGen/CGStmtOpenMP.cpp
+++ lib/CodeGen/CGStmtOpenMP.cpp
@@ -589,6 +589,7 @@
                             /*RegisterCastedArgsOnly=*/true,
                             CapturedStmtInfo->getHelperName());
   CodeGenFunction WrapperCGF(CGM, /*suppressNewContext=*/true);
+  WrapperCGF.CapturedStmtInfo = CapturedStmtInfo;
   Args.clear();
   LocalAddrs.clear();
   VLASizes.clear();
Index: lib/CodeGen/CGOpenMPRuntimeNVPTX.h
===================================================================
--- lib/CodeGen/CGOpenMPRuntimeNVPTX.h
+++ lib/CodeGen/CGOpenMPRuntimeNVPTX.h
@@ -288,6 +288,14 @@
       CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn,
       ArrayRef<llvm::Value *> Args = llvm::None) const override;
 
+  /// Emits OpenMP-specific function prolog.
+  /// Required for device constructs.
+  void emitFunctionProlog(CodeGenFunction &CGF, const Decl *D) override;
+
+  /// Gets the OpenMP-specific address of the local variable.
+  Address getAddressOfLocalVariable(CodeGenFunction &CGF,
+                                    const VarDecl *VD) override;
+
   /// Target codegen is specialized based on two programming models: the
   /// 'generic' fork-join model of OpenMP, and a more GPU efficient 'spmd'
   /// model for constructs like 'target parallel' that support it.
@@ -299,12 +307,39 @@
     Unknown,
   };
 
+  /// Cleans up references to the objects in finished function.
+  ///
+  void functionFinished(CodeGenFunction &CGF) override;
+
 private:
   // Track the execution mode when codegening directives within a target
   // region. The appropriate mode (generic/spmd) is set on entry to the
   // target region and used by containing directives such as 'parallel'
   // to emit optimized code.
   ExecutionMode CurrentExecutionMode;
+
+  /// Map between an outlined function and its wrapper.
+  llvm::DenseMap<llvm::Function *, llvm::Function *> WrapperFunctionsMap;
+
+  /// Emit function which wraps the outline parallel region
+  /// and controls the parameters which are passed to this function.
+  /// The wrapper ensures that the outlined function is called
+  /// with the correct arguments when data is shared.
+  llvm::Function *
+  createDataSharingWrapper(llvm::Function *OutlinedParallelFn,
+      const OMPExecutableDirective &D);
+
+  /// The map of local variables to their addresses in the global memory.
+  typedef llvm::MapVector<
+      const Decl *, std::pair<const FieldDecl *, Address>> DeclToAddrMapTy;
+  /// Maps the function to the list of the globalized variables with their
+  /// addresses.
+  llvm::DenseMap<
+      llvm::Function *, std::pair<const RecordDecl *, DeclToAddrMapTy>>
+          FunctionGlobalizedDecls;
+  /// Map from function to global record pointer.
+  llvm::DenseMap<llvm::Function *, llvm::Value *>
+      FunctionToGlobalRecPtr;
 };
 
 } // CodeGen namespace.
Index: lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
===================================================================
--- lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
+++ lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
@@ -13,9 +13,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "CGOpenMPRuntimeNVPTX.h"
-#include "clang/AST/DeclOpenMP.h"
 #include "CodeGenFunction.h"
+#include "clang/AST/DeclOpenMP.h"
 #include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtVisitor.h"
+#include "llvm/ADT/SmallPtrSet.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -70,7 +72,18 @@
   /// index, int32_t width, int32_t reduce))
   OMPRTL_NVPTX__kmpc_teams_reduce_nowait,
   /// \brief Call to __kmpc_nvptx_end_reduce_nowait(int32_t global_tid);
-  OMPRTL_NVPTX__kmpc_end_reduce_nowait
+  OMPRTL_NVPTX__kmpc_end_reduce_nowait,
+  /// \brief Call to void* __kmpc_data_sharing_push_stack(size_t size);
+  OMPRTL_NVPTX__kmpc_data_sharing_push_stack,
+  /// \brief Call to void __kmpc_data_sharing_pop_stack(void *a);
+  OMPRTL_NVPTX__kmpc_data_sharing_pop_stack,
+  /// \brief Call to void __kmpc_begin_sharing_variables(void ***args,
+  /// size_t n_args);
+  OMPRTL_NVPTX__kmpc_begin_sharing_variables,
+  /// \brief Call to void __kmpc_end_sharing_variables();
+  OMPRTL_NVPTX__kmpc_end_sharing_variables,
+  /// \brief Call to void __kmpc_get_shared_variables(void ***GlobalArgs)
+  OMPRTL_NVPTX__kmpc_get_shared_variables,
 };
 
 /// Pre(post)-action for different OpenMP constructs specialized for NVPTX.
@@ -149,6 +162,246 @@
   /// barrier.
   NB_Parallel = 1,
 };
+
+/// Get the list of variables that can escape their declaration context.
+class CheckVarsEscapingDeclContext final
+    : public ConstStmtVisitor<CheckVarsEscapingDeclContext> {
+  CodeGenFunction &CGF;
+  llvm::SetVector<const ValueDecl *> EscapedDecls;
+  llvm::SmallPtrSet<const ValueDecl *, 4> IgnoredDecls;
+  bool AllEscaped = false;
+  RecordDecl *GlobalizedRD = nullptr;
+  llvm::SmallDenseMap<const ValueDecl *, const FieldDecl *> MappedDeclsFields;
+
+  void markAsEscaped(const ValueDecl *VD) {
+    if (IgnoredDecls.count(VD) ||
+        (CGF.CapturedStmtInfo &&
+         CGF.CapturedStmtInfo->lookup(cast<VarDecl>(VD))))
+      return;
+    EscapedDecls.insert(VD);
+  }
+
+  void VisitValueDecl(const ValueDecl *VD) {
+    if (VD->getType()->isLValueReferenceType()) {
+      markAsEscaped(VD);
+      if (const auto *VarD = dyn_cast<VarDecl>(VD)) {
+        if (!isa<ParmVarDecl>(VarD) && VarD->hasInit()) {
+          const bool SavedAllEscaped = AllEscaped;
+          AllEscaped = true;
+          Visit(VarD->getInit());
+          AllEscaped = SavedAllEscaped;
+        }
+      }
+    }
+  }
+  void VisitOpenMPCapturedStmt(const CapturedStmt *S) {
+    if (!S)
+      return;
+    for (const auto &C : S->captures()) {
+      if (C.capturesVariable() && !C.capturesVariableByCopy()) {
+        const ValueDecl *VD = C.getCapturedVar();
+        markAsEscaped(VD);
+        if (isa<OMPCapturedExprDecl>(VD))
+          VisitValueDecl(VD);
+      }
+    }
+  }
+
+  typedef std::pair<CharUnits /*Align*/, const ValueDecl *> VarsDataTy;
+  static bool stable_sort_comparator(const VarsDataTy P1,
+                                     const VarsDataTy P2) {
+    return P1.first > P2.first;
+  }
+
+  void buildRecordForGlobalizedVars() {
+    assert(!GlobalizedRD &&
+           "Record for globalized variables is built already.");
+    if (EscapedDecls.empty())
+      return;
+    ASTContext &C = CGF.getContext();
+    SmallVector<VarsDataTy, 4> GlobalizedVars;
+    for (const auto *D : EscapedDecls)
+      GlobalizedVars.emplace_back(C.getDeclAlign(D), D);
+    std::stable_sort(GlobalizedVars.begin(), GlobalizedVars.end(),
+                     stable_sort_comparator);
+    // Build struct _globalized_locals_ty {
+    //         /*  globalized vars  */
+    //       };
+    GlobalizedRD = C.buildImplicitRecord("_globalized_locals_ty");
+    GlobalizedRD->startDefinition();
+    for (const auto &Pair : GlobalizedVars) {
+      const ValueDecl *VD = Pair.second;
+      QualType Type = VD->getType();
+      if (Type->isLValueReferenceType())
+        Type = C.getPointerType(Type.getNonReferenceType());
+      else
+        Type = Type.getNonReferenceType();
+      SourceLocation Loc = VD->getLocation();
+      auto *Field = FieldDecl::Create(
+          C, GlobalizedRD, Loc, Loc, VD->getIdentifier(), Type,
+          C.getTrivialTypeSourceInfo(Type, SourceLocation()),
+          /*BW=*/nullptr, /*Mutable=*/false,
+          /*InitStyle=*/ICIS_NoInit);
+      Field->setAccess(AS_public);
+      GlobalizedRD->addDecl(Field);
+      if (VD->hasAttrs()) {
+        for (specific_attr_iterator<AlignedAttr> I(VD->getAttrs().begin()),
+             E(VD->getAttrs().end());
+             I != E; ++I)
+          Field->addAttr(*I);
+      }
+      MappedDeclsFields.try_emplace(VD, Field);
+    }
+    GlobalizedRD->completeDefinition();
+  }
+
+public:
+  CheckVarsEscapingDeclContext(CodeGenFunction &CGF,
+                               ArrayRef<const ValueDecl *> IgnoredDecls)
+      : CGF(CGF), IgnoredDecls(IgnoredDecls.begin(), IgnoredDecls.end()) {}
+  virtual ~CheckVarsEscapingDeclContext() = default;
+  void VisitDeclStmt(const DeclStmt *S) {
+    if (!S)
+      return;
+    for (const auto *D : S->decls())
+      if (const auto *VD = dyn_cast_or_null<ValueDecl>(D))
+        VisitValueDecl(VD);
+  }
+  void VisitOMPExecutableDirective(const OMPExecutableDirective *D) {
+    if (!D)
+      return;
+    if (D->hasAssociatedStmt()) {
+      if (const auto *S =
+              dyn_cast_or_null<CapturedStmt>(D->getAssociatedStmt()))
+        VisitOpenMPCapturedStmt(S);
+    }
+  }
+  void VisitCapturedStmt(const CapturedStmt *S) {
+    if (!S)
+      return;
+    for (const auto &C : S->captures()) {
+      if (C.capturesVariable() && !C.capturesVariableByCopy()) {
+        const ValueDecl *VD = C.getCapturedVar();
+        markAsEscaped(VD);
+        if (isa<OMPCapturedExprDecl>(VD))
+          VisitValueDecl(VD);
+      }
+    }
+  }
+  void VisitLambdaExpr(const LambdaExpr *E) {
+    if (!E)
+      return;
+    for (const auto &C : E->captures()) {
+      if (C.capturesVariable()) {
+        if (C.getCaptureKind() == LCK_ByRef) {
+          const ValueDecl *VD = C.getCapturedVar();
+          markAsEscaped(VD);
+          if (E->isInitCapture(&C) || isa<OMPCapturedExprDecl>(VD))
+            VisitValueDecl(VD);
+        }
+      }
+    }
+  }
+  void VisitBlockExpr(const BlockExpr *E) {
+    if (!E)
+      return;
+    for (const auto &C : E->getBlockDecl()->captures()) {
+      if (C.isByRef()) {
+        const VarDecl *VD = C.getVariable();
+        markAsEscaped(VD);
+        if (isa<OMPCapturedExprDecl>(VD) || VD->isInitCapture())
+          VisitValueDecl(VD);
+      }
+    }
+  }
+  void VisitCallExpr(const CallExpr *E) {
+    if (!E)
+      return;
+    for (const Expr *Arg : E->arguments()) {
+      if (!Arg)
+        continue;
+      if (Arg->isLValue()) {
+        const bool SavedAllEscaped = AllEscaped;
+        AllEscaped = true;
+        Visit(Arg);
+        AllEscaped = SavedAllEscaped;
+      } else
+        Visit(Arg);
+    }
+    Visit(E->getCallee());
+  }
+  void VisitDeclRefExpr(const DeclRefExpr *E) {
+    if (!E)
+      return;
+    const ValueDecl *VD = E->getDecl();
+    if (AllEscaped)
+      markAsEscaped(VD);
+    if (isa<OMPCapturedExprDecl>(VD))
+      VisitValueDecl(VD);
+    else if (const auto *VarD = dyn_cast<VarDecl>(VD))
+      if (VarD->isInitCapture())
+        VisitValueDecl(VD);
+  }
+  void VisitUnaryOperator(const UnaryOperator *E) {
+    if (!E)
+      return;
+    if (E->getOpcode() == UO_AddrOf) {
+      const bool SavedAllEscaped = AllEscaped;
+      AllEscaped = true;
+      Visit(E->getSubExpr());
+      AllEscaped = SavedAllEscaped;
+    } else
+      Visit(E->getSubExpr());
+  }
+  void VisitImplicitCastExpr(const ImplicitCastExpr *E) {
+    if (!E)
+      return;
+    if (E->getCastKind() == CK_ArrayToPointerDecay) {
+      const bool SavedAllEscaped = AllEscaped;
+      AllEscaped = true;
+      Visit(E->getSubExpr());
+      AllEscaped = SavedAllEscaped;
+    } else
+      Visit(E->getSubExpr());
+  }
+  void VisitExpr(const Expr *E) {
+    if (!E)
+      return;
+    bool SavedAllEscaped = AllEscaped;
+    if (!E->isLValue())
+      AllEscaped = false;
+    for (const auto *Child : E->children())
+      if (Child)
+        Visit(Child);
+    AllEscaped = SavedAllEscaped;
+  }
+  void VisitStmt(const Stmt *S) {
+    if (!S)
+      return;
+    for (const auto *Child : S->children())
+      if (Child)
+        Visit(Child);
+  }
+
+  const RecordDecl *getGlobalizedRecord() {
+    if (!GlobalizedRD)
+      buildRecordForGlobalizedVars();
+    return GlobalizedRD;
+  }
+
+  const FieldDecl *getFieldForGlobalizedVar(const ValueDecl *VD) const {
+    assert(GlobalizedRD &&
+           "Record for globalized variables must be generated already.");
+    auto I = MappedDeclsFields.find(VD);
+    if (I == MappedDeclsFields.end())
+      return nullptr;
+    return I->getSecond();
+  }
+
+  ArrayRef<const ValueDecl *> getEscapedDecls() const {
+    return EscapedDecls.getArrayRef();
+  }
+};
 } // anonymous namespace
 
 /// Get the GPU warp size.
@@ -299,6 +552,7 @@
   EntryFunctionState EST;
   WorkerFunctionState WST(CGM);
   Work.clear();
+  WrapperFunctionsMap.clear();
 
   // Emit target region as a standalone region.
   class NVPTXPrePostActionTy : public PrePostActionTy {
@@ -362,10 +616,58 @@
                          Bld.getInt16(/*RequiresOMPRuntime=*/1)};
   CGF.EmitRuntimeCall(
       createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_init), Args);
+  const auto I = FunctionGlobalizedDecls.find(CGF.CurFn);
+  if (I == FunctionGlobalizedDecls.end())
+    return;
+  const RecordDecl *GlobalizedVarsRecord = I->getSecond().first;
+  QualType RecTy = CGM.getContext().getRecordType(GlobalizedVarsRecord);
+
+  // Recover pointer to this function's global record. The runtime will
+  // handle the specifics of the allocation of the memory.
+  // Use actual memory size of the record including the padding
+  // for alignment purposes.
+  unsigned Alignment =
+      CGM.getContext().getTypeAlignInChars(RecTy).getQuantity();
+  unsigned GlobalRecordSize =
+      CGM.getContext().getTypeSizeInChars(RecTy).getQuantity();
+  GlobalRecordSize = llvm::alignTo(GlobalRecordSize, Alignment);
+  llvm::Value *GlobalRecordSizeArg[] = {
+      llvm::ConstantInt::get(CGM.SizeTy, GlobalRecordSize)};
+  llvm::Value *GlobalRecValue = CGF.EmitRuntimeCall(
+      createNVPTXRuntimeFunction(
+          OMPRTL_NVPTX__kmpc_data_sharing_push_stack),
+      GlobalRecordSizeArg);
+  auto GlobalRecCastAddr = Bld.CreatePointerBitCastOrAddrSpaceCast(
+      GlobalRecValue, CGF.ConvertTypeForMem(RecTy)->getPointerTo());
+  FunctionToGlobalRecPtr.try_emplace(CGF.CurFn, GlobalRecValue);
+
+  // Emit the "global alloca" which is a GEP from the global declaration record
+  // using the pointer returned by the runtime.
+  LValue Base =
+      CGF.MakeNaturalAlignPointeeAddrLValue(GlobalRecCastAddr, RecTy);
+  auto &Res = I->getSecond().second;
+  for (auto &Rec : Res) {
+    const FieldDecl *FD = Rec.second.first;
+    LValue VarAddr = CGF.EmitLValueForField(Base, FD);
+    Rec.second.second = VarAddr.getAddress();
+  }
 }
 
 void CGOpenMPRuntimeNVPTX::emitGenericEntryFooter(CodeGenFunction &CGF,
                                                   EntryFunctionState &EST) {
+  const auto I = FunctionGlobalizedDecls.find(CGF.CurFn);
+  if (I != FunctionGlobalizedDecls.end()) {
+    if (!CGF.HaveInsertPoint())
+      return;
+    auto I = FunctionToGlobalRecPtr.find(CGF.CurFn);
+    if (I != FunctionToGlobalRecPtr.end()) {
+      llvm::Value *Args[] = {I->getSecond()};
+      CGF.EmitRuntimeCall(
+          createNVPTXRuntimeFunction(
+              OMPRTL_NVPTX__kmpc_data_sharing_pop_stack), Args);
+    }
+  }
+
   if (!EST.ExitBB)
     EST.ExitBB = CGF.createBasicBlock(".exit");
 
@@ -554,14 +856,12 @@
     // Execute this outlined function.
     CGF.EmitBlock(ExecuteFNBB);
 
-    // Insert call to work function.
-    // FIXME: Pass arguments to outlined function from master thread.
-    auto *Fn = cast<llvm::Function>(W);
-    Address ZeroAddr =
-        CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, /*Name=*/".zero.addr");
-    CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C=*/0));
-    llvm::Value *FnArgs[] = {ZeroAddr.getPointer(), ZeroAddr.getPointer()};
-    emitCall(CGF, Fn, FnArgs);
+    // Insert call to work function via shared wrapper. The shared
+    // wrapper takes two arguments:
+    //   - the parallelism level;
+    //   - the master thread ID;
+    emitCall(CGF, W, {Bld.getInt16(/*ParallelLevel=*/0),
+        getMasterThreadID(CGF)});
 
     // Go to end of parallel region.
     CGF.EmitBranch(TerminateBB);
@@ -630,8 +930,7 @@
   case OMPRTL_NVPTX__kmpc_kernel_prepare_parallel: {
     /// Build void __kmpc_kernel_prepare_parallel(
     /// void *outlined_function, int16_t IsOMPRuntimeInitialized);
-    llvm::Type *TypeParams[] = {CGM.Int8PtrTy,
-                                CGM.Int16Ty};
+    llvm::Type *TypeParams[] = {CGM.Int8PtrTy, CGM.Int16Ty};
     llvm::FunctionType *FnTy =
         llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
     RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_prepare_parallel");
@@ -769,6 +1068,48 @@
         FnTy, /*Name=*/"__kmpc_nvptx_end_reduce_nowait");
     break;
   }
+  case OMPRTL_NVPTX__kmpc_data_sharing_push_stack: {
+    // Build void *__kmpc_data_sharing_push_stack(size_t size);
+    llvm::Type *TypeParams[] = {CGM.SizeTy};
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidPtrTy, TypeParams, /*isVarArg=*/false);
+    RTLFn = CGM.CreateRuntimeFunction(
+        FnTy, /*Name=*/"__kmpc_data_sharing_push_stack");
+    break;
+  }
+  case OMPRTL_NVPTX__kmpc_data_sharing_pop_stack: {
+    // Build void __kmpc_data_sharing_pop_stack(void *a);
+    llvm::Type *TypeParams[] = {CGM.VoidPtrTy};
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg=*/false);
+    RTLFn = CGM.CreateRuntimeFunction(
+        FnTy, /*Name=*/"__kmpc_data_sharing_pop_stack");
+    break;
+  }
+  case OMPRTL_NVPTX__kmpc_begin_sharing_variables: {
+    /// Build void __kmpc_begin_sharing_variables(void ***args,
+    /// size_t n_args);
+    llvm::Type *TypeParams[] = {CGM.Int8PtrPtrTy->getPointerTo(), CGM.SizeTy};
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
+    RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_begin_sharing_variables");
+    break;
+  }
+  case OMPRTL_NVPTX__kmpc_end_sharing_variables: {
+    /// Build void __kmpc_end_sharing_variables();
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidTy, {}, /*isVarArg*/ false);
+    RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_end_sharing_variables");
+    break;
+  }
+  case OMPRTL_NVPTX__kmpc_get_shared_variables: {
+    /// Build void __kmpc_get_shared_variables(void ***GlobalArgs);
+    llvm::Type *TypeParams[] = {CGM.Int8PtrPtrTy->getPointerTo()};
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
+    RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_get_shared_variables");
+    break;
+  }
   }
   return RTLFn;
 }
@@ -859,8 +1200,16 @@
 llvm::Value *CGOpenMPRuntimeNVPTX::emitParallelOutlinedFunction(
     const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
     OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
-  return CGOpenMPRuntime::emitParallelOutlinedFunction(D, ThreadIDVar,
-      InnermostKind, CodeGen);
+  auto *OutlinedFun = cast<llvm::Function>(
+      CGOpenMPRuntime::emitParallelOutlinedFunction(
+          D, ThreadIDVar, InnermostKind, CodeGen));
+  if (!isInSpmdExecutionMode()) {
+    llvm::Function *WrapperFun =
+        createDataSharingWrapper(OutlinedFun, D);
+    WrapperFunctionsMap[OutlinedFun] = WrapperFun;
+  }
+
+  return OutlinedFun;
 }
 
 llvm::Value *CGOpenMPRuntimeNVPTX::emitTeamsOutlinedFunction(
@@ -912,16 +1261,61 @@
     CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn,
     ArrayRef<llvm::Value *> CapturedVars, const Expr *IfCond) {
   llvm::Function *Fn = cast<llvm::Function>(OutlinedFn);
+  llvm::Function *WFn = WrapperFunctionsMap[Fn];
+
+  assert(WFn && "Wrapper function does not exist!");
 
-  auto &&L0ParallelGen = [this, Fn](CodeGenFunction &CGF, PrePostActionTy &) {
+  // Force inline this outlined function at its call site.
+  Fn->setLinkage(llvm::GlobalValue::InternalLinkage);
+
+  auto &&L0ParallelGen = [this, WFn, &CapturedVars](CodeGenFunction &CGF,
+                                                    PrePostActionTy &) {
     CGBuilderTy &Bld = CGF.Builder;
 
-    // TODO: Optimize runtime initialization.
-    llvm::Value *Args[] = {Bld.CreateBitOrPointerCast(Fn, CGM.Int8PtrTy),
-                           /*RequiresOMPRuntime=*/Bld.getInt16(1)};
-    CGF.EmitRuntimeCall(
-        createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel),
-        Args);
+    llvm::Value *ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy);
+
+    // There's somehting to share.
+    if (!CapturedVars.empty()) {
+      // Prepare for parallel region. Indicate the outlined function.
+      Address SharedArgs =
+          CGF.CreateDefaultAlignTempAlloca(CGF.VoidPtrPtrTy,
+              "shared_arg_refs");
+      llvm::Value *SharedArgsPtr = SharedArgs.getPointer();
+
+      // Handle the sharing of vaiables using global memory.
+      // Prepare for parallel region. Indicate the outlined function.
+      llvm::Value *Args[] = {ID, /*RequiresOMPRuntime=*/Bld.getInt16(1)};
+      CGF.EmitRuntimeCall(
+          createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel),
+          Args);
+
+      llvm::Value *DataSharingArgs[] = {SharedArgsPtr,
+          llvm::ConstantInt::get(CGM.SizeTy, CapturedVars.size())};
+      CGF.EmitRuntimeCall(
+          createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_begin_sharing_variables),
+          DataSharingArgs);
+
+      // Store variable address in a list of references to pass to workers.
+      unsigned Idx = 0;
+      ASTContext &Ctx = CGF.getContext();
+      for (llvm::Value *V : CapturedVars) {
+        Address Dst = Bld.CreateConstInBoundsGEP(
+            CGF.EmitLoadOfPointer(SharedArgs,
+            Ctx.getPointerType(
+                Ctx.getPointerType(Ctx.VoidPtrTy)).castAs<PointerType>()),
+            Idx, CGF.getPointerSize());
+        llvm::Value *PtrV = Bld.CreateBitCast(V, CGF.VoidPtrTy);
+        CGF.EmitStoreOfScalar(PtrV, Dst, /*Volatile=*/false,
+            Ctx.getPointerType(Ctx.VoidPtrTy));
+        Idx++;
+      }
+    } else {
+      // TODO: Optimize runtime initialization and pass in correct value.
+      llvm::Value *Args[] = {ID, /*RequiresOMPRuntime=*/Bld.getInt16(1)};
+      CGF.EmitRuntimeCall(
+          createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel),
+          Args);
+    }
 
     // Activate workers. This barrier is used by the master to signal
     // work for the workers.
@@ -935,8 +1329,12 @@
     // The master waits at this barrier until all workers are done.
     syncCTAThreads(CGF);
 
+    if (!CapturedVars.empty())
+      CGF.EmitRuntimeCall(
+          createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_end_sharing_variables));
+
     // Remember for post-processing in worker loop.
-    Work.emplace_back(Fn);
+    Work.emplace_back(WFn);
   };
 
   auto *RTLoc = emitUpdateLocation(CGF, Loc);
@@ -2340,3 +2738,154 @@
   }
   CGOpenMPRuntime::emitOutlinedFunctionCall(CGF, Loc, OutlinedFn, TargetArgs);
 }
+
+/// Emit function which wraps the outline parallel region
+/// and controls the arguments which are passed to this function.
+/// The wrapper ensures that the outlined function is called
+/// with the correct arguments when data is shared.
+llvm::Function *CGOpenMPRuntimeNVPTX::createDataSharingWrapper(
+    llvm::Function *OutlinedParallelFn, const OMPExecutableDirective &D) {
+  ASTContext &Ctx = CGM.getContext();
+  const auto &CS = *cast<CapturedStmt>(D.getAssociatedStmt());
+
+  // Create a function that takes as argument the source thread.
+  FunctionArgList WrapperArgs;
+  QualType Int16QTy =
+      Ctx.getIntTypeForBitwidth(/*DestWidth=*/16, /*Signed=*/false);
+  QualType Int32QTy =
+      Ctx.getIntTypeForBitwidth(/*DestWidth=*/32, /*Signed=*/false);
+  ImplicitParamDecl ParallelLevelArg(Ctx, Int16QTy, ImplicitParamDecl::Other);
+  ImplicitParamDecl WrapperArg(Ctx, Int32QTy, ImplicitParamDecl::Other);
+  WrapperArgs.emplace_back(&ParallelLevelArg);
+  WrapperArgs.emplace_back(&WrapperArg);
+
+  auto &CGFI =
+      CGM.getTypes().arrangeBuiltinFunctionDeclaration(Ctx.VoidTy, WrapperArgs);
+
+  auto *Fn = llvm::Function::Create(
+      CGM.getTypes().GetFunctionType(CGFI), llvm::GlobalValue::InternalLinkage,
+      OutlinedParallelFn->getName() + "_wrapper", &CGM.getModule());
+  CGM.SetInternalFunctionAttributes(/*D=*/nullptr, Fn, CGFI);
+  Fn->setLinkage(llvm::GlobalValue::InternalLinkage);
+
+  CodeGenFunction CGF(CGM, /*suppressNewContext=*/true);
+  CGF.StartFunction(GlobalDecl(), Ctx.VoidTy, Fn, CGFI, WrapperArgs);
+
+  const auto *RD = CS.getCapturedRecordDecl();
+  auto CurField = RD->field_begin();
+
+  // Get the array of arguments.
+  SmallVector<llvm::Value *, 8> Args;
+
+  // TODO: suppport SIMD and pass actual values
+  Args.emplace_back(llvm::ConstantPointerNull::get(
+      CGM.Int32Ty->getPointerTo()));
+  Args.emplace_back(llvm::ConstantPointerNull::get(
+      CGM.Int32Ty->getPointerTo()));
+
+  CGBuilderTy &Bld = CGF.Builder;
+  auto CI = CS.capture_begin();
+
+  // Use global memory for data sharing.
+  // Handle passing of global args to workers.
+  Address GlobalArgs =
+      CGF.CreateDefaultAlignTempAlloca(CGF.VoidPtrPtrTy,
+          "global_args");
+  llvm::Value *GlobalArgsPtr = GlobalArgs.getPointer();
+  llvm::Value *DataSharingArgs[] = {GlobalArgsPtr};
+  CGF.EmitRuntimeCall(
+      createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_get_shared_variables),
+      DataSharingArgs);
+
+  // Retrieve the shared variables from the list of references returned
+  // by the runtime. Pass the variables to the outlined function.
+  ASTContext &CGFContext = CGF.getContext();
+  for (unsigned I = 0, E = CS.capture_size(); I < E; ++I, ++CI, ++CurField) {
+    QualType ElemTy = CurField->getType();
+    // If this is a capture by copy the element type has to be the pointer to
+    // the data.
+    if (CI->capturesVariableByCopy())
+      ElemTy = CGFContext.getPointerType(ElemTy);
+
+    Address Src = Bld.CreateConstInBoundsGEP(
+        CGF.EmitLoadOfPointer(GlobalArgs,
+        CGFContext.getPointerType(
+            CGFContext.getPointerType(
+                CGFContext.VoidPtrTy)).castAs<PointerType>()),
+        I, CGF.getPointerSize());
+    Address TypedAddress = Bld.CreateBitCast(
+        Src, CGF.ConvertTypeForMem(CGFContext.getPointerType(ElemTy)));
+    llvm::Value *Arg = CGF.EmitLoadOfScalar(TypedAddress,
+        /*Volatile=*/false, CGFContext.getPointerType(ElemTy),
+        CurField->getInnerLocStart());
+    Args.emplace_back(Arg);
+  }
+
+  emitCall(CGF, OutlinedParallelFn, Args);
+  CGF.FinishFunction();
+  return Fn;
+}
+
+void CGOpenMPRuntimeNVPTX::emitFunctionProlog(CodeGenFunction &CGF,
+                                              const Decl *D) {
+  assert(D && "Expected function or captured|block decl.");
+  assert(FunctionGlobalizedDecls.count(CGF.CurFn) == 0 &&
+         "Function is registered already.");
+  SmallVector<const ValueDecl *, 4> IgnoredDecls;
+  const Stmt *Body = nullptr;
+  if (const auto *FD = dyn_cast<FunctionDecl>(D)) {
+    Body = FD->getBody();
+  } else if (const auto *BD = dyn_cast<BlockDecl>(D)) {
+    Body = BD->getBody();
+  } else if (const auto *CD = dyn_cast<CapturedDecl>(D)) {
+    Body = CD->getBody();
+    if (CGF.CapturedStmtInfo->getKind() == CR_OpenMP) {
+      const auto *CS = dyn_cast<CapturedStmt>(Body);
+      const auto *LastCS = CS;
+      while (CS && CS->getCapturedRegionKind() == CR_OpenMP) {
+        Body = CS->getCapturedStmt();
+        LastCS = CS;
+        CS = dyn_cast<CapturedStmt>(Body);
+      }
+      if (LastCS) {
+        IgnoredDecls.reserve(LastCS->capture_size());
+        for (const auto &Capture : LastCS->captures())
+          if (Capture.capturesVariable())
+            IgnoredDecls.emplace_back(Capture.getCapturedVar());
+      }
+    }
+  }
+  if (!Body)
+    return;
+  CheckVarsEscapingDeclContext VarChecker(CGF, IgnoredDecls);
+  VarChecker.Visit(Body);
+  const RecordDecl *GlobalizedVarsRecord =
+      VarChecker.getGlobalizedRecord();
+  if (!GlobalizedVarsRecord)
+    return;
+  auto &Res = FunctionGlobalizedDecls.try_emplace(CGF.CurFn,
+      GlobalizedVarsRecord,
+          DeclToAddrMapTy()).first->getSecond().second;
+  for (const ValueDecl *VD : VarChecker.getEscapedDecls()) {
+    const FieldDecl *FD = VarChecker.getFieldForGlobalizedVar(VD);
+    Res.insert(std::make_pair(VD,
+        std::make_pair(FD, Address::invalid())));
+  }
+}
+
+Address CGOpenMPRuntimeNVPTX::getAddressOfLocalVariable(
+    CodeGenFunction &CGF, const VarDecl *VD) {
+  auto I = FunctionGlobalizedDecls.find(CGF.CurFn);
+  if (I == FunctionGlobalizedDecls.end())
+    return Address::invalid();
+  auto VDI = I->getSecond().second.find(VD);
+  if (VDI == I->getSecond().second.end())
+    return Address::invalid();
+  return VDI->second.second;
+}
+
+void CGOpenMPRuntimeNVPTX::functionFinished(CodeGenFunction &CGF) {
+  FunctionToGlobalRecPtr.erase(CGF.CurFn);
+  FunctionGlobalizedDecls.erase(CGF.CurFn);
+  CGOpenMPRuntime::functionFinished(CGF);
+}
Index: lib/CodeGen/CGOpenMPRuntime.h
===================================================================
--- lib/CodeGen/CGOpenMPRuntime.h
+++ lib/CodeGen/CGOpenMPRuntime.h
@@ -675,7 +675,7 @@
 
   /// \brief Cleans up references to the objects in finished function.
   ///
-  void functionFinished(CodeGenFunction &CGF);
+  virtual void functionFinished(CodeGenFunction &CGF);
 
   /// \brief Emits code for parallel or serial call of the \a OutlinedFn with
   /// variables captured in a record which address is stored in \a
@@ -1362,6 +1362,14 @@
   emitOutlinedFunctionCall(CodeGenFunction &CGF, SourceLocation Loc,
                            llvm::Value *OutlinedFn,
                            ArrayRef<llvm::Value *> Args = llvm::None) const;
+
+  /// Emits OpenMP-specific function prolog.
+  /// Required for device constructs.
+  virtual void emitFunctionProlog(CodeGenFunction &CGF, const Decl *D) {}
+
+  /// Gets the OpenMP-specific address of the local variable.
+  virtual Address getAddressOfLocalVariable(CodeGenFunction &CGF,
+                                            const VarDecl *VD);
 };
 
 /// Class supports emissionof SIMD-only code.
Index: lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- lib/CodeGen/CGOpenMPRuntime.cpp
+++ lib/CodeGen/CGOpenMPRuntime.cpp
@@ -8019,6 +8019,11 @@
   return CGF.GetAddrOfLocalVar(NativeParam);
 }
 
+Address CGOpenMPRuntime::getAddressOfLocalVariable(CodeGenFunction &CGF,
+                                                   const VarDecl *VD) {
+  return Address::invalid();
+}
+
 llvm::Value *CGOpenMPSIMDRuntime::emitParallelOutlinedFunction(
     const OMPExecutableDirective &D, const VarDecl *ThreadIDVar,
     OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) {
Index: lib/CodeGen/CGDecl.cpp
===================================================================
--- lib/CodeGen/CGDecl.cpp
+++ lib/CodeGen/CGDecl.cpp
@@ -1016,9 +1016,17 @@
     }
 
     // A normal fixed sized variable becomes an alloca in the entry block,
-    // unless it's an NRVO variable.
-
-    if (NRVO) {
+    // unless:
+    // - it's an NRVO variable.
+    // - we are compiling OpenMP and it's an OpenMP local variable.
+
+    Address OpenMPLocalAddr =
+        getLangOpts().OpenMP
+            ? CGM.getOpenMPRuntime().getAddressOfLocalVariable(*this, &D)
+            : Address::invalid();
+    if (getLangOpts().OpenMP && OpenMPLocalAddr.isValid()) {
+      address = OpenMPLocalAddr;
+    } else if (NRVO) {
       // The named return value optimization: allocate this variable in the
       // return slot, so that we can elide the copy when returning this
       // variable (C++0x [class.copy]p34).
@@ -1820,9 +1828,18 @@
         pushDestroy(QualType::DK_cxx_destructor, DeclPtr, Ty);
     }
   } else {
-    // Otherwise, create a temporary to hold the value.
-    DeclPtr = CreateMemTemp(Ty, getContext().getDeclAlign(&D),
-                            D.getName() + ".addr");
+    // Check if the parameter address is controlled by OpenMP runtime.
+    Address OpenMPLocalAddr =
+        getLangOpts().OpenMP
+            ? CGM.getOpenMPRuntime().getAddressOfLocalVariable(*this, &D)
+            : Address::invalid();
+    if (getLangOpts().OpenMP && OpenMPLocalAddr.isValid()) {
+      DeclPtr = OpenMPLocalAddr;
+    } else {
+      // Otherwise, create a temporary to hold the value.
+      DeclPtr = CreateMemTemp(Ty, getContext().getDeclAlign(&D),
+                              D.getName() + ".addr");
+    }
     DoStore = true;
   }
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to