================
@@ -2817,6 +2833,98 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return true;
}
+unsigned getVectorSizeOrOne(SPIRVTypeInst Type) {
+
+ if (Type->getOpcode() != SPIRV::OpTypeVector)
+ return 1;
+
+ // Operand(2) is the vector size
+ return Type->getOperand(2).getImm();
+}
+
+bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
+ SPIRVTypeInst ResType,
+ MachineInstr &I) const
{
+ MachineBasicBlock &BB = *I.getParent();
+ const DebugLoc &DL = I.getDebugLoc();
+
+ // Input to the intrinsic
+ Register InputReg = I.getOperand(2).getReg();
+ SPIRVTypeInst InputType = GR.getSPIRVTypeForVReg(InputReg);
+
+ // Determine if input is vector
+ unsigned NumElems = getVectorSizeOrOne(InputType);
+ bool IsVector = NumElems > 1;
+
+ // Determine element types
+ SPIRVTypeInst ElemInputType = InputType;
+ SPIRVTypeInst ElemBoolType = ResType;
+ if (IsVector) {
+ ElemInputType = GR.getSPIRVTypeForVReg(InputType->getOperand(1).getReg());
+ ElemBoolType = GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg());
+ }
+
+ // Subgroup scope constant
+ SPIRVTypeInst IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+ Register ScopeConst = GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
IntTy,
+ TII, !STI.isShader());
+
+ // === Scalar case ===
+ if (!IsVector) {
+ BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ElemBoolType))
+ .addUse(ScopeConst)
+ .addUse(InputReg)
+ .constrainAllUses(TII, TRI, RBI);
+ return true;
+ }
+
+ // === Vector case ===
+ SmallVector<Register, 4> ElementResults;
+ ElementResults.reserve(NumElems);
+
+ for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
+ // Extract element
+ Register ElemInput = InputReg;
+ Register Extracted =
+ MRI->createVirtualRegister(GR.getRegClass(ElemInputType));
+
+ BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeExtract))
+ .addDef(Extracted)
+ .addUse(GR.getSPIRVTypeID(ElemInputType))
+ .addUse(InputReg)
+ .addImm(Idx)
+ .constrainAllUses(TII, TRI, RBI);
+
+ ElemInput = Extracted;
+
+ // Emit per-element AllEqual
+ Register ElemResult =
+ MRI->createVirtualRegister(GR.getRegClass(ElemBoolType));
+
+ BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual))
+ .addDef(ElemResult)
+ .addUse(GR.getSPIRVTypeID(ElemBoolType))
+ .addUse(ScopeConst)
+ .addUse(ElemInput)
+ .constrainAllUses(TII, TRI, RBI);
----------------
bob80905 wrote:
Yeah, this one won't work. the register number trips things up.
```
# .---command stderr------------
# |
# | # After InstructionSelect
# | # Machine code for function test_vhalf: IsSSA, TracksLiveness, Legalized,
Selected
# |
# | bb.1.entry:
# | %2:type = OpTypeVector %1:type, 4
# | %4:type = OpTypeBool
# | %5:type = OpTypeVector %4:type, 4
# | %6:type = OpTypeFunction %5:type, %2:type
# | %10:type = OpTypeInt 32, 0
# | %12:iid = OpConstantI %10:type, 3
# | %1:type = OpTypeFloat 16
# | OpName %0:vfid, 2019910262, 29296
# | %3:iid = OpFunction %5:type, 0, %6:type
# | %0:vfid = OpFunctionParameter %2:type
# | OpName %3:iid, 1953719668, 1634235999, 26220
# | OpDecorate %3:iid, 41, 1953719668, 1634235999, 26220, 0
# | %13:fid = OpCompositeExtract %1:type, %0:vfid, 0
# | %14:iid = OpGroupNonUniformAllEqual %4:type, %12:iid, %0:vfid
# | OpReturnValue %9:id
# |
# | # End machine code for function test_vhalf.
# |
# | *** Bad machine code: Reading virtual register without a def ***
# | - function: test_vhalf
# | - basic block: %bb.1 entry (0x173173f39c0)
# | - instruction: OpReturnValue %9:id
# | - operand 0: %9:id
# | LLVM ERROR: Found 1 machine code errors.
```
The way it is currently done, Extracted is assigned to ElemInput, which is what
we want. By calling this function, we don't have control of which register gets
added to the use, and the wrong register number gets generated. Extracted's
assignment to ElemInput properly updates the register number.
https://github.com/llvm/llvm-project/pull/183634
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits