llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-hlsl Author: Chris B (llvm-beanz) <details> <summary>Changes</summary> In HLSL function parameters are passed by value, including array parameters. This change introduces a new AST node to represent array temporary expressions. They behave as lvalues to temporary arrays and decay to pointers for overload resolution and code generation. The behavior of HLSL function calls is documented in the [draft language specification](https://microsoft.github.io/hlsl-specs/specs/hlsl.pdf) under the Expr.Post.Call heading. Additionally the design of this implementation approach is documented in [Clang's documentation](https://clang.llvm.org/docs/HLSL/FunctionCalls.html) --- Full diff: https://github.com/llvm/llvm-project/pull/79382.diff 20 Files Affected: - (modified) clang/include/clang/AST/Expr.h (+38) - (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+2) - (modified) clang/include/clang/Basic/StmtNodes.td (+3) - (modified) clang/lib/AST/Expr.cpp (+11) - (modified) clang/lib/AST/ExprClassification.cpp (+1) - (modified) clang/lib/AST/ExprConstant.cpp (+1) - (modified) clang/lib/AST/ItaniumMangle.cpp (+4) - (modified) clang/lib/AST/StmtPrinter.cpp (+4) - (modified) clang/lib/AST/StmtProfile.cpp (+9) - (modified) clang/lib/CodeGen/CGExpr.cpp (+1) - (modified) clang/lib/CodeGen/CGExprAgg.cpp (+6) - (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1) - (modified) clang/lib/Sema/SemaInit.cpp (+5) - (modified) clang/lib/Sema/TreeTransform.h (+13) - (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+8) - (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+8) - (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+2-1) - (added) clang/test/CodeGenHLSL/ArrayTemporary.hlsl (+49) - (added) clang/test/SemaHLSL/ArrayTemporary.hlsl (+46) - (modified) clang/tools/libclang/CXCursor.cpp (+1) ``````````diff diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h index 59f0aee2c0cedd..c28747e7a3796f 100644 --- a/clang/include/clang/AST/Expr.h +++ b/clang/include/clang/AST/Expr.h @@ -6651,6 +6651,44 @@ class RecoveryExpr final : public Expr, friend class ASTStmtWriter; }; +/// HLSLArrayTemporaryExpr - In HLSL, default parameter passing is by value +/// including for arrays. This AST node represents a materialized temporary of a +/// constant size arrray. +class HLSLArrayTemporaryExpr : public Expr { + Expr *SourceExpr; + + HLSLArrayTemporaryExpr(Expr *S) + : Expr(HLSLArrayTemporaryExprClass, S->getType(), VK_LValue, OK_Ordinary), + SourceExpr(S) {} + + HLSLArrayTemporaryExpr(EmptyShell Empty) + : Expr(HLSLArrayTemporaryExprClass, Empty), SourceExpr(nullptr) {} + +public: + static HLSLArrayTemporaryExpr *Create(const ASTContext &Ctx, Expr *S); + static HLSLArrayTemporaryExpr *CreateEmpty(const ASTContext &Ctx); + + const Expr *getSourceExpr() const { return SourceExpr; } + Expr *getSourceExpr() { return SourceExpr; } + void setSourceExpr(Expr *S) { SourceExpr = S; } + + SourceLocation getBeginLoc() const { return SourceExpr->getBeginLoc(); } + + SourceLocation getEndLoc() const { return SourceExpr->getEndLoc(); } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == HLSLArrayTemporaryExprClass; + } + + // Iterators + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } +}; + } // end namespace clang #endif // LLVM_CLANG_AST_EXPR_H diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index 2aee6a947141b6..f7f8bbbb05cf20 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3171,6 +3171,8 @@ DEF_TRAVERSE_STMT(OMPTargetParallelGenericLoopDirective, DEF_TRAVERSE_STMT(OMPErrorDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) +DEF_TRAVERSE_STMT(HLSLArrayTemporaryExpr, {}) + // OpenMP clauses. template <typename Derived> bool RecursiveASTVisitor<Derived>::TraverseOMPClause(OMPClause *C) { diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td index cec301dfca2817..3d14d33a14b878 100644 --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -295,3 +295,6 @@ def OMPTargetTeamsGenericLoopDirective : StmtNode<OMPLoopDirective>; def OMPParallelGenericLoopDirective : StmtNode<OMPLoopDirective>; def OMPTargetParallelGenericLoopDirective : StmtNode<OMPLoopDirective>; def OMPErrorDirective : StmtNode<OMPExecutableDirective>; + +// HLSL Extensions +def HLSLArrayTemporaryExpr : StmtNode<Expr>; diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp index f1efa98e175edf..5903e2763dbe59 100644 --- a/clang/lib/AST/Expr.cpp +++ b/clang/lib/AST/Expr.cpp @@ -3569,6 +3569,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx, case ConceptSpecializationExprClass: case RequiresExprClass: case SYCLUniqueStableNameExprClass: + case HLSLArrayTemporaryExprClass: // These never have a side-effect. return false; @@ -5227,3 +5228,13 @@ OMPIteratorExpr *OMPIteratorExpr::CreateEmpty(const ASTContext &Context, alignof(OMPIteratorExpr)); return new (Mem) OMPIteratorExpr(EmptyShell(), NumIterators); } + +HLSLArrayTemporaryExpr * +HLSLArrayTemporaryExpr::Create(const ASTContext &Ctx, Expr *Base) { + return new (Ctx) HLSLArrayTemporaryExpr(Base); +} + +HLSLArrayTemporaryExpr * +HLSLArrayTemporaryExpr::CreateEmpty(const ASTContext &Ctx) { + return new (Ctx) HLSLArrayTemporaryExpr(EmptyShell()); +} diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp index ffa7c6802ea6e1..f07f2756c3f81d 100644 --- a/clang/lib/AST/ExprClassification.cpp +++ b/clang/lib/AST/ExprClassification.cpp @@ -148,6 +148,7 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) { case Expr::OMPArraySectionExprClass: case Expr::OMPArrayShapingExprClass: case Expr::OMPIteratorExprClass: + case Expr::HLSLArrayTemporaryExprClass: return Cl::CL_LValue; // C99 6.5.2.5p5 says that compound literals are lvalues. diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index f1d07d022b2584..ef474983b4b7dd 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -16044,6 +16044,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) { case Expr::CoyieldExprClass: case Expr::SYCLUniqueStableNameExprClass: case Expr::CXXParenListInitExprClass: + case Expr::HLSLArrayTemporaryExprClass: return ICEDiag(IK_NotICE, E->getBeginLoc()); case Expr::InitListExprClass: { diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp index 40b1e086ddd0c6..7da53e71719c14 100644 --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -4701,6 +4701,10 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity, E = cast<ConstantExpr>(E)->getSubExpr(); goto recurse; + case Expr::HLSLArrayTemporaryExprClass: + E = cast<HLSLArrayTemporaryExpr>(E)->getSourceExpr(); + goto recurse; + // FIXME: invent manglings for all these. case Expr::BlockExprClass: case Expr::ChooseExprClass: diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp index 9b5bf436c13be9..dd2a4f8be403da 100644 --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -2749,6 +2749,10 @@ void StmtPrinter::VisitAsTypeExpr(AsTypeExpr *Node) { OS << ")"; } +void StmtPrinter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *Node) { + PrintExpr(Node->getSourceExpr()); +} + //===----------------------------------------------------------------------===// // Stmt method implementations //===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index dd0838edab7b3f..0c5acb596f1179 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -2433,6 +2433,15 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) { } } +//===----------------------------------------------------------------------===// +// HLSL AST Nodes +//===----------------------------------------------------------------------===// + +void StmtProfiler::VisitHLSLArrayTemporaryExpr( + const HLSLArrayTemporaryExpr *S) { + VisitExpr(S); +} + void Stmt::Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context, bool Canonical, bool ProfileLambdaExpr) const { StmtProfilerWithPointers Profiler(ID, Context, Canonical, ProfileLambdaExpr); diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index c5f6b6d3a99f0b..cd7ad03208654a 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -1590,6 +1590,7 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E, case Expr::CXXUuidofExprClass: return EmitCXXUuidofLValue(cast<CXXUuidofExpr>(E)); case Expr::LambdaExprClass: + case Expr::HLSLArrayTemporaryExprClass: return EmitAggExprToLValue(E); case Expr::ExprWithCleanupsClass: { diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp index 810b28f25fa18b..57471da27133d9 100644 --- a/clang/lib/CodeGen/CGExprAgg.cpp +++ b/clang/lib/CodeGen/CGExprAgg.cpp @@ -235,6 +235,8 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> { RValue Res = CGF.EmitAtomicExpr(E); EmitFinalDestCopy(E->getType(), Res); } + + void VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *E); }; } // end anonymous namespace. @@ -1923,6 +1925,10 @@ void AggExprEmitter::VisitDesignatedInitUpdateExpr(DesignatedInitUpdateExpr *E) VisitInitListExpr(E->getUpdater()); } +void AggExprEmitter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *E) { + Visit(E->getSourceExpr()); +} + //===----------------------------------------------------------------------===// // Entry Points into this File //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp index 75730ea888afb4..0e4c11bf5b2614 100644 --- a/clang/lib/Sema/SemaExceptionSpec.cpp +++ b/clang/lib/Sema/SemaExceptionSpec.cpp @@ -1414,6 +1414,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) { case Expr::SourceLocExprClass: case Expr::ConceptSpecializationExprClass: case Expr::RequiresExprClass: + case Expr::HLSLArrayTemporaryExprClass: // These expressions can never throw. return CT_Cannot; diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp index 457fa377355a97..54f990f50d8576 100644 --- a/clang/lib/Sema/SemaInit.cpp +++ b/clang/lib/Sema/SemaInit.cpp @@ -10524,6 +10524,11 @@ Sema::PerformCopyInitialization(const InitializedEntity &Entity, Expr *InitE = Init.get(); assert(InitE && "No initialization expression?"); + if (LangOpts.HLSL) + if (auto AdjTy = dyn_cast<DecayedType>(Entity.getType())) + if (AdjTy->getOriginalType()->isConstantArrayType()) + InitE = HLSLArrayTemporaryExpr::Create(getASTContext(), InitE); + if (EqualLoc.isInvalid()) EqualLoc = InitE->getBeginLoc(); diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index e55e752b9cc354..3772e9a61f5f97 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -15461,6 +15461,19 @@ TreeTransform<Derived>::TransformCapturedStmt(CapturedStmt *S) { return getSema().ActOnCapturedRegionEnd(Body.get()); } +template <typename Derived> +ExprResult TreeTransform<Derived>::TransformHLSLArrayTemporaryExpr( + HLSLArrayTemporaryExpr *E) { + ExprResult SrcExpr = getDerived().TransformExpr(E->getSourceExpr()); + if (SrcExpr.isInvalid()) + return ExprError(); + + if (!getDerived().AlwaysRebuild() && SrcExpr.get() == E->getSourceExpr()) + return E; + + return HLSLArrayTemporaryExpr::Create(getSema().Context, SrcExpr.get()); +} + } // end namespace clang #endif // LLVM_CLANG_LIB_SEMA_TREETRANSFORM_H diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index 85ecfa1a1a0bf2..2203a0add1f648 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2776,6 +2776,14 @@ void ASTStmtReader::VisitOMPTargetParallelGenericLoopDirective( VisitOMPLoopDirective(D); } +//===----------------------------------------------------------------------===// +// HLSL AST Nodes +//===----------------------------------------------------------------------===// + +void ASTStmtReader::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *S) { + VisitExpr(S); +} + //===----------------------------------------------------------------------===// // ASTReader Implementation //===----------------------------------------------------------------------===// diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index e5836f5dcbe955..9951426e2f2d81 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2825,6 +2825,14 @@ void ASTStmtWriter::VisitOMPTargetParallelGenericLoopDirective( Code = serialization::STMT_OMP_TARGET_PARALLEL_GENERIC_LOOP_DIRECTIVE; } +//===----------------------------------------------------------------------===// +// HLSL AST Nodes +//===----------------------------------------------------------------------===// + +void ASTStmtWriter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *S) { + VisitExpr(S); +} + //===----------------------------------------------------------------------===// // ASTWriter Implementation //===----------------------------------------------------------------------===// diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp index 24e91a22fd6884..39cdec788648fb 100644 --- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp +++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -1821,7 +1821,8 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred, case Stmt::OMPTargetParallelGenericLoopDirectiveClass: case Stmt::CapturedStmtClass: case Stmt::OMPUnrollDirectiveClass: - case Stmt::OMPMetaDirectiveClass: { + case Stmt::OMPMetaDirectiveClass: + case Stmt::HLSLArrayTemporaryExprClass: { const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState()); Engine.addAbortedBlock(node, currBldrCtx->getBlock()); break; diff --git a/clang/test/CodeGenHLSL/ArrayTemporary.hlsl b/clang/test/CodeGenHLSL/ArrayTemporary.hlsl new file mode 100644 index 00000000000000..411e3182f45809 --- /dev/null +++ b/clang/test/CodeGenHLSL/ArrayTemporary.hlsl @@ -0,0 +1,49 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -emit-llvm -disable-llvm-passes -o - %s | Filecheck %s + +void fn(float x[2]) { } + +// CHECK-LABEL: define void {{.*}}call{{.*}} +// CHECK: [[Arr:%.*]] = alloca [2 x float] +// CHECK: [[Tmp:%.*]] = alloca [2 x float] +// CHECK: call void @llvm.memset.p0.i32(ptr align 4 [[Arr]], i8 0, i32 8, i1 false) +// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 8, i1 false) +// CHECK: [[Decay:%.*]] = getelementptr inbounds [2 x float], ptr [[Tmp]], i32 0, i32 0 +// CHECK: call void {{.*}}fn{{.*}}(ptr noundef [[Decay]]) +void call() { + float Arr[2] = {0, 0}; + fn(Arr); +} + +struct Obj { + float V; + int X; +}; + +void fn2(Obj O[4]) { } + +// CHECK-LABEL: define void {{.*}}call2{{.*}} +// CHECK: [[Arr:%.*]] = alloca [4 x %struct.Obj] +// CHECK: [[Tmp:%.*]] = alloca [4 x %struct.Obj] +// CHECK: call void @llvm.memset.p0.i32(ptr align 4 [[Arr]], i8 0, i32 32, i1 false) +// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 32, i1 false) +// CHECK: [[Decay:%.*]] = getelementptr inbounds [4 x %struct.Obj], ptr [[Tmp]], i32 0, i32 0 +// CHECK: call void {{.*}}fn2{{.*}}(ptr noundef [[Decay]]) +void call2() { + Obj Arr[4] = {}; + fn2(Arr); +} + + +void fn3(float x[2][2]) { } + +// CHECK-LABEL: define void {{.*}}call3{{.*}} +// CHECK: [[Arr:%.*]] = alloca [2 x [2 x float]] +// CHECK: [[Tmp:%.*]] = alloca [2 x [2 x float]] +// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Arr]], ptr align 4 {{.*}}, i32 16, i1 false) +// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 16, i1 false) +// CHECK: [[Decay:%.*]] = getelementptr inbounds [2 x [2 x float]], ptr [[Tmp]], i32 0, i32 0 +// CHECK: call void {{.*}}fn3{{.*}}(ptr noundef [[Decay]]) +void call3() { + float Arr[2][2] = {{0, 0}, {1,1}}; + fn3(Arr); +} diff --git a/clang/test/SemaHLSL/ArrayTemporary.hlsl b/clang/test/SemaHLSL/ArrayTemporary.hlsl new file mode 100644 index 00000000000000..32f4fc0f9abdca --- /dev/null +++ b/clang/test/SemaHLSL/ArrayTemporary.hlsl @@ -0,0 +1,46 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -ast-dump %s | Filecheck %s + +void fn(float x[2]) { } + +// CHECK: CallExpr {{.*}} 'void' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float *)' <FunctionToPointerDecay> +// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float *)' lvalue Function {{.*}} 'fn' 'void (float *)' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float *' <ArrayToPointerDecay> +// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'float[2]' lvalue + +void call() { + float Arr[2] = {0, 0}; + fn(Arr); +} + +struct Obj { + float V; + int X; +}; + +void fn2(Obj O[4]) { } + +// CHECK: CallExpr {{.*}} 'void' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(Obj *)' <FunctionToPointerDecay> +// CHECK-NEXT: DeclRefExpr {{.*}} 'void (Obj *)' lvalue Function {{.*}} 'fn2' 'void (Obj *)' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'Obj *' <ArrayToPointerDecay> +// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'Obj[4]' lvalue + +void call2() { + Obj Arr[4] = {}; + fn2(Arr); +} + + +void fn3(float x[2][2]) { } + +// CHECK: CallExpr {{.*}} 'void' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float (*)[2])' <FunctionToPointerDecay> +// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float (*)[2])' lvalue Function {{.*}} 'fn3' 'void (float (*)[2])' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float (*)[2]' <ArrayToPointerDecay> +// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'float[2][2]' lvalue + +void call3() { + float Arr[2][2] = {{0, 0}, {1,1}}; + fn3(Arr); +} diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp index 978adac5521aaa..43b088c573c0a4 100644 --- a/clang/tools/libclang/CXCursor.cpp +++ b/clang/tools/libclang/CXCursor.cpp @@ -335,6 +335,7 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent, case Stmt::ObjCSubscriptRefExprClass: case Stmt::RecoveryExprClass: case Stmt::SYCLUniqueStableNameExprClass: + case Stmt::HLSLArrayTemporaryExprClass: K = CXCursor_UnexposedExpr; break; `````````` </details> https://github.com/llvm/llvm-project/pull/79382 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits