llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Steven Perron (s-perron)

<details>
<summary>Changes</summary>

Add the Gather functions for Texture2D. Variations for all components
are added (Red, Blue, Greed, Alpha). If targeting Vulkan then the
GatherCmp* function for a component other than 0 will result in an
error, as that will lead to invalid SPIR-V.

Part of https://github.com/llvm/llvm-project/issues/175630.

Assisted by: Gemini


---

Patch is 114.77 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/183323.diff


14 Files Affected:

- (modified) clang/include/clang/Basic/Builtins.td (+12) 
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2) 
- (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+48) 
- (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+2) 
- (modified) clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp (+147-4) 
- (modified) clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h (+4) 
- (modified) clang/lib/Sema/HLSLExternalSemaSource.cpp (+98-3) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+121-21) 
- (renamed) clang/test/AST/HLSL/Texture2D-scalar-AST.hlsl (+284-1) 
- (added) clang/test/AST/HLSL/Texture2D-vector-AST.hlsl (+726) 
- (added) clang/test/CodeGenHLSL/resources/Texture2D-Gather.hlsl (+183) 
- (added) clang/test/SemaHLSL/Texture2D-Gather.hlsl (+50) 
- (added) clang/test/SemaHLSL/Texture2D-GatherCmp-Vulkan.hlsl (+23) 
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+12) 


``````````diff
diff --git a/clang/include/clang/Basic/Builtins.td 
b/clang/include/clang/Basic/Builtins.td
index 531c3702161f2..fb196fab125f0 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5072,6 +5072,18 @@ def HLSLResourceSampleCmpLevelZero : 
LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLResourceGather : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_resource_gather"];
+  let Attributes = [NoThrow];
+  let Prototype = "void(...)";
+}
+
+def HLSLResourceGatherCmp : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_resource_gather_cmp"];
+  let Attributes = [NoThrow];
+  let Prototype = "void(...)";
+}
+
 def HLSLResourceUninitializedHandle : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_resource_uninitializedhandle"];
   let Attributes = [NoThrow];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td 
b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 8a3b9de19ad32..f2012ea2e0529 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -13595,6 +13595,8 @@ def err_hlsl_push_constant_unique
 def err_hlsl_samplecmp_requires_float
     : Error<"'SampleCmp' and 'SampleCmpLevelZero' require resource to contain "
             "a floating point type">;
+def err_hlsl_gathercmp_invalid_component
+    : Error<"gatherCmp%select{Red|Green|Blue|Alpha}0 operations on the Vulkan 
target are not supported; only GatherCmp and GatherCmpRed are allowed">;
 
 // Layout randomization diagnostics.
 def err_non_designated_init_used : Error<
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 70891eac39425..621c726ab3c60 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -662,6 +662,54 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned 
BuiltinID,
     return Builder.CreateIntrinsic(
         RetTy, CGM.getHLSLRuntime().getSampleCmpLevelZeroIntrinsic(), Args);
   }
