llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Justin Bogner (bogner)

<details>
<summary>Changes</summary>

This adjusts the DXILOpBuilder API in a couple of ways:
1. Remove the need to call `getOverloadTy` before creating Ops
2. Introduce `tryCreateOp` to parallel `createOp` but propagate errors
3. Introduce specialized createOp methods for each DXIL Op

This will simplify usage of the builder in upcoming changes, and also allows us
to propagate errors via DiagnosticInfo rather than using fatal errors.


---

Patch is 43.67 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/101250.diff


31 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXIL.td (+29-29) 
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+62-82) 
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+27-14) 
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+12-7) 
- (modified) llvm/test/CodeGen/DirectX/acos_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/asin_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/atan_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/ceil_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/cos_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/cosh_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/dot2_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/dot3_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/dot4_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/exp2_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/flattened_thread_id_in_group_error.ll 
(+2-1) 
- (modified) llvm/test/CodeGen/DirectX/floor_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/frac_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/group_id_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/isinf_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/log2_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/round_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/rsqrt_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/sin_error.ll (+2-2) 
- (modified) llvm/test/CodeGen/DirectX/sinh_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/sqrt_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/tan_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/tanh_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/thread_id_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/thread_id_in_group_error.ll (+2-1) 
- (modified) llvm/test/CodeGen/DirectX/trunc_error.ll (+2-1) 
- (modified) llvm/utils/TableGen/DXILEmitter.cpp (+7-14) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index a66f5b6470934..67015cff78a79 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -318,7 +318,7 @@ class DXILOp<int opcode, DXILOpClass opclass> {
 def Abs :  DXILOp<6, unary> {
   let Doc = "Returns the absolute value of the input.";
   let LLVMIntrinsic = int_fabs;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy, doubleTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -338,7 +338,7 @@ def IsInf :  DXILOp<9, isSpecialFloat> {
 def Cos :  DXILOp<12, unary> {
   let Doc = "Returns cosine(theta) for theta in radians.";
   let LLVMIntrinsic = int_cos;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -348,7 +348,7 @@ def Cos :  DXILOp<12, unary> {
 def Sin :  DXILOp<13, unary> {
   let Doc = "Returns sine(theta) for theta in radians.";
   let LLVMIntrinsic = int_sin;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -358,7 +358,7 @@ def Sin :  DXILOp<13, unary> {
 def Tan :  DXILOp<14, unary> {
   let Doc = "Returns tangent(theta) for theta in radians.";
   let LLVMIntrinsic = int_tan;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -368,7 +368,7 @@ def Tan :  DXILOp<14, unary> {
 def ACos :  DXILOp<15, unary> {
   let Doc = "Returns the arccosine of the specified value.";
   let LLVMIntrinsic = int_acos;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -378,7 +378,7 @@ def ACos :  DXILOp<15, unary> {
 def ASin :  DXILOp<16, unary> {
   let Doc = "Returns the arcsine of the specified value.";
   let LLVMIntrinsic = int_asin;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -388,7 +388,7 @@ def ASin :  DXILOp<16, unary> {
 def ATan :  DXILOp<17, unary> {
   let Doc = "Returns the arctangent of the specified value.";
   let LLVMIntrinsic = int_atan;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -398,7 +398,7 @@ def ATan :  DXILOp<17, unary> {
 def HCos :  DXILOp<18, unary> {
   let Doc = "Returns the hyperbolic cosine of the specified value.";
   let LLVMIntrinsic = int_cosh;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -408,7 +408,7 @@ def HCos :  DXILOp<18, unary> {
 def HSin :  DXILOp<19, unary> {
   let Doc = "Returns the hyperbolic sine of the specified value.";
   let LLVMIntrinsic = int_sinh;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -418,7 +418,7 @@ def HSin :  DXILOp<19, unary> {
 def HTan :  DXILOp<20, unary> {
   let Doc = "Returns the hyperbolic tan of the specified value.";
   let LLVMIntrinsic = int_tanh;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -429,7 +429,7 @@ def Exp2 :  DXILOp<21, unary> {
   let Doc = "Returns the base 2 exponential, or 2**x, of the specified value. "
             "exp2(x) = 2**x.";
   let LLVMIntrinsic = int_exp2;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -440,7 +440,7 @@ def Frac :  DXILOp<22, unary> {
   let Doc = "Returns a fraction from 0 to 1 that represents the decimal part "
             "of the input.";
   let LLVMIntrinsic = int_dx_frac;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -450,7 +450,7 @@ def Frac :  DXILOp<22, unary> {
 def Log2 :  DXILOp<23, unary> {
   let Doc = "Returns the base-2 logarithm of the specified value.";
   let LLVMIntrinsic = int_log2;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -461,7 +461,7 @@ def Sqrt :  DXILOp<24, unary> {
   let Doc = "Returns the square root of the specified floating-point value, "
             "per component.";
   let LLVMIntrinsic = int_sqrt;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -472,7 +472,7 @@ def RSqrt :  DXILOp<25, unary> {
   let Doc = "Returns the reciprocal of the square root of the specified value. 
"
             "rsqrt(x) = 1 / sqrt(x).";
   let LLVMIntrinsic = int_dx_rsqrt;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -483,7 +483,7 @@ def Round :  DXILOp<26, unary> {
   let Doc = "Returns the input rounded to the nearest integer within a "
             "floating-point type.";
   let LLVMIntrinsic = int_roundeven;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -494,7 +494,7 @@ def Floor :  DXILOp<27, unary> {
   let Doc =
       "Returns the largest integer that is less than or equal to the input.";
   let LLVMIntrinsic = int_floor;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -505,7 +505,7 @@ def Ceil :  DXILOp<28, unary> {
   let Doc = "Returns the smallest integer that is greater than or equal to the 
"
             "input.";
   let LLVMIntrinsic = int_ceil;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -515,7 +515,7 @@ def Ceil :  DXILOp<28, unary> {
 def Trunc :  DXILOp<29, unary> {
   let Doc = "Returns the specified value truncated to the integer component.";
   let LLVMIntrinsic = int_trunc;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads = [Overloads<DXIL1_0, [halfTy, floatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
@@ -525,7 +525,7 @@ def Trunc :  DXILOp<29, unary> {
 def Rbits :  DXILOp<30, unary> {
   let Doc = "Returns the specified value with its bits reversed.";
   let LLVMIntrinsic = int_bitreverse;
-  let arguments = [LLVMMatchType<0>];
+  let arguments = [overloadTy];
   let result = overloadTy;
   let overloads =
       [Overloads<DXIL1_0, [i16Ty, i32Ty, i64Ty]>];
@@ -536,7 +536,7 @@ def Rbits :  DXILOp<30, unary> {
 def FMax :  DXILOp<35, binary> {
   let Doc = "Float maximum. FMax(a,b) = a > b ? a : b";
   let LLVMIntrinsic = int_maxnum;
-  let arguments = [LLVMMatchType<0>, LLVMMatchType<0>];
+  let arguments = [overloadTy, overloadTy];
   let result = overloadTy;
   let overloads =
       [Overloads<DXIL1_0, [halfTy, floatTy, doubleTy]>];
@@ -547,7 +547,7 @@ def FMax :  DXILOp<35, binary> {
 def FMin :  DXILOp<36, binary> {
   let Doc = "Float minimum. FMin(a,b) = a < b ? a : b";
   let LLVMIntrinsic = int_minnum;
-  let arguments = [LLVMMatchType<0>, LLVMMatchType<0>];
+  let arguments = [overloadTy, overloadTy];
   let result = overloadTy;
   let overloads =
       [Overloads<DXIL1_0, [halfTy, floatTy, doubleTy]>];
@@ -558,7 +558,7 @@ def FMin :  DXILOp<36, binary> {
 def SMax :  DXILOp<37, binary> {
   let Doc = "Signed integer maximum. SMax(a,b) = a > b ? a : b";
   let LLVMIntrinsic = int_smax;
-  let arguments = [LLVMMatchType<0>, LLVMMatchType<0>];
+  let arguments = [overloadTy, overloadTy];
   let result = overloadTy;
   let overloads =
       [Overloads<DXIL1_0, [i16Ty, i32Ty, i64Ty]>];
@@ -569,7 +569,7 @@ def SMax :  DXILOp<37, binary> {
 def SMin :  DXILOp<38, binary> {
   let Doc = "Signed integer minimum. SMin(a,b) = a < b ? a : b";
   let LLVMIntrinsic = int_smin;
-  let arguments = [LLVMMatchType<0>, LLVMMatchType<0>];
+  let arguments = [overloadTy, overloadTy];
   let result = overloadTy;
   let overloads =
       [Overloads<DXIL1_0, [i16Ty, i32Ty, i64Ty]>];
@@ -580,7 +580,7 @@ def SMin :  DXILOp<38, binary> {
 def UMax :  DXILOp<39, binary> {
   let Doc = "Unsigned integer maximum. UMax(a,b) = a > b ? a : b";
   let LLVMIntrinsic = int_umax;
-  let arguments = [LLVMMatchType<0>, LLVMMatchType<0>];
+  let arguments = [overloadTy, overloadTy];
   let result = overloadTy;
   let overloads =
       [Overloads<DXIL1_0, [i16Ty, i32Ty, i64Ty]>];
@@ -591,7 +591,7 @@ def UMax :  DXILOp<39, binary> {
 def UMin :  DXILOp<40, binary> {
   let Doc = "Unsigned integer minimum. UMin(a,b) = a < b ? a : b";
   let LLVMIntrinsic = int_umin;
-  let arguments = [LLVMMatchType<0>, LLVMMatchType<0>];
+  let arguments = [overloadTy, overloadTy];
   let result = overloadTy;
   let overloads =
       [Overloads<DXIL1_0, [i16Ty, i32Ty, i64Ty]>];
@@ -603,7 +603,7 @@ def FMad :  DXILOp<46, tertiary> {
   let Doc = "Floating point arithmetic multiply/add operation. fmad(m,a,b) = m 
"
             "* a + b.";
   let LLVMIntrinsic = int_fmuladd;
-  let arguments = [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>];
+  let arguments = [overloadTy, overloadTy, overloadTy];
   let result = overloadTy;
   let overloads =
       [Overloads<DXIL1_0, [halfTy, floatTy, doubleTy]>];
@@ -615,7 +615,7 @@ def IMad :  DXILOp<48, tertiary> {
   let Doc = "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m 
"
             "* a + b.";
   let LLVMIntrinsic = int_dx_imad;
-  let arguments = [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>];
+  let arguments = [overloadTy, overloadTy, overloadTy];
   let result = overloadTy;
   let overloads =
       [Overloads<DXIL1_0, [i16Ty, i32Ty, i64Ty]>];
@@ -627,7 +627,7 @@ def UMad :  DXILOp<49, tertiary> {
   let Doc = "Unsigned integer arithmetic multiply/add operation. umad(m,a, = m 
"
             "* a + b.";
   let LLVMIntrinsic = int_dx_umad;
-  let arguments = [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>];
+  let arguments = [overloadTy, overloadTy, overloadTy];
   let result = overloadTy;
   let overloads =
       [Overloads<DXIL1_0, [i16Ty, i32Ty, i64Ty]>];
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp 
b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index a03701be743c7..d43ac1119ff48 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -208,8 +208,8 @@ static StructType *getHandleType(LLVMContext &Ctx) {
                                Ctx);
 }
 
-static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
-  auto &Ctx = OverloadTy->getContext();
+static Type *getTypeFromParameterKind(ParameterKind Kind, LLVMContext &Ctx,
+                                      Type *OverloadTy) {
   switch (Kind) {
   case ParameterKind::Void:
     return Type::getVoidTy(Ctx);
@@ -289,24 +289,25 @@ static ShaderKind 
getShaderKindEnum(Triple::EnvironmentType EnvType) {
 ///               its specification in DXIL.td.
 /// \param OverloadTy Return type to be used to construct DXIL function type.
 static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
-                                           Type *ReturnTy, Type *OverloadTy) {
+                                           LLVMContext &Context,
+                                           Type *OverloadTy) {
   SmallVector<Type *> ArgTys;
 
   const ParameterKind *ParamKinds = getOpCodeParameterKind(*Prop);
 
-  // Add ReturnTy as return type of the function
-  ArgTys.emplace_back(ReturnTy);
+  assert(Prop->NumOfParameters && "No return type?");
+  // Add return type of the function
+  Type *ReturnTy = getTypeFromParameterKind(ParamKinds[0], Context, 
OverloadTy);
 
   // Add DXIL Opcode value type viz., Int32 as first argument
-  ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));
+  ArgTys.emplace_back(Type::getInt32Ty(Context));
 
   // Add DXIL Operation parameter types as specified in DXIL properties
-  for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
+  for (unsigned I = 1; I < Prop->NumOfParameters; ++I) {
     ParameterKind Kind = ParamKinds[I];
-    ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
+    ArgTys.emplace_back(getTypeFromParameterKind(Kind, Context, OverloadTy));
   }
-  return FunctionType::get(
-      ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
+  return FunctionType::get(ReturnTy, ArgTys, /*isVarArg=*/false);
 }
 
 /// Get index of the property from PropList valid for the most recent
@@ -347,107 +348,86 @@ DXILOpBuilder::DXILOpBuilder(Module &M, IRBuilderBase 
&B) : M(M), B(B) {
   }
 }
 
-CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
-                                          Type *OverloadTy,
-                                          SmallVector<Value *> Args) {
+static Error makeOpError(dxil::OpCode OpCode, Twine Msg) {
+  return make_error<StringError>(
+      Twine("Cannot create ") + getOpCodeName(OpCode) + " operation: " + Msg,
+      inconvertibleErrorCode());
+}
 
+Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
+                                                ArrayRef<Value *> Args) {
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
+
+  Type *OverloadTy = nullptr;
+  if (Prop->OverloadParamIndex > 0) {
+    // The index counts including the return type
+    unsigned ArgIndex = Prop->OverloadParamIndex - 1;
+    if (static_cast<unsigned>(ArgIndex) >= Args.size())
+      return makeOpError(OpCode, "Wrong number of arguments");
+    OverloadTy = Args[ArgIndex]->getType();
+  }
+  FunctionType *DXILOpFT =
+      getDXILOpFunctionType(Prop, M.getContext(), OverloadTy);
+
   std::optional<size_t> OlIndexOrErr =
       getPropIndex(ArrayRef(Prop->Overloads), DXILVersion);
-  if (!OlIndexOrErr.has_value()) {
-    report_fatal_error(Twine(getOpCodeName(OpCode)) +
-                           ": No valid overloads found for DXIL Version - " +
-                           DXILVersion.getAsString(),
-                       /*gen_crash_diag*/ false);
-  }
+  if (!OlIndexOrErr.has_value())
+    return makeOpError(OpCode, Twine("No valid overloads for DXIL version ") +
+                                   DXILVersion.getAsString());
+
   uint16_t ValidTyMask = Prop->Overloads[*OlIndexOrErr].ValidTys;
 
-  OverloadKind Kind = getOverloadKind(OverloadTy);
+  // If we don't have an overload type, use the function's return type. This is
+  // a bit of a hack, but it's necessary to get the type suffix on unoverloaded
+  // DXIL ops correct, like `dx.op.threadId.i32`.
+  OverloadKind Kind =
+      getOverloadKind(OverloadTy ? OverloadTy : DXILOpFT->getReturnType());
 
   // Check if the operation supports overload types and OverloadTy is valid
   // per the specified types for the operation
   if ((ValidTyMask != OverloadKind::UNDEFINED) &&
-      (ValidTyMask & (uint16_t)Kind) == 0) {
-    report_fatal_error(Twine("Invalid Overload Type for DXIL operation - ") +
-                           getOpCodeName(OpCode),
-                       /* gen_crash_diag=*/false);
-  }
+      (ValidTyMask & (uint16_t)Kind) == 0)
+    return makeOpError(OpCode, "Invalid overload type");
 
   // Perform necessary checks to ensure Opcode is valid in the targeted shader
   // kind
   std::optional<size_t> StIndexOrErr =
       getPropIndex(ArrayRef(Prop->Stages), DXILVersion);
-  if (!StIndexOrErr.has_value()) {
-    report_fatal_error(Twine(getOpCodeName(OpCode)) +
-                           ": No valid stages found for DXIL Version - " +
-                           DXILVersion.getAsString(),
-                       /*gen_crash_diag*/ false);
-  }
+  if (!StIndexOrErr.has_value())
+    return makeOpError(OpCode, Twine("No valid stage for DXIL version ") +
+                                   DXILVersion.getAsString());
+
   uint16_t ValidShaderKindMask = Prop->Stages[*StIndexOrErr].ValidStages;
 
   // Ensure valid shader stage properties are specified
-  if (ValidShaderKindMask == ShaderKind::removed) {
-    report_fatal_error(
-        Twine(DXILVersion.getAsString()) +
-            ": Unsupported Target Shader Stage for DXIL operation - " +
-            getOpCodeName(OpCode),
-        /*gen_crash_diag*/ false);
-  }
+  if (ValidShaderKindMask == ShaderKind::removed)
+    return makeOpError(OpCode, "Operation has been removed");
 
   // Shader stage need not be validated since getShaderKindEnum() fails
   // for unknown shader stage.
 
   // Verify the target shader stage is valid for the DXIL operation
   ShaderKind ModuleStagekind = getShaderKindEnum(ShaderStage);
-  if (!(ValidShaderKindMask & ModuleStagekind)) {
-    auto ShaderEnvStr = Triple::getEnvironmentTypeName(ShaderStage);
-    report_fatal_error(Twine(ShaderEnvStr) +
-                           " : Invalid Shader Stage for DXIL operation - " +
-                           getOpCodeName(OpCode) + " for DXIL Version " +
-                           DXILVersion.getAsString(),
-                       /*gen_crash_diag*/ false);
-  }
+  if (!(ValidShaderKindMask & ModuleStagekind))
+    return makeOpError(OpCode, "Invalid stage");
 
   std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
-  FunctionCallee DXILFn;
-  // Get the function with name DXILFnName, if one exists
-  if (auto *Func = M.getFunction(DXILFnName)) {
-    DXILFn = FunctionCallee(Func);
-  } else {
-    // Construct and add a function with name DXILFnName
-    FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
-    DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
-  }
+  FunctionCallee DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
 
-  return B.CreateCall(DXILFn, Args);
-}
-
-Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
+  // We need to inject the opcode as the first argument.
+  SmallVector<Value *> OpArgs;
+  OpArgs.push_back(B.getInt32(llvm::to_underlying(OpCode)));
+  OpArgs.append(Args.begin(), Args.end());
 
-  const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
-  // If DXIL Op has no overload parameter, just return the
-  // precise return type specified.
-  if (Prop->OverloadParamIndex < 0) {
-    return FT->getReturnType();
-  }
-
-  // Consider FT->getReturnType() as default overload type, unless
-  // Prop->OverloadParamIndex != 0.
-  Type *OverloadType = FT->getReturnType();
-  if (Prop->OverloadParamIndex != 0) {
-    // Skip Return Type.
-    OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1);
-  }
+  return B.CreateCall(DXILFn, OpArgs);
+}
 
-  const ParameterKind *ParamKinds = getOpCodeParameterKind(*Prop);
-  auto Kind = ParamKinds[Prop->OverloadParamIndex];
-  // For ResRet and CBufferRet, OverloadTy is in field of StructType.
-  if (Kind == ParameterKind::CBufferRet ||
-      Kind == ParameterKind::ResourceRet) {
-    auto *ST = cast<StructType>(Ove...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/101250
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to