[clang] [llvm] [HLSL] Implement elementwise firstbitlow builtin (PR #116858)

Ashley Coleman via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 22 12:23:41 PST 2024


================
@@ -3158,6 +3172,166 @@ bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
   }
 }
 
+bool SPIRVInstructionSelector::selectFirstBitLow16(Register ResVReg,
+                                                   const SPIRVType *ResType,
+                                                   MachineInstr &I) const {
+  // OpUConvert treats the operand bits as an unsigned i16 and zero extends it
+  // to an unsigned i32. As this leaves all the least significant bits unchanged
+  // the first set bit from the LSB side doesn't change.
+  Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+  bool Result = selectNAryOpWithSrcs(
+      ExtReg, ResType, I, {I.getOperand(2).getReg()}, SPIRV::OpUConvert);
+  return Result && selectFirstBitLow32(ResVReg, ResType, I, ExtReg);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitLow32(Register ResVReg,
+                                                   const SPIRVType *ResType,
+                                                   MachineInstr &I,
+                                                   Register SrcReg) const {
+  return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(GL::FindILsb)
+      .addUse(SrcReg)
+      .constrainAllUses(TII, TRI, RBI);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitLow64(Register ResVReg,
+                                                   const SPIRVType *ResType,
+                                                   MachineInstr &I) const {
+  Register OpReg = I.getOperand(2).getReg();
+
+  // 1. Split int64 into 2 pieces using a bitcast
+  unsigned ComponentCount = GR.getScalarOrVectorComponentCount(ResType);
+  SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
+  MachineIRBuilder MIRBuilder(I);
+  SPIRVType *PostCastType =
+      GR.getOrCreateSPIRVVectorType(BaseType, 2 * ComponentCount, MIRBuilder);
+  Register BitcastReg =
+      MRI->createVirtualRegister(GR.getRegClass(PostCastType));
+  bool Result =
+      selectUnOpWithSrc(BitcastReg, PostCastType, I, OpReg, SPIRV::OpBitcast);
+
+  // 2. Find the first set bit from the LSB side for all the pieces in #1
+  Register FBLReg = MRI->createVirtualRegister(GR.getRegClass(PostCastType));
+  Result = Result && selectFirstBitLow32(FBLReg, PostCastType, I, BitcastReg);
+
+  // 3. Split result vector into high bits and low bits
+  Register HighReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+  Register LowReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+
+  bool ZeroAsNull = STI.isOpenCLEnv();
+  bool IsScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
+  if (IsScalarRes) {
+    // if scalar do a vector extract
+    Result =
+        Result &&
+        selectNAryOpWithSrcs(
+            HighReg, ResType, I,
+            {FBLReg, GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull)},
+            SPIRV::OpVectorExtractDynamic);
+    Result =
+        Result &&
+        selectNAryOpWithSrcs(
+            LowReg, ResType, I,
+            {FBLReg, GR.getOrCreateConstInt(1, I, ResType, TII, ZeroAsNull)},
+            SPIRV::OpVectorExtractDynamic);
+  } else {
+    // if vector do a shufflevector
+    auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                       TII.get(SPIRV::OpVectorShuffle))
+                   .addDef(HighReg)
+                   .addUse(GR.getSPIRVTypeID(ResType))
+                   .addUse(FBLReg)
+                   // Per the spec, repeat the vector if only one vec is needed
+                   .addUse(FBLReg);
+
+    // high bits are stored in even indexes. Extract them from FBLReg
+    for (unsigned j = 0; j < ComponentCount * 2; j += 2) {
+      MIB.addImm(j);
+    }
+    Result = Result && MIB.constrainAllUses(TII, TRI, RBI);
+
+    MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                  TII.get(SPIRV::OpVectorShuffle))
+              .addDef(LowReg)
+              .addUse(GR.getSPIRVTypeID(ResType))
+              .addUse(FBLReg)
+              // Per the spec, repeat the vector if only one vec is needed
+              .addUse(FBLReg);
+
+    // low bits are stored in odd indexes. Extract them from FBLReg
+    for (unsigned j = 1; j < ComponentCount * 2; j += 2) {
+      MIB.addImm(j);
+    }
+    Result = Result && MIB.constrainAllUses(TII, TRI, RBI);
+  }
+
+  // 4. Check the result. When low bits == -1 use high, otherwise use low
----------------
V-FEXrt wrote:

I couldn't figure out how to make the swapping work since the shared function is actually what assigns High/Low, but I added a flag to the shared function to swap the regs. Lmk if that's good enough or if there is something I'm missing!

https://github.com/llvm/llvm-project/pull/116858


More information about the llvm-commits mailing list