[clang] [llvm] [HLSL] Implement elementwise firstbitlow builtin (PR #116858)

Chris B via cfe-commits cfe-commits at lists.llvm.org
Tue Dec 17 09:11:19 PST 2024


================
@@ -3139,136 +3151,269 @@ Register SPIRVInstructionSelector::buildPointerToResource(
   return AcReg;
 }
 
-bool SPIRVInstructionSelector::selectFirstBitHigh16(Register ResVReg,
-                                                    const SPIRVType *ResType,
-                                                    MachineInstr &I,
-                                                    bool IsSigned) const {
-  unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert;
-  // zero or sign extend
+bool SPIRVInstructionSelector::selectFirstBitSet16(
+    Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
+    unsigned ExtendOpcode, unsigned BitSetOpcode) const {
   Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
-  bool Result =
-      selectOpWithSrcs(ExtReg, ResType, I, {I.getOperand(2).getReg()}, Opcode);
-  return Result && selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
+  bool Result = selectOpWithSrcs(ExtReg, ResType, I, {I.getOperand(2).getReg()},
+                                 ExtendOpcode);
+
+  return Result &&
+         selectFirstBitSet32(ResVReg, ResType, I, ExtReg, BitSetOpcode);
 }
 
-bool SPIRVInstructionSelector::selectFirstBitHigh32(Register ResVReg,
-                                                    const SPIRVType *ResType,
-                                                    MachineInstr &I,
-                                                    Register SrcReg,
-                                                    bool IsSigned) const {
-  unsigned Opcode = IsSigned ? GL::FindSMsb : GL::FindUMsb;
+bool SPIRVInstructionSelector::selectFirstBitSet32(
+    Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
+    Register SrcReg, unsigned BitSetOpcode) const {
   return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
       .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
-      .addImm(Opcode)
+      .addImm(BitSetOpcode)
       .addUse(SrcReg)
       .constrainAllUses(TII, TRI, RBI);
 }
 
-bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
-                                                    const SPIRVType *ResType,
-                                                    MachineInstr &I,
-                                                    bool IsSigned) const {
-  Register OpReg = I.getOperand(2).getReg();
-  // 1. split our int64 into 2 pieces using a bitcast
-  unsigned count = GR.getScalarOrVectorComponentCount(ResType);
-  SPIRVType *baseType = GR.retrieveScalarOrVectorIntType(ResType);
+bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
+    Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
+    Register SrcReg, unsigned BitSetOpcode, bool SwapPrimarySide) const {
+
+  unsigned ComponentCount = GR.getScalarOrVectorComponentCount(ResType);
+  SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
+  bool ZeroAsNull = STI.isOpenCLEnv();
+  Register ConstIntZero =
+      GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
+  unsigned LeftComponentCount = ComponentCount / 2;
+  unsigned RightComponentCount = ComponentCount - LeftComponentCount;
+  bool LeftIsVector = LeftComponentCount > 1;
+
+  // Split the SrcReg in half into 2 smaller vec registers
+  // (ie i64x4 -> i64x2, i64x2)
   MachineIRBuilder MIRBuilder(I);
-  SPIRVType *postCastT =
-      GR.getOrCreateSPIRVVectorType(baseType, 2 * count, MIRBuilder);
-  Register bitcastReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
+  SPIRVType *OpType = GR.getOrCreateSPIRVIntegerType(64, MIRBuilder);
+  SPIRVType *LeftOpType;
+  SPIRVType *LeftResType;
+  if (LeftIsVector) {
+    LeftOpType =
+        GR.getOrCreateSPIRVVectorType(OpType, LeftComponentCount, MIRBuilder);
+    LeftResType =
+        GR.getOrCreateSPIRVVectorType(BaseType, LeftComponentCount, MIRBuilder);
+  } else {
+    LeftOpType = OpType;
+    LeftResType = BaseType;
+  }
+
+  SPIRVType *RightOpType =
+      GR.getOrCreateSPIRVVectorType(OpType, RightComponentCount, MIRBuilder);
+  SPIRVType *RightResType =
+      GR.getOrCreateSPIRVVectorType(BaseType, RightComponentCount, MIRBuilder);
+
+  Register LeftSideIn = MRI->createVirtualRegister(GR.getRegClass(LeftOpType));
+  Register RightSideIn =
+      MRI->createVirtualRegister(GR.getRegClass(RightOpType));
+
+  bool Result;
+
+  // Extract the left half from the SrcReg into LeftSideIn
+  // accounting for the special case when it only has one element
+  if (LeftIsVector) {
+    auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                       TII.get(SPIRV::OpVectorShuffle))
+                   .addDef(LeftSideIn)
+                   .addUse(GR.getSPIRVTypeID(LeftOpType))
+                   .addUse(SrcReg)
+                   // Per the spec, repeat the vector if only one vec is needed
+                   .addUse(SrcReg);
+
+    for (unsigned J = 0; J < LeftComponentCount; J++) {
+      MIB.addImm(J);
+    }
+
+    Result = MIB.constrainAllUses(TII, TRI, RBI);
+  } else {
+    Result = selectOpWithSrcs(LeftSideIn, LeftOpType, I, {SrcReg, ConstIntZero},
+                              SPIRV::OpVectorExtractDynamic);
+  }
+
+  // Extract the right half from the SrcReg into RightSideIn.
+  // Right will always be a vector since the only time one element is left is
+  // when Component == 3, and in that case Left is one element.
+  auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                     TII.get(SPIRV::OpVectorShuffle))
+                 .addDef(RightSideIn)
+                 .addUse(GR.getSPIRVTypeID(RightOpType))
+                 .addUse(SrcReg)
+                 // Per the spec, repeat the vector if only one vec is needed
+                 .addUse(SrcReg);
+
+  for (unsigned J = LeftComponentCount; J < ComponentCount; J++) {
+    MIB.addImm(J);
+  }
+
+  Result = Result && MIB.constrainAllUses(TII, TRI, RBI);
+
+  // Recursively call selectFirstBitSet64 on the 2 halves
+  Register LeftSideOut =
+      MRI->createVirtualRegister(GR.getRegClass(LeftResType));
+  Register RightSideOut =
+      MRI->createVirtualRegister(GR.getRegClass(RightResType));
+  Result =
+      Result && selectFirstBitSet64(LeftSideOut, LeftResType, I, LeftSideIn,
+                                    BitSetOpcode, SwapPrimarySide);
+  Result =
+      Result && selectFirstBitSet64(RightSideOut, RightResType, I, RightSideIn,
+                                    BitSetOpcode, SwapPrimarySide);
+
+  // Join the two resulting registers back into the return type
+  // (ie i32x2, i32x2 -> i32x4)
+  return Result &&
+         selectOpWithSrcs(ResVReg, ResType, I, {LeftSideOut, RightSideOut},
+                          SPIRV::OpCompositeConstruct);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitSet64(
+    Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
+    Register SrcReg, unsigned BitSetOpcode, bool SwapPrimarySide) const {
+  unsigned ComponentCount = GR.getScalarOrVectorComponentCount(ResType);
+  SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
+  bool ZeroAsNull = STI.isOpenCLEnv();
+  Register ConstIntZero =
+      GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
+  Register ConstIntOne =
+      GR.getOrCreateConstInt(1, I, BaseType, TII, ZeroAsNull);
+
+  // SPIRV doesn't support vectors with more than 4 components. Since the
+  // algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
+  // operate on vectors with 2 or less components. When largers vectors are
+  // seen. Split them, recurse, then recombine them.
+  if (ComponentCount > 2) {
+    return selectFirstBitSet64Overflow(ResVReg, ResType, I, SrcReg,
+                                       BitSetOpcode, SwapPrimarySide);
+  }
+
+  // 1. Split int64 into 2 pieces using a bitcast
+  MachineIRBuilder MIRBuilder(I);
+  SPIRVType *PostCastType =
+      GR.getOrCreateSPIRVVectorType(BaseType, 2 * ComponentCount, MIRBuilder);
+  Register BitcastReg =
+      MRI->createVirtualRegister(GR.getRegClass(PostCastType));
   bool Result =
-      selectOpWithSrcs(bitcastReg, postCastT, I, {OpReg}, SPIRV::OpBitcast);
+      selectOpWithSrcs(BitcastReg, PostCastType, I, {SrcReg}, SPIRV::OpBitcast);
 
-  // 2. call firstbithigh
-  Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
-  Result &= selectFirstBitHigh32(FBHReg, postCastT, I, bitcastReg, IsSigned);
+  // 2. Find the first set bit from the primary side for all the pieces in #1
+  Register FBSReg = MRI->createVirtualRegister(GR.getRegClass(PostCastType));
+  Result = Result && selectFirstBitSet32(FBSReg, PostCastType, I, BitcastReg,
+                                         BitSetOpcode);
 
-  // 3. split result vector into high bits and low bits
+  // 3. Split result vector into high bits and low bits
   Register HighReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
   Register LowReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
 
-  bool ZeroAsNull = STI.isOpenCLEnv();
-  bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
-  if (isScalarRes) {
+  bool IsScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
+  if (IsScalarRes) {
     // if scalar do a vector extract
-    Result &= selectOpWithSrcs(
-        HighReg, ResType, I,
-        {FBHReg, GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull)},
-        SPIRV::OpVectorExtractDynamic);
-    Result &= selectOpWithSrcs(
-        LowReg, ResType, I,
-        {FBHReg, GR.getOrCreateConstInt(1, I, ResType, TII, ZeroAsNull)},
-        SPIRV::OpVectorExtractDynamic);
-  } else { // vector case do a shufflevector
+    Result =
+        Result && selectOpWithSrcs(HighReg, ResType, I, {FBSReg, ConstIntZero},
+                                   SPIRV::OpVectorExtractDynamic);
+    Result =
+        Result && selectOpWithSrcs(LowReg, ResType, I, {FBSReg, ConstIntOne},
+                                   SPIRV::OpVectorExtractDynamic);
+  } else {
+    // if vector do a shufflevector
     auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
                        TII.get(SPIRV::OpVectorShuffle))
                    .addDef(HighReg)
                    .addUse(GR.getSPIRVTypeID(ResType))
-                   .addUse(FBHReg)
-                   .addUse(FBHReg);
-    // ^^ this vector will not be selected from; could be empty
-    unsigned j;
-    for (j = 0; j < count * 2; j += 2) {
-      MIB.addImm(j);
+                   .addUse(FBSReg)
+                   // Per the spec, repeat the vector if only one vec is needed
+                   .addUse(FBSReg);
+
+    // high bits are stored in even indexes. Extract them from FBSReg
+    for (unsigned J = 0; J < ComponentCount * 2; J += 2) {
+      MIB.addImm(J);
     }
-    Result &= MIB.constrainAllUses(TII, TRI, RBI);
+    Result = Result && MIB.constrainAllUses(TII, TRI, RBI);
 
-    // get low bits
     MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
                   TII.get(SPIRV::OpVectorShuffle))
               .addDef(LowReg)
               .addUse(GR.getSPIRVTypeID(ResType))
-              .addUse(FBHReg)
-              .addUse(FBHReg);
-    // ^^ this vector will not be selected from; could be empty
-    for (j = 1; j < count * 2; j += 2) {
-      MIB.addImm(j);
+              .addUse(FBSReg)
+              // Per the spec, repeat the vector if only one vec is needed
+              .addUse(FBSReg);
+
+    // low bits are stored in odd indexes. Extract them from FBSReg
+    for (unsigned J = 1; J < ComponentCount * 2; J += 2) {
+      MIB.addImm(J);
     }
-    Result &= MIB.constrainAllUses(TII, TRI, RBI);
+    Result = Result && MIB.constrainAllUses(TII, TRI, RBI);
   }
 
-  // 4. check if result of each top 32 bits is == -1
+  // 4. Check the result. When primary bits == -1 use secondary, otherwise use
+  // primary
   SPIRVType *BoolType = GR.getOrCreateSPIRVBoolType(I, TII);
   Register NegOneReg;
   Register Reg0;
   Register Reg32;
-  unsigned selectOp;
-  unsigned addOp;
-  if (isScalarRes) {
+  unsigned SelectOp;
+  unsigned AddOp;
+
+  if (IsScalarRes) {
     NegOneReg =
         GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull);
     Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
     Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull);
-    selectOp = SPIRV::OpSelectSISCond;
-    addOp = SPIRV::OpIAddS;
+    SelectOp = SPIRV::OpSelectSISCond;
+    AddOp = SPIRV::OpIAddS;
   } else {
-    BoolType = GR.getOrCreateSPIRVVectorType(BoolType, count, MIRBuilder);
+    BoolType =
+        GR.getOrCreateSPIRVVectorType(BoolType, ComponentCount, MIRBuilder);
     NegOneReg =
         GR.getOrCreateConstVector((unsigned)-1, I, ResType, TII, ZeroAsNull);
     Reg0 = GR.getOrCreateConstVector(0, I, ResType, TII, ZeroAsNull);
     Reg32 = GR.getOrCreateConstVector(32, I, ResType, TII, ZeroAsNull);
-    selectOp = SPIRV::OpSelectVIVCond;
-    addOp = SPIRV::OpIAddV;
+    SelectOp = SPIRV::OpSelectVIVCond;
+    AddOp = SPIRV::OpIAddV;
+  }
+
+  Register PrimaryReg;
+  Register SecondaryReg;
+  Register PrimaryShiftReg;
+  Register SecondaryShiftReg;
+
+  // By default the emitted opcodes check for the set bit from the MSB side.
+  // Setting SwapPrimarySide checks the set bit from the LSB side
+  if (SwapPrimarySide) {
----------------
llvm-beanz wrote:

nit: I would either write these as ternary selections on the initializer or default them one way and have an if to swap.

https://github.com/llvm/llvm-project/pull/116858


More information about the cfe-commits mailing list