beanz created this revision.
beanz added reviewers: bogner, aaron.ballman, python3kgae, pow2clk.
Herald added a subscriber: Anastasia.
Herald added a project: All.
beanz requested review of this revision.
Herald added a project: clang.

In HLSL buffer types support array subscripting syntax for loads and
stores. This change fleshes out the subscript operators to become array
accesses on the underlying handle pointer. This will allow LLVM
optimization passes to optimize resource accesses the same way any other
memory access would be optimized.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D131268

Files:
  clang/lib/Sema/HLSLExternalSemaSource.cpp
  clang/lib/Sema/SemaType.cpp
  clang/test/AST/HLSL/RWBuffer-AST.hlsl
  clang/test/CodeGenHLSL/buffer-array-operator.hlsl

Index: clang/test/CodeGenHLSL/buffer-array-operator.hlsl
===================================================================
--- /dev/null
+++ clang/test/CodeGenHLSL/buffer-array-operator.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+const RWBuffer<float> In;
+RWBuffer<float> Out;
+
+void fn(int Idx) {
+  Out[Idx] = In[Idx];
+}
+
+// This test is intended to verify reasonable code generation of the subscript
+// operator. In this test case we should be generating both the const and
+// non-const operators so we verify both cases.
+
+// Non-const comes first.
+// CHECK: float @"??A?$RWBuffer@M@hlsl@@QBAMI@Z"
+// CHECK: %this1 = load ptr, ptr %this.addr, align 4
+// CHECK-NEXT: %h = getelementptr inbounds %"class.hlsl::RWBuffer", ptr %this1, i32 0, i32 0
+// CHECK-NEXT: %0 = load ptr, ptr %h, align 4
+// CHECK-NEXT: %1 = load i32, ptr %Idx.addr, align 4
+// CHECK-NEXT: %arrayidx = getelementptr inbounds float, ptr %0, i32 %1
+// CHECK-NEXT: %2 = load float, ptr %arrayidx, align 4
+// CHECK-NEXT: ret float %2
+
+// Const comes next, and returns the pointer instead of the value.
+// CHECK: ptr @"??A?$RWBuffer@M@hlsl@@QAAAAMI@Z"
+// CHECK: %this1 = load ptr, ptr %this.addr, align 4
+// CHECK-NEXT: %h = getelementptr inbounds %"class.hlsl::RWBuffer", ptr %this1, i32 0, i32 0
+// CHECK-NEXT: %0 = load ptr, ptr %h, align 4
+// CHECK-NEXT: %1 = load i32, ptr %Idx.addr, align 4
+// CHECK-NEXT: %arrayidx = getelementptr inbounds float, ptr %0, i32 %1
+// CHECK-NEXT: ret ptr %arrayidx
\ No newline at end of file
Index: clang/test/AST/HLSL/RWBuffer-AST.hlsl
===================================================================
--- clang/test/AST/HLSL/RWBuffer-AST.hlsl
+++ clang/test/AST/HLSL/RWBuffer-AST.hlsl
@@ -39,11 +39,30 @@
 
 // CHECK: FinalAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit final
 // CHECK-NEXT: HLSLResourceAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit UAV
-// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> implicit h 'void *'
+// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> implicit h 'element_type *'
+
+// CHECK: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> operator[] 'element_type (unsigned int) const'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Idx 'unsigned int'
+// CHECK-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-NEXT: ArraySubscriptExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' lvalue
+// CHECK-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *' lvalue ->h 0x{{[0-9A-Fa-f]+}}
+// CHECK-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'const RWBuffer<element_type> *' implicit this
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int' ParmVar 0x{{[0-9A-Fa-f]+}} 'Idx' 'unsigned int'
+
+// CHECK-NEXT: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> operator[] 'element_type &(unsigned int)'
+// CHECK-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Idx 'unsigned int'
+// CHECK-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
+// CHECK-NEXT: ArraySubscriptExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' lvalue
+// CHECK-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *' lvalue ->h 0x{{[0-9A-Fa-f]+}}
+// CHECK-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'RWBuffer<element_type> *' implicit this
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int' ParmVar 0x{{[0-9A-Fa-f]+}} 'Idx' 'unsigned int'
+
 // CHECK: ClassTemplateSpecializationDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> class RWBuffer definition
 
 // CHECK: TemplateArgument type 'float'
 // CHECK-NEXT: BuiltinType 0x{{[0-9A-Fa-f]+}} 'float'
 // CHECK-NEXT: FinalAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit final
 // CHECK-NEXT: HLSLResourceAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit UAV
-// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc>  implicit referenced h 'void *'
+// CHECK-NEXT: FieldDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc>  implicit referenced h 'float *'
Index: clang/lib/Sema/SemaType.cpp
===================================================================
--- clang/lib/Sema/SemaType.cpp
+++ clang/lib/Sema/SemaType.cpp
@@ -2158,7 +2158,7 @@
     return QualType();
   }
 
-  if (getLangOpts().HLSL) {
+  if (getLangOpts().HLSL && Loc.isValid()) {
     Diag(Loc, diag::err_hlsl_pointers_unsupported) << 0;
     return QualType();
   }
@@ -2228,7 +2228,7 @@
     return QualType();
   }
 
-  if (getLangOpts().HLSL) {
+  if (getLangOpts().HLSL && Loc.isValid()) {
     Diag(Loc, diag::err_hlsl_pointers_unsupported) << 1;
     return QualType();
   }
@@ -2992,7 +2992,7 @@
     return QualType();
   }
 
