================ @@ -1126,6 +1133,185 @@ void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag, Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin()); } +// Callback used to create OpenMP runtime calls to support +// omp parallel clause for the device. +// We need to use this callback to replace call to the OutlinedFn in OuterFn +// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51) +static void +targetParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, + Function *OuterFn, Value *Ident, Value *IfCondition, + Value *NumThreads, Instruction *PrivTID, + AllocaInst *PrivTIDAddr, Value *ThreadID, + const SmallVector<Instruction *, 4> &ToBeDeleted) { + // Add some known attributes. + Module &M = OMPIRBuilder->M; + IRBuilder<> &Builder = OMPIRBuilder->Builder; + OutlinedFn.addParamAttr(0, Attribute::NoAlias); + OutlinedFn.addParamAttr(1, Attribute::NoAlias); + OutlinedFn.addParamAttr(0, Attribute::NoUndef); + OutlinedFn.addParamAttr(1, Attribute::NoUndef); + OutlinedFn.addFnAttr(Attribute::NoUnwind); + + assert(OutlinedFn.arg_size() >= 2 && + "Expected at least tid and bounded tid as arguments"); + unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2; + + CallInst *CI = cast<CallInst>(OutlinedFn.user_back()); + assert(CI && "Expected call instruction to outlined function"); + CI->getParent()->setName("omp_parallel"); + // Replace direct call to the outlined function by the call to + // __kmpc_parallel_51 + Builder.SetInsertPoint(CI); + + // Build call __kmpc_parallel_51 + auto PtrTy = Type::getInt8PtrTy(M.getContext()); + Value *Void = ConstantPointerNull::get(PtrTy); + // Add alloca for kernel args. Put this instruction at the beginning + // of the function. + OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP(); + Builder.SetInsertPoint(&OuterFn->front(), + OuterFn->front().getFirstInsertionPt()); + AllocaInst *ArgsAlloca = + Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars)); + Value *Args = + Builder.CreatePointerCast(ArgsAlloca, Type::getInt8PtrTy(M.getContext())); + Builder.restoreIP(CurrentIP); + // Store captured vars which are used by kmpc_parallel_51 + if (NumCapturedVars) { + for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) { + Value *V = *(CI->arg_begin() + 2 + Idx); + Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64( + ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx); + Builder.CreateStore(V, StoreAddress); + } + } + Value *Cond = IfCondition ? Builder.CreateSExtOrTrunc( + IfCondition, Type::getInt32Ty(M.getContext())) + : Builder.getInt32(1); + Value *Parallel51CallArgs[] = { + /* identifier*/ Ident, + /* global thread num*/ ThreadID, + /* if expression */ Cond, NumThreads ? NumThreads : Builder.getInt32(-1), ---------------- jdoerfert wrote:
Either provide comments for all arguments or none. this is confusing. https://github.com/llvm/llvm-project/pull/67000 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits