[clang] [llvm] [HLSL] implement elementwise firstbithigh hlsl builtin (PR #111082)
Farzon Lotfi via cfe-commits
cfe-commits at lists.llvm.org
Tue Oct 22 10:45:25 PDT 2024
================
@@ -2626,6 +2671,148 @@ Register SPIRVInstructionSelector::buildPointerToResource(
MIRBuilder);
}
+bool SPIRVInstructionSelector::selectFirstBitHigh16(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
+ unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert;
+ // zero or sign extend
+ Register ExtReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+ bool Result =
+ selectUnOpWithSrc(ExtReg, ResType, I, I.getOperand(2).getReg(), Opcode);
+ return Result & selectFirstBitHigh32(ResVReg, ResType, I, ExtReg, IsSigned);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh32(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ Register SrcReg,
+ bool IsSigned) const {
+ unsigned Opcode = IsSigned ? GL::FindSMsb : GL::FindUMsb;
+ return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(Opcode)
+ .addUse(SrcReg)
+ .constrainAllUses(TII, TRI, RBI);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
+ Register OpReg = I.getOperand(2).getReg();
+ // 1. split our int64 into 2 pieces using a bitcast
+ unsigned count = GR.getScalarOrVectorComponentCount(ResType);
+ SPIRVType *baseType = GR.retrieveScalarOrVectorIntType(ResType);
+ MachineIRBuilder MIRBuilder(I);
+ SPIRVType *postCastT =
+ GR.getOrCreateSPIRVVectorType(baseType, 2 * count, MIRBuilder);
+ Register bitcastReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
+ bool Result =
+ selectUnOpWithSrc(bitcastReg, postCastT, I, OpReg, SPIRV::OpBitcast);
+
+ // 2. call firstbithigh
+ Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
+ Result &= selectFirstBitHigh32(FBHReg, postCastT, I, bitcastReg, IsSigned);
+
+ // 3. check if result of each top 32 bits is == -1
+ // split result vector into vector of high bits and vector of low bits
+ // get high bits
+ // if ResType is a scalar we need a vector anyways because our code
+ // operates on vectors, even vectors of length one.
+ SPIRVType *VResType = ResType;
+ bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
+ if (isScalarRes)
+ VResType = GR.getOrCreateSPIRVVectorType(ResType, count, MIRBuilder);
+ // count should be one.
+
+ Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+ auto MIB =
+ BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpVectorShuffle))
+ .addDef(HighReg)
+ .addUse(GR.getSPIRVTypeID(VResType))
+ .addUse(FBHReg)
+ .addUse(
+ FBHReg); // this vector will not be selected from; could be empty
+ unsigned i;
+ for (i = 0; i < count * 2; i += 2) {
+ MIB.addImm(i);
+ }
+ Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+ // get low bits
+ Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+ MIB =
+ BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpVectorShuffle))
+ .addDef(LowReg)
+ .addUse(GR.getSPIRVTypeID(VResType))
+ .addUse(FBHReg)
+ .addUse(
+ FBHReg); // this vector will not be selected from; could be empty
+ for (i = 1; i < count * 2; i += 2) {
+ MIB.addImm(i);
+ }
+ Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+ SPIRVType *BoolType = GR.getOrCreateSPIRVVectorType(
+ GR.getOrCreateSPIRVBoolType(I, TII), count, MIRBuilder);
+ // check if the high bits are == -1;
+ Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
+ // true if -1
+ Register BReg = MRI->createVirtualRegister(GR.getRegClass(BoolType));
+ Result &= selectNAryOpWithSrcs(BReg, BoolType, I, {HighReg, NegOneReg},
+ SPIRV::OpIEqual);
+
+ // Select low bits if true in BReg, otherwise high bits
+ Register TmpReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+ Result &= selectNAryOpWithSrcs(TmpReg, VResType, I, {BReg, LowReg, HighReg},
+ SPIRV::OpSelectVIVCond);
+
+ // Add 32 for high bits, 0 for low bits
+ Register ValReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+ bool ZeroAsNull = STI.isOpenCLEnv();
+ Register Reg32 = GR.getOrCreateConstVector(32, I, VResType, TII, ZeroAsNull);
+ Register Reg0 = GR.getOrCreateConstVector(0, I, VResType, TII, ZeroAsNull);
+ Result &= selectNAryOpWithSrcs(ValReg, VResType, I, {BReg, Reg0, Reg32},
+ SPIRV::OpSelectVIVCond);
+
+ Register AddReg = ResVReg;
+ if (isScalarRes)
+ AddReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
+ Result &= selectNAryOpWithSrcs(AddReg, VResType, I, {ValReg, TmpReg},
+ SPIRV::OpIAddV);
+
+ // convert result back to scalar if necessary
+ if (!isScalarRes)
+ return Result;
+ else
+ return Result & selectNAryOpWithSrcs(
+ ResVReg, ResType, I,
+ {AddReg, GR.getOrCreateConstInt(0, I, ResType, TII)},
+ SPIRV::OpVectorExtractDynamic);
+}
+
+bool SPIRVInstructionSelector::selectFirstBitHigh(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I,
+ bool IsSigned) const {
+ // FindUMsb intrinsic only supports 32 bit integers
+ Register OpReg = I.getOperand(2).getReg();
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
+ unsigned bitWidth = GR.getScalarOrVectorBitWidth(OpType);
+
+ if (bitWidth == 16)
----------------
farzonl wrote:
Maybe this would be better as a switch statement. Then we can be more explicit about checking the 64 bit check and if it isn't one of these three bitwidth we can throw an error.
https://github.com/llvm/llvm-project/pull/111082
More information about the cfe-commits
mailing list