-  if (getLangOpts().HLSL) {
+  if (getLangOpts().HLSL && Loc.isValid()) {
     Diag(Loc, diag::err_hlsl_pointers_unsupported) << 0;
     return QualType();
   }
Index: clang/lib/Sema/HLSLExternalSemaSource.cpp
===================================================================
--- clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -104,7 +104,14 @@
 
   BuiltinTypeDeclBuilder &
   addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) {
-    return addMemberVariable("h", Record->getASTContext().VoidPtrTy, Access);
+    QualType Ty = Record->getASTContext().VoidPtrTy;
+    if (Template) {
+      if (auto TTD = dyn_cast<TemplateTypeParmDecl>(
+              Template->getTemplateParameters()->getParam(0)))
+        Ty = Record->getASTContext().getPointerType(
+            QualType(TTD->getTypeForDecl(), 0));
+    }
+    return addMemberVariable("h", Ty, Access);
   }
 
   BuiltinTypeDeclBuilder &
@@ -158,15 +165,24 @@
         lookupBuiltinFunction(AST, S, "__builtin_hlsl_create_handle");
 
     Expr *RCExpr = emitResourceClassExpr(AST, RC);
-    CallExpr *Call =
-        CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue,
-                         SourceLocation(), FPOptionsOverride());
+    Expr *Call = CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue,
+                                  SourceLocation(), FPOptionsOverride());
 
     CXXThisExpr *This = new (AST)
         CXXThisExpr(SourceLocation(), Constructor->getThisType(), true);
-    MemberExpr *Handle = MemberExpr::CreateImplicit(
-        AST, This, true, Fields["h"], Fields["h"]->getType(), VK_LValue,
-        OK_Ordinary);
+    Expr *Handle = MemberExpr::CreateImplicit(AST, This, true, Fields["h"],
+                                              Fields["h"]->getType(), VK_LValue,
+                                              OK_Ordinary);
+
+    // If the handle isn't a void pointer, cast the builtin result to the
+    // correct type.
+    if (Handle->getType().getCanonicalType() != AST.VoidPtrTy) {
+      Call = CXXReinterpretCastExpr::Create(
+          AST, Handle->getType(), VK_PRValue, CK_Dependent, Call, nullptr,
+          AST.getTrivialTypeSourceInfo(Handle->getType(), SourceLocation()),
+          SourceLocation(), SourceLocation(), SourceRange());
+    }
+
     BinaryOperator *Assign = BinaryOperator::Create(
         AST, Handle, Call, BO_Assign, Handle->getType(), VK_LValue, OK_Ordinary,
         SourceLocation(), FPOptionsOverride());
@@ -179,6 +195,84 @@
     return *this;
   }
 
+  BuiltinTypeDeclBuilder &addArraySubscriptOperators() {
+    addArraySubscriptOperator(true);
+    addArraySubscriptOperator(false);
+    return *this;
+  }
+
+  BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) {
+    assert(Fields.count("h") > 0 &&
+           "Subscript operator must be added after the handle.");
+
+    FieldDecl *Handle = Fields["h"];
+    ASTContext &AST = Record->getASTContext();
+
+    assert(Handle->getType().getCanonicalType() != AST.VoidPtrTy &&
+           "Not yet supported for void pointer handles.");
+
+    QualType ElemTy =
+        QualType(Handle->getType()->getPointeeOrArrayElementType(), 0);
+    QualType ReturnTy = ElemTy;
+
+    FunctionProtoType::ExtProtoInfo ExtInfo;
+
+    // Const subscript operators return copies of elements, non-const return a
+    // reference so that they are assignable.
+    if (IsConst)
+      ExtInfo.TypeQuals.addConst();
+    else
+      ReturnTy = AST.getLValueReferenceType(ReturnTy);
+
+    QualType MethodTy =
+        AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
+    auto TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    auto MethodDecl = CXXMethodDecl::Create(
+        AST, Record, SourceLocation(),
+        DeclarationNameInfo(
+            AST.DeclarationNames.getCXXOperatorName(OO_Subscript),
+            SourceLocation()),
+        MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
+        SourceLocation());
+
+    IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier);
+    auto IdxParam = ParmVarDecl::Create(
+        AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(),
+        &II, AST.UnsignedIntTy,
+        AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()),
+        SC_None, nullptr);
+    MethodDecl->setParams({IdxParam});
+
+    // Also add the parameter to the function prototype.
+    auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    FnProtoLoc.setParam(0, IdxParam);
+
+    CXXThisExpr *This = new (AST)
+        CXXThisExpr(SourceLocation(), MethodDecl->getThisType(), true);
+    Expr *HandleAccess = MemberExpr::CreateImplicit(
+        AST, This, true, Handle, Handle->getType(), VK_LValue, OK_Ordinary);
+
+    Expr *IndexExpr = DeclRefExpr::Create(
+        AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false,
+        DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()),
+        AST.UnsignedIntTy, VK_PRValue);
+
+    Expr *Array =
+        new (AST) ArraySubscriptExpr(HandleAccess, IndexExpr, ElemTy, VK_LValue,
+                                     OK_Ordinary, SourceLocation());
+
+    Stmt *Return = ReturnStmt::Create(AST, SourceLocation(), Array, nullptr);
+
+    MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(),
+                                             SourceLocation(),
+                                             SourceLocation()));
+    MethodDecl->setLexicalDeclContext(Record);
+    MethodDecl->setAccess(AccessSpecifier::AS_public);
+    Record->addDecl(MethodDecl);
+
+    return *this;
+  }
+
   BuiltinTypeDeclBuilder &startDefinition() {
     Record->startDefinition();
     return *this;
@@ -368,6 +462,7 @@
   BuiltinTypeDeclBuilder(Record)
       .addHandleMember()
       .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV)
+      .addArraySubscriptOperators()
       .annotateResourceClass(HLSLResourceAttr::UAV)
       .completeDefinition();
 }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to