+  case Builtin::BI__builtin_hlsl_resource_gather: {
+    Value *HandleOp = EmitScalarExpr(E->getArg(0));
+    Value *SamplerOp = EmitScalarExpr(E->getArg(1));
+    Value *CoordOp = EmitScalarExpr(E->getArg(2));
+    Value *ComponentOp = EmitScalarExpr(E->getArg(3));
+    if (ComponentOp->getType() != Builder.getInt32Ty())
+      ComponentOp = Builder.CreateIntCast(ComponentOp, Builder.getInt32Ty(),
+                                          /*isSigned=*/false);
+
+    SmallVector<Value *, 5> Args;
+    Args.push_back(HandleOp);
+    Args.push_back(SamplerOp);
+    Args.push_back(CoordOp);
+    Args.push_back(ComponentOp);
+    Args.push_back(emitHlslOffset(*this, E, 4));
+
+    llvm::Type *RetTy = ConvertType(E->getType());
+    return Builder.CreateIntrinsic(
+        RetTy, CGM.getHLSLRuntime().getGatherIntrinsic(), Args);
+  }
+  case Builtin::BI__builtin_hlsl_resource_gather_cmp: {
+    Value *HandleOp = EmitScalarExpr(E->getArg(0));
+    Value *SamplerOp = EmitScalarExpr(E->getArg(1));
+    Value *CoordOp = EmitScalarExpr(E->getArg(2));
+    Value *CompareOp = EmitScalarExpr(E->getArg(3));
+    if (CompareOp->getType() != Builder.getFloatTy())
+      CompareOp = Builder.CreateFPCast(CompareOp, Builder.getFloatTy());
+
+    SmallVector<Value *, 6> Args;
+    Args.push_back(HandleOp);
+    Args.push_back(SamplerOp);
+    Args.push_back(CoordOp);
+    Args.push_back(CompareOp);
+
+    if (CGM.getTarget().getTriple().isDXIL()) {
+      Value *ComponentOp = EmitScalarExpr(E->getArg(4));
+      if (ComponentOp->getType() != Builder.getInt32Ty())
+        ComponentOp = Builder.CreateIntCast(ComponentOp, Builder.getInt32Ty(),
+                                            /*isSigned=*/false);
+      Args.push_back(ComponentOp);
+    }
+
+    Args.push_back(emitHlslOffset(*this, E, 5));
+
+    llvm::Type *RetTy = ConvertType(E->getType());
+    return Builder.CreateIntrinsic(
+        RetTy, CGM.getHLSLRuntime().getGatherCmpIntrinsic(), Args);
+  }
   case Builtin::BI__builtin_hlsl_resource_load_with_status:
   case Builtin::BI__builtin_hlsl_resource_load_with_status_typed: {
     Value *HandleOp = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h 
b/clang/lib/CodeGen/CGHLSLRuntime.h
index dbbc887353cec..7b9bdae645540 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -174,6 +174,8 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(SampleCmpClamp, resource_samplecmp_clamp)
   GENERATE_HLSL_INTRINSIC_FUNCTION(SampleCmpLevelZero,
                                    resource_samplecmplevelzero)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Gather, resource_gather)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(GatherCmp, resource_gather_cmp)
   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding,
                                    resource_handlefrombinding)
   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromImplicitBinding,
diff --git a/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp 
b/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp
index 4ef54cf49412f..b0aaa563c81f7 100644
--- a/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp
+++ b/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp
@@ -16,6 +16,7 @@
 #include "clang/AST/Attr.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclCXX.h"
+#include "clang/AST/DeclTemplate.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/HLSLResource.h"
 #include "clang/AST/Stmt.h"
