[llvm] [Spirv][HLSL] Add OpAll lowering and float vec support (PR #87952)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 8 04:15:06 PDT 2024


================
@@ -1155,6 +1160,62 @@ static unsigned getBoolCmpOpcode(unsigned PredNum) {
   }
 }
 
+bool SPIRVInstructionSelector::selectAll(Register ResVReg,
+                                         const SPIRVType *ResType,
+                                         MachineInstr &I) const {
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+  Register InputRegister = I.getOperand(2).getReg();
+  SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
+  bool IsBoolTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeBool);
+  bool IsVectorTy = InputType->getOpcode() == SPIRV::OpTypeVector;
+  if (IsBoolTy && !IsVectorTy) {
+    assert(ResVReg == I.getOperand(0).getReg());
+    return BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                   TII.get(TargetOpcode::COPY))
+        .addDef(ResVReg)
+        .addUse(InputRegister)
+        .constrainAllUses(TII, TRI, RBI);
+  }
+
+  bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat);
+  unsigned SpirvNotEqualId =
+      IsFloatTy ? SPIRV::OpFOrdNotEqual : SPIRV::OpINotEqual;
+  SPIRVType *SpvBoolScalarTy = GR.getOrCreateSPIRVBoolType(I, TII);
+  SPIRVType *SpvBoolTy = SpvBoolScalarTy;
+  Register NotEqualReg = ResVReg;
+
+  if (IsVectorTy) {
+    NotEqualReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    const unsigned NumElts = InputType->getOperand(2).getImm();
+    SpvBoolTy = GR.getOrCreateSPIRVVectorType(SpvBoolTy, NumElts, I, TII);
+  }
+
+  if (!IsBoolTy) {
+    Register ConstCompositeZeroReg =
+        IsFloatTy ? buildZerosValF(InputType, I) : buildZerosVal(InputType, I);
+
+    BuildMI(BB, I, I.getDebugLoc(), TII.get(SpirvNotEqualId))
+        .addDef(NotEqualReg)
+        .addUse(GR.getSPIRVTypeID(SpvBoolTy))
+        .addUse(InputRegister)
+        .addUse(ConstCompositeZeroReg)
+        .constrainAllUses(TII, TRI, RBI);
+  } else {
+    NotEqualReg = InputRegister;
----------------
VyacheslavLevytskyy wrote:

If it was a vector of bool, we shouldn't probable call `NotEqualReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);` in line 1190 earlier and allocate a new virtual register to rewrite it here.

https://github.com/llvm/llvm-project/pull/87952


More information about the llvm-commits mailing list