[clang] [llvm] [HLSL] implement elementwise firstbithigh hlsl builtin (PR #111082)

Farzon Lotfi via cfe-commits cfe-commits at lists.llvm.org
Tue Oct 22 10:45:25 PDT 2024


================
@@ -2626,6 +2671,148 @@ Register SPIRVInstructionSelector::buildPointerToResource(
                                                  MIRBuilder);
 }
 
+bool SPIRVInstructionSelector::selectFirstBitHigh16(Register ResVReg,
+                                                    const SPIRVType *ResType,
+                                                    MachineInstr &I,
+                                                    bool IsSigned) const {
+  unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert;
+  // zero or sign extend
+  Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+  bool Result =
+      selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(), Opcode);
+  return Result & selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh32(Register ResVReg,
+                                                    const SPIRVType *ResType,
+                                                    MachineInstr &I,
+                                                    Register SrcReg,
+                                                    bool IsSigned) const {
+  unsigned Opcode = IsSigned ? GL::FindSMsb : GL::FindUMsb;
+  return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(Opcode)
+      .addUse(SrcReg)
+      .constrainAllUses(TII, TRI, RBI);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
+                                                    const SPIRVType *ResType,
+                                                    MachineInstr &I,
+                                                    bool IsSigned) const {
+  Register OpReg = I.getOperand(2).getReg();
+  // 1. split our int64 into 2 pieces using a bitcast
+  unsigned count = GR.getScalarOrVectorComponentCount(ResType);
+  SPIRVType *baseType = GR.retrieveScalarOrVectorIntType(ResType);
+  MachineIRBuilder MIRBuilder(I);
+  SPIRVType *postCastT =
+      GR.getOrCreateSPIRVVectorType(baseType, 2 * count, MIRBuilder);
+  Register bitcastReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
+  bool Result =
+      selectUnOpWithSrc(bitcastReg, postCastT, I, OpReg, SPIRV::OpBitcast);
+
+  // 2. call firstbithigh
+  Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
+  Result &= selectFirstBitHigh32(FBHReg, postCastT, I, bitcastReg, IsSigned);
+
+  // 3. check if result of each top 32 bits is == -1
+  // split result vector into vector of high bits and vector of low bits
+  // get high bits
+  // if ResType is a scalar we need a vector anyways because our code
+  // operates on vectors, even vectors of length one.
+  SPIRVType *VResType = ResType;
+  bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
+  if (isScalarRes)
+    VResType = GR.getOrCreateSPIRVVectorType(ResType, count, MIRBuilder);
+  // count should be one.
+
+  Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+  auto MIB =
+      BuildMI(*I.getParent(), I, I.getDebugLoc(),
+              TII.get(SPIRV::OpVectorShuffle))
+          .addDef(HighReg)
+          .addUse(GR.getSPIRVTypeID(VResType))
+          .addUse(FBHReg)
+          .addUse(
+              FBHReg); // this vector will not be selected from; could be empty
+  unsigned i;
+  for (i = 0; i < count * 2; i += 2) {
+    MIB.addImm(i);
+  }
+  Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+  // get low bits
+  Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+  MIB =
+      BuildMI(*I.getParent(), I, I.getDebugLoc(),
+              TII.get(SPIRV::OpVectorShuffle))
+          .addDef(LowReg)
+          .addUse(GR.getSPIRVTypeID(VResType))
+          .addUse(FBHReg)
+          .addUse(
+              FBHReg); // this vector will not be selected from; could be empty
+  for (i = 1; i < count * 2; i += 2) {
+    MIB.addImm(i);
+  }
+  Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+  SPIRVType *BoolType = GR.getOrCreateSPIRVVectorType(
+      GR.getOrCreateSPIRVBoolType(I, TII), count, MIRBuilder);
+  // check if the high bits are == -1;
+  Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
+  // true if -1
+  Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
+  Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
+                                 SPIRV::OpIEqual);
+
+  // Select low bits if true in BReg, otherwise high bits
+  Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+  Result &= selectNAryOpWithSrcs(TmpReg, VResType, I, {BReg, LowReg, HighReg},
+                                 SPIRV::OpSelectVIVCond);
+
+  // Add 32 for high bits, 0 for low bits
+  Register ValReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+  bool ZeroAsNull = STI.isOpenCLEnv();
+  Register Reg32 = GR.getOrCreateConstVector(32, I, VResType, TII, ZeroAsNull);
+  Register Reg0 = GR.getOrCreateConstVector(0, I, VResType, TII, ZeroAsNull);
+  Result &= selectNAryOpWithSrcs(ValReg, VResType, I, {BReg, Reg0, Reg32},
+                                 SPIRV::OpSelectVIVCond);
+
+  Register AddReg = ResVReg;
+  if (isScalarRes)
+    AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+  Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
+                                 SPIRV::OpIAddV);
+
+  // convert result back to scalar if necessary
+  if (!isScalarRes)
+    return Result;
+  else
+    return Result & selectNAryOpWithSrcs(
+                        ResVReg, ResType, I,
+                        {AddReg, GR.getOrCreateConstInt(0, I, ResType, TII)},
+                        SPIRV::OpVectorExtractDynamic);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
+                                                  const SPIRVType *ResType,
+                                                  MachineInstr &I,
+                                                  bool IsSigned) const {
+  // FindUMsb intrinsic only supports 32 bit integers
+  Register OpReg = I.getOperand(2).getReg();
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
+  unsigned bitWidth = GR.getScalarOrVectorBitWidth(OpType);
+
+  if (bitWidth == 16)
----------------
farzonl wrote:

Maybe this would be better as a switch statement. Then we can be more explicit about checking the 64 bit check and if it isn't one of these three bitwidth we can throw an error.

https://github.com/llvm/llvm-project/pull/111082


More information about the cfe-commits mailing list