@@ -1304,7 +1305,7 @@ BuiltinTypeDeclBuilder &
 BuiltinTypeDeclBuilder::addSampleMethods(ResourceDimension Dim) {
   assert(!Record->isCompleteDefinition() && "record is already complete");
   ASTContext &AST = Record->getASTContext();
-  QualType ReturnType = getFirstTemplateTypeParam();
+  QualType ReturnType = getHandleElementType();
   QualType SamplerStateType =
       lookupBuiltinType(SemaRef, "SamplerState", Record->getDeclContext());
   uint32_t VecSize = getResourceDimensions(Dim);
@@ -1352,7 +1353,7 @@ BuiltinTypeDeclBuilder &
 BuiltinTypeDeclBuilder::addSampleBiasMethods(ResourceDimension Dim) {
   assert(!Record->isCompleteDefinition() && "record is already complete");
   ASTContext &AST = Record->getASTContext();
-  QualType ReturnType = getFirstTemplateTypeParam();
+  QualType ReturnType = getHandleElementType();
   QualType SamplerStateType =
       lookupBuiltinType(SemaRef, "SamplerState", Record->getDeclContext());
   uint32_t VecSize = getResourceDimensions(Dim);
@@ -1404,7 +1405,7 @@ BuiltinTypeDeclBuilder &
 BuiltinTypeDeclBuilder::addSampleGradMethods(ResourceDimension Dim) {
   assert(!Record->isCompleteDefinition() && "record is already complete");
   ASTContext &AST = Record->getASTContext();
-  QualType ReturnType = getFirstTemplateTypeParam();
+  QualType ReturnType = getHandleElementType();
   QualType SamplerStateType =
       lookupBuiltinType(SemaRef, "SamplerState", Record->getDeclContext());
   uint32_t VecSize = getResourceDimensions(Dim);
@@ -1461,7 +1462,7 @@ BuiltinTypeDeclBuilder &
 BuiltinTypeDeclBuilder::addSampleLevelMethods(ResourceDimension Dim) {
   assert(!Record->isCompleteDefinition() && "record is already complete");
   ASTContext &AST = Record->getASTContext();
-  QualType ReturnType = getFirstTemplateTypeParam();
+  QualType ReturnType = getHandleElementType();
   QualType SamplerStateType =
       lookupBuiltinType(SemaRef, "SamplerState", Record->getDeclContext());
   uint32_t VecSize = getResourceDimensions(Dim);
@@ -1588,6 +1589,133 @@ 
BuiltinTypeDeclBuilder::addSampleCmpLevelZeroMethods(ResourceDimension Dim) {
       .finalize();
 }
 
+QualType BuiltinTypeDeclBuilder::getGatherReturnType() {
+  ASTContext &AST = SemaRef.getASTContext();
+  QualType T = getHandleElementType();
+  if (T.isNull())
+    return QualType();
+
+  if (const auto *VT = T->getAs<VectorType>())
+    T = VT->getElementType();
+  else if (const auto *DT = T->getAs<DependentSizedExtVectorType>())
+    T = DT->getElementType();
+
+  return AST.getExtVectorType(T, 4);
+}
+
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addGatherMethods(ResourceDimension Dim) {
+  assert(!Record->isCompleteDefinition() && "record is already complete");
+  ASTContext &AST = Record->getASTContext();
+  QualType ReturnType = getGatherReturnType();
+
+  QualType SamplerStateType =
+      lookupBuiltinType(SemaRef, "SamplerState", Record->getDeclContext());
+  uint32_t VecSize = getResourceDimensions(Dim);
+  QualType LocationTy = AST.FloatTy;
+  QualType Float2Ty = AST.getExtVectorType(LocationTy, VecSize);
+  QualType IntTy = AST.IntTy;
+  QualType OffsetTy = AST.getExtVectorType(IntTy, VecSize);
+  using PH = BuiltinTypeMethodBuilder::PlaceHolder;
+
+  // Overloads for Gather, GatherRed, GatherGreen, GatherBlue, GatherAlpha
+  struct GatherVariant {
+    const char *Name;
+    int Component;
+  };
+  GatherVariant Variants[] = {{"Gather", 0},
+                              {"GatherRed", 0},
+                              {"GatherGreen", 1},
+                              {"GatherBlue", 2},
+                              {"GatherAlpha", 3}};
+
+  for (const auto &V : Variants) {
+    // ret GatherVariant(SamplerState s, float2 location)
+    BuiltinTypeMethodBuilder(*this, V.Name, ReturnType)
+        .addParam("Sampler", SamplerStateType)
+        .addParam("Location", Float2Ty)
+        .accessHandleFieldOnResource(PH::_0)
+        .callBuiltin("__builtin_hlsl_resource_gather", ReturnType, PH::Handle,
+                     PH::LastStmt, PH::_1,
+                     getConstantUnsignedIntExpr(V.Component))
+        .returnValue(PH::LastStmt)
+        .finalize();
+
+    // ret GatherVariant(SamplerState s, float2 location, int2 offset)
+    BuiltinTypeMethodBuilder(*this, V.Name, ReturnType)
+        .addParam("Sampler", SamplerStateType)
+        .addParam("Location", Float2Ty)
+        .addParam("Offset", OffsetTy)
+        .accessHandleFieldOnResource(PH::_0)
+        .callBuiltin("__builtin_hlsl_resource_gather", ReturnType, PH::Handle,
+                     PH::LastStmt, PH::_1,
+                     getConstantUnsignedIntExpr(V.Component), PH::_2)
+        .returnValue(PH::LastStmt)
+        .finalize();
+  }
+
+  return *this;
+}
+
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addGatherCmpMethods(ResourceDimension Dim) {
+  assert(!Record->isCompleteDefinition() && "record is already complete");
+  ASTContext &AST = Record->getASTContext();
+  QualType ReturnType = AST.getExtVectorType(AST.FloatTy, 4);
+
+  QualType SamplerComparisonStateType = lookupBuiltinType(
+      SemaRef, "SamplerComparisonState", Record->getDeclContext());
+  uint32_t VecSize = getResourceDimensions(Dim);
+  QualType FloatTy = AST.FloatTy;
+  QualType Float2Ty = AST.getExtVectorType(FloatTy, VecSize);
+  QualType IntTy = AST.IntTy;
+  QualType Int2Ty = AST.getExtVectorType(IntTy, VecSize);
+  using PH = BuiltinTypeMethodBuilder::PlaceHolder;
+
+  // Overloads for GatherCmp, GatherCmpRed, GatherCmpGreen, GatherCmpBlue,
+  // GatherCmpAlpha
+  struct GatherVariant {
+    const char *Name;
+    int Component;
+  };
+  GatherVariant Variants[] = {{"GatherCmp", 0},
+                              {"GatherCmpRed", 0},
+                              {"GatherCmpGreen", 1},
+                              {"GatherCmpBlue", 2},
+                              {"GatherCmpAlpha", 3}};
+
+  for (const auto &V : Variants) {
+    // ret GatherCmpVariant(SamplerComparisonState s, float2 location, float
+    // compare_value)
+    BuiltinTypeMethodBuilder(*this, V.Name, ReturnType)
+        .addParam("Sampler", SamplerComparisonStateType)
+        .addParam("Location", Float2Ty)
+        .addParam("CompareValue", FloatTy)
+        .accessHandleFieldOnResource(PH::_0)
+        .callBuiltin("__builtin_hlsl_resource_gather_cmp", ReturnType,
+                     PH::Handle, PH::LastStmt, PH::_1, PH::_2,
+                     getConstantUnsignedIntExpr(V.Component))
+        .returnValue(PH::LastStmt)
+        .finalize();
+
+    // ret GatherCmpVariant(SamplerComparisonState s, float2 location, float
+    // compare_value, int2 offset)
+    BuiltinTypeMethodBuilder(*this, V.Name, ReturnType)
+        .addParam("Sampler", SamplerComparisonStateType)
+        .addParam("Location", Float2Ty)
+        .addParam("CompareValue", FloatTy)
+        .addParam("Offset", Int2Ty)
+        .accessHandleFieldOnResource(PH::_0)
+        .callBuiltin("__builtin_hlsl_resource_gather_cmp", ReturnType,
+                     PH::Handle, PH::LastStmt, PH::_1, PH::_2,
+                     getConstantUnsignedIntExpr(V.Component), PH::_3)
+        .returnValue(PH::LastStmt)
+        .finalize();
+  }
+
+  return *this;
+}
+
 FieldDecl *BuiltinTypeDeclBuilder::getResourceHandleField() const {
   auto I = Fields.find("__handle");
   assert(I != Fields.end() &&
@@ -1616,6 +1744,14 @@ QualType 
BuiltinTypeDeclBuilder::getFirstTemplateTypeParam() {
 QualType BuiltinTypeDeclBuilder::getHandleElementType() {
   if (Template)
     return getFirstTemplateTypeParam();
+
+  if (auto *PartialSpec =
+          dyn_cast<ClassTemplatePartialSpecializationDecl>(Record)) {
+    const auto &Args = PartialSpec->getTemplateArgs();
+    if (Args.size() > 0 && Args[0].getKind() == TemplateArgument::Type)
+      return Args[0].getAsType();
+  }
+
   // TODO: Should we default to VoidTy? Using `i8` is arguably ambiguous.
   return SemaRef.getASTContext().Char8Ty;
 }
@@ -1642,6 +1778,13 @@ Expr *BuiltinTypeDeclBuilder::getConstantIntExpr(int 
value) {
       SourceLocation());
 }
 
+Expr *BuiltinTypeDeclBuilder::getConstantUnsignedIntExpr(unsigned value) {
+  ASTContext &AST = SemaRef.getASTContext();
+  return IntegerLiteral::Create(
+      AST, llvm::APInt(AST.getTypeSize(AST.UnsignedIntTy), value),
+      AST.UnsignedIntTy, SourceLocation());
+}
+
 BuiltinTypeDeclBuilder &
 BuiltinTypeDeclBuilder::addSimpleTemplateParams(ArrayRef<StringRef> Names,
                                                 ConceptDecl *CD = nullptr) {
diff --git a/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h 
b/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h
index fcb61731c5416..c27ff30c6ff73 100644
--- a/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h
+++ b/clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h
@@ -100,6 +100,8 @@ class BuiltinTypeDeclBuilder {
   BuiltinTypeDeclBuilder &addSampleLevelMethods(ResourceDimension Dim);
   BuiltinTypeDeclBuilder &addSampleCmpMethods(ResourceDimension Dim);
   BuiltinTypeDeclBuilder &addSampleCmpLevelZeroMethods(ResourceDimension Dim);
+  BuiltinTypeDeclBuilder &addGatherMethods(ResourceDimension Dim);
+  BuiltinTypeDeclBuilder &addGatherCmpMethods(ResourceDimension Dim);
   BuiltinTypeDeclBuilder &addIncrementCounterMethod();
   BuiltinTypeDeclBuilder &addDecrementCounterMethod();
   BuiltinTypeDeclBuilder &addHandleAccessFunction(DeclarationName &Name,
@@ -132,11 +134,13 @@ class BuiltinTypeDeclBuilder {
   BuiltinTypeDeclBuilder &
   addCounterHandleMember(ResourceClass RC, bool IsROV, bool RawBuffer,
                          AccessSpecifier Access = AccessSpecifier::AS_private);
+  QualType getGatherReturnType();
   FieldDecl *getResourceHandleField() const;
   FieldDecl *getResourceCounterHandleField() const;
   QualType getFirstTemplateTypeParam();
   QualType getHandleElementType();
   Expr *getConstantIntExpr(int value);
+  Expr *getConstantUnsignedIntExpr(unsigned value);
   HLSLAttributedResourceType::Attributes getResourceAttrs() const;
 };
 
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp 
b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 662627901539a..788a129ec5390 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -15,12 +15,14 @@
 #include "clang/AST/Attr.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclCXX.h"
+#include "clang/AST/DeclTemplate.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/Type.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Sema/Lookup.h"
 #include "clang/Sema/Sema.h"
 #include "clang/Sema/SemaHLSL.h"
+#include "clang/Sema/TemplateDeduction.h"
 #include "llvm/ADT/SmallVector.h"
 
 using namespace clang;
@@ -265,7 +267,72 @@ static BuiltinTypeDeclBuilder 
setupTextureType(CXXRecordDecl *Decl, Sema &S,
       .addSampleGradMethods(Dim)
       .addSampleLevelMethods(Dim)
       .addSampleCmpMethods(Dim)
-      .addSampleCmpLevelZeroMethods(Dim);
+      .addSampleCmpLevelZeroMethods(Dim)
+      .addGatherMethods(Dim)
+      .addGatherCmpMethods(Dim);
+}
+
+// Add a partial specialization for a template. The `TextureTemplate` is
+// `Texture<element_type>`, and it will be specialized for vectors:
+// `Texture<vector<element_type, element_count>>`.
+static ClassTemplatePartialSpecializationDecl *
+addVectorTexturePartialSpecialization(Sema &S, NamespaceDecl *HLSLNamespace,
+                                      ClassTemplateDecl *TextureTemplate) {
+  ASTContext &AST = S.getASTContext();
+
+  // Create the template parameters: element_type and element_count.
+  auto *ElementType = TemplateTypeParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+      &AST.Idents.get("element_type"), false, false);
+  auto *ElementCount = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+      &AST.Idents.get("element_count"), AST.IntTy, false,
+      AST.getTrivialTypeSourceInfo(AST.IntTy));
+
+  auto *TemplateParams = TemplateParameterList::Create(
+      AST, SourceLocation(), SourceLocation(), {ElementType, ElementCount},
+      SourceLocation(), nullptr);
+
+  // Create the dependent vector type: vector<element_type, element_count>.
+  QualType VectorType = AST.getDependentSizedExtVectorType(
+      AST.getTemplateTypeParmType(0, 0, false, ElementType),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), ElementCount, false,
+          DeclarationNameInfo(ElementCount->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      SourceLocation());
+
+  // Create the partial specialization declaration.
+  QualType CanonInjectedTST =
+      AST.getCanonicalType(AST.getTemplateSpecializationType(
+          ElaboratedTypeKeyword::Class, TemplateName(TextureTemplate),
+          {TemplateArgument(VectorType)}, {}));
+
+  auto *PartialSpec = ClassTemplatePartialSpecializationDecl::Create(
+      AST, TagDecl::TagKind::Class, HLSLNamespace, SourceLocation(),
+      SourceLocation(), TemplateParams, TextureTemplate,
+      {TemplateArgument(VectorType)},
+      CanQualType::CreateUnsafe(CanonInjectedTST), nullptr);
+
+  // Set the template arguments as written.
+  TemplateArgument Arg(VectorType);
+  TemplateArgumentLoc ArgLoc =
+      S.getTrivialTemplateArgumentLoc(Arg, QualType(), SourceLocation());
+  TemplateArgumentListInfo ArgsInfo =
+      TemplateArgumentListInfo(SourceLocation(), SourceLocation());
+  ArgsInfo.addArgument(ArgLoc);
+  PartialSpec->setTemplateArgsAsWritten(
+      ASTTemplateArgumentListInfo::Create(AST, ArgsInfo));
+
+  PartialSpec->setImplicit(true);
+  PartialSpec->setLexicalDeclContext(HLSLNamespace);
+  PartialSpec->setHasExternalLexicalStorage();
+
+  // Add the partial specialization to the namespace and the class template.
+  HLSLNamespace->addDecl(PartialSpec);
+  TextureTemplate->AddPartialSpecialization(PartialSpec, nullptr);
+
+  return PartialSpec;
 }
 
 // This function is responsible for constructing the constraint expression for
@@ -548,11 +615,20 @@ void 
HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
   Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Texture2D")
              .addSimpleTemplateParams({"element_type"}, TypedBufferConcept)
              .finalizeForwardDeclaration();
+
   onCompletion(Decl, [this](CXXRecordDecl *Decl) {
     setupTextureType(Decl, *SemaPtr, ResourceClass::SRV, /*IsROV=*/false,
                      ResourceDimension::Dim2D)
         .completeDefinition();
   });
+
+  auto *PartialSpec = addVectorTexturePartialSpecialization(
+      *SemaPtr, HLSLNamespace, Decl->getDescribedClassTemplate());
+  onCompletion(PartialSpec, [this](CXXRecordDecl *Decl) {
+    setupTextureType(Decl, *SemaPtr, ResourceClass::SRV, /*IsROV=*/false,
+                     ResourceDimension::Dim2D)
+        .completeDefinition();
+  });
 }
 
 void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
@@ -568,8 +644,27 @@ void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
 
   // ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/183323
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to