[llvm] [SPIRV] Add support for arbitrary-precision integers larger than 64 bits in SPIR-V backend (PR #161270)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 6 15:56:44 PDT 2025
https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/161270
>From 4c1d90c01082af32ee28a04edd3a0fc58a4ac1b9 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Sat, 27 Sep 2025 08:58:37 -0700
Subject: [PATCH 1/3] Add support for arbitrary integer with bitwidth larger
than 64 bits in spirv-backend
---
.../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp | 18 ++++++---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 17 ++++----
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 4 +-
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 39 +++++++++----------
llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 11 +++---
.../SPV_INTEL_arbitrary_precision_integers.ll | 4 ++
6 files changed, 52 insertions(+), 41 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 776208bd3e693..dff9f699ebd6f 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -50,18 +50,24 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
- assert((NumVarOps == 1 || NumVarOps == 2) &&
+ // we support integer up to 1024 bits
+ assert((NumVarOps <= 1024) &&
"Unsupported number of bits for literal variable");
O << ' ';
- uint64_t Imm = MI->getOperand(StartIndex).getImm();
-
- // Handle 64 bit literals.
- if (NumVarOps == 2) {
- Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32);
+ // Handle arbitrary number of 32-bit words for the literal value.
+ if (MI->getOpcode() == SPIRV::OpConstantI){
+ APInt Val(NumVarOps * 32, 0);
+ for (unsigned i = 0; i < NumVarOps; ++i) {
+ Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32));
+ }
+ O << Val;
+ return;
}
+ uint64_t Imm = MI->getOperand(StartIndex).getImm();
+
// Format and print float values.
if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 115766ce886c7..05b3371e97cdc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -149,7 +149,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
}
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
- if (Width > 64)
+ if (Width > 1024)
report_fatal_error("Unsupported integer width!");
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
@@ -343,7 +343,7 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
return Res;
}
-Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
@@ -353,10 +353,11 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
MI->getOpcode() == SPIRV::OpConstantI))
return MI->getOperand(0).getReg();
- return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
+ return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull);
}
-Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
+Register SPIRVGlobalRegistry::createConstInt(const Constant *CI,
+ APInt Val,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
@@ -374,15 +375,15 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
MachineInstrBuilder MIB;
if (BitWidth == 1) {
MIB = MIRBuilder
- .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse
+ .buildInstr(Val.isZero() ? SPIRV::OpConstantFalse
: SPIRV::OpConstantTrue)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
- } else if (!CI->isZero() || !ZeroAsNull) {
+ } else if (!Val.isZero() || !ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
- addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
+ addNumImm(Val, MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
@@ -491,7 +492,7 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
}
assert(Type->getOpcode() == SPIRV::OpTypeInt);
SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
- return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
+ return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I,
SpvBaseType, TII, ZeroAsNull);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a648defa0a888..ee217f81fb416 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -515,10 +515,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
bool ZeroAsNull = true);
- Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+ Register getOrCreateConstInt(APInt Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
- Register createConstInt(const ConstantInt *CI, MachineInstr &I,
+ Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull);
Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 1aadd9df189a8..3e5566945ec0b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2252,8 +2252,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
.addDef(AElt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(X)
- .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
// B[i]
@@ -2263,8 +2263,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
.addDef(BElt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Y)
- .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
// A[i] * B[i]
@@ -2283,8 +2283,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
.addDef(MaskMul)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Mul)
- .addUse(GR.getOrCreateConstInt(0, I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
// Acc = Acc + A[i] * B[i]
@@ -2381,7 +2381,7 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
IntTy, TII, !STI.isShader()));
for (unsigned J = 2; J < I.getNumOperands(); J++) {
@@ -2405,7 +2405,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy,
TII, !STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(BallotReg)
@@ -2436,7 +2436,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
!STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(I.getOperand(2).getReg())
@@ -2463,7 +2463,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
!STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(I.getOperand(2).getReg());
@@ -2689,7 +2689,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
bool ZeroAsNull = !STI.isShader();
if (ResType->getOpcode() == SPIRV::OpTypeVector)
return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
- return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
+ return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
}
Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
@@ -2720,7 +2720,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
if (ResType->getOpcode() == SPIRV::OpTypeVector)
return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
- return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII);
+ return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII);
}
bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
@@ -2939,8 +2939,7 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
ResType, TII, !STI.isShader());
} else {
- Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I,
- ResType, TII, !STI.isShader());
+ Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader());
}
return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
}
@@ -3765,7 +3764,7 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
bool ZeroAsNull = !STI.isShader();
Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
Register ConstIntLastIdx = GR.getOrCreateConstInt(
- ComponentCount - 1, I, BaseType, TII, ZeroAsNull);
+ APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull);
if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
SPIRV::OpVectorExtractDynamic))
@@ -3794,9 +3793,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
bool ZeroAsNull = !STI.isShader();
Register ConstIntZero =
- GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull);
Register ConstIntOne =
- GR.getOrCreateConstInt(1, I, BaseType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull);
// SPIRV doesn't support vectors with more than 4 components. Since the
// algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
@@ -3881,9 +3880,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
if (IsScalarRes) {
NegOneReg =
- GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull);
- Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
- Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull);
+ Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
+ Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull);
SelectOp = SPIRV::OpSelectSISCond;
AddOp = SPIRV::OpIAddS;
} else {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 820e56b362edc..e409234a83568 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -100,11 +100,12 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
if (Bitwidth == 16)
MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
return;
- } else if (Bitwidth <= 64) {
- uint64_t FullImm = Imm.getZExtValue();
- uint32_t LowBits = FullImm & 0xffffffff;
- uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
- MIB.addImm(LowBits).addImm(HighBits);
+ } else if (Bitwidth <= 1024) {
+ unsigned NumWords = (Bitwidth + 31) / 32;
+ for (unsigned i = 0; i < NumWords; ++i) {
+ uint32_t Word = Imm.extractBits(32, i * 32).getZExtValue();
+ MIB.addImm(Word);
+ }
return;
}
report_fatal_error("Unsupported constant bitwidth");
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
index 41d4b58ed1157..17ba9b044842c 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
@@ -8,6 +8,10 @@ define i13 @getConstantI13() {
ret i13 42
}
+define i96 @getConstantI96() {
+ ret i96 18446744073709551620
+}
+
;; Capabilities:
; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers"
; CHECK-DAG: OpCapability ArbitraryPrecisionIntegersINTEL
>From 8578a56ff0bd41fc78ed0702f62f1c9d3233968a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 6 Oct 2025 15:22:56 -0700
Subject: [PATCH 2/3] update the test
---
.../SPV_INTEL_arbitrary_precision_integers.ll | 20 ++++++++++++++++++-
1 file changed, 19 insertions(+), 1 deletion(-)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
index 17ba9b044842c..003a900c73770 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
@@ -10,7 +10,11 @@ define i13 @getConstantI13() {
define i96 @getConstantI96() {
ret i96 18446744073709551620
-}
+}
+
+define i160 @getConstantI160() {
+ ret i160 3363637389930338837376336738763689377839373638
+}
;; Capabilities:
; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers"
@@ -21,14 +25,20 @@ define i96 @getConstantI96() {
;; Names:
; CHECK-DAG: OpName %[[#GET_I6:]] "getConstantI6"
; CHECK-DAG: OpName %[[#GET_I13:]] "getConstantI13"
+; CHECK-DAG: OpName %[[#GET_I96:]] "getConstantI96"
+; CHECK-DAG: OpName %[[#GET_I160:]] "getConstantI160"
; CHECK-NOT: DAG-FENCE
;; Types and Constants:
; CHECK-DAG: %[[#I6:]] = OpTypeInt 6 0
; CHECK-DAG: %[[#I13:]] = OpTypeInt 13 0
+; CHECK-DAG: %[[#I96:]] = OpTypeInt 96 0
+; CHECK-DAG: %[[#I160:]] = OpTypeInt 160 0
; CHECK-DAG: %[[#CST_I6:]] = OpConstant %[[#I6]] 2
; CHECK-DAG: %[[#CST_I13:]] = OpConstant %[[#I13]] 42
+; CHECK-DAG: %[[#CST_I96:]] = OpConstant %[[#I96]] 18446744073709551620
+; CHECK-DAG: %[[#CST_I160:]] = OpConstant %[[#I160]] 3363637389930338837376336738763689377839373638
; CHECK: %[[#GET_I6]] = OpFunction %[[#I6]]
; CHECK: OpReturnValue %[[#CST_I6]]
@@ -37,3 +47,11 @@ define i96 @getConstantI96() {
; CHECK: %[[#GET_I13]] = OpFunction %[[#I13]]
; CHECK: OpReturnValue %[[#CST_I13]]
; CHECK: OpFunctionEnd
+
+; CHECK: %[[#GET_I96]] = OpFunction %[[#I96]]
+; CHECK: OpReturnValue %[[#CST_I96]]
+; CHECK: OpFunctionEnd
+
+; CHECK: %[[#GET_I160]] = OpFunction %[[#I160]]
+; CHECK: OpReturnValue %[[#CST_I160]]
+; CHECK: OpFunctionEnd
\ No newline at end of file
>From fd4a78f7119aca99afd7ac924231b10ddfffb28c Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 6 Oct 2025 15:56:33 -0700
Subject: [PATCH 3/3] code clean up
---
.../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp | 9 +-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 8 +-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 4 +-
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 117 ++++++++++--------
.../SPV_INTEL_arbitrary_precision_integers.ll | 2 +-
5 files changed, 79 insertions(+), 61 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index dff9f699ebd6f..9529e18da21c7 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -50,17 +50,18 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
- // we support integer up to 1024 bits
- assert((NumVarOps <= 1024) &&
+ // We support integer up to 1024 bits
+ assert((NumVarOps <= 32) &&
"Unsupported number of bits for literal variable");
O << ' ';
// Handle arbitrary number of 32-bit words for the literal value.
- if (MI->getOpcode() == SPIRV::OpConstantI){
+ if (MI->getOpcode() == SPIRV::OpConstantI) {
APInt Val(NumVarOps * 32, 0);
for (unsigned i = 0; i < NumVarOps; ++i) {
- Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32));
+ Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm())
+ << (i * 32));
}
O << Val;
return;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 05b3371e97cdc..c4f8565f39b84 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -356,8 +356,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I,
return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull);
}
-Register SPIRVGlobalRegistry::createConstInt(const Constant *CI,
- APInt Val,
+Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, APInt Val,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
@@ -492,8 +491,9 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
}
assert(Type->getOpcode() == SPIRV::OpTypeInt);
SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
- return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I,
- SpvBaseType, TII, ZeroAsNull);
+ return getOrCreateConstInt(
+ APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I, SpvBaseType,
+ TII, ZeroAsNull);
}
Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ee217f81fb416..9cb7d982c3fc2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -515,8 +515,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
bool ZeroAsNull = true);
- Register getOrCreateConstInt(APInt Val, MachineInstr &I,
- SPIRVType *SpvType, const SPIRVInstrInfo &TII,
+ Register getOrCreateConstInt(APInt Val, MachineInstr &I, SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3e5566945ec0b..f82ddbc8990b6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2247,33 +2247,36 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
for (unsigned i = 0; i < 4; i++) {
// A[i]
Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
- Result &=
- BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
- .addDef(AElt)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(X)
- .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
- .constrainAllUses(TII, TRI, RBI);
+ Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+ .addDef(AElt)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(X)
+ .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType,
+ TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII,
+ ZeroAsNull))
+ .constrainAllUses(TII, TRI, RBI);
// B[i]
- Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
- Result &=
- BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
- .addDef(BElt)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(Y)
- .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
- .constrainAllUses(TII, TRI, RBI);
+ Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+ .addDef(BElt)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(Y)
+ .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType,
+ TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII,
+ ZeroAsNull))
+ .constrainAllUses(TII, TRI, RBI);
// A[i] * B[i]
- Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
- Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS))
- .addDef(Mul)
+ Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+ .addDef(MaskMul)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(AElt)
- .addUse(BElt)
+ .addUse(Mul)
+ .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII,
+ ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII,
+ ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
// Discard 24 highest-bits so that stored i32 register is i8 equivalent
@@ -2378,11 +2381,12 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
MachineBasicBlock &BB = *I.getParent();
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
- auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
- IntTy, TII, !STI.isShader()));
+ auto BMI =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+ IntTy, TII, !STI.isShader()));
for (unsigned J = 2; J < I.getNumOperands(); J++) {
BMI.addUse(I.getOperand(J).getReg());
@@ -2401,15 +2405,16 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
SPIRV::OpGroupNonUniformBallot);
MachineBasicBlock &BB = *I.getParent();
- Result &= BuildMI(BB, I, I.getDebugLoc(),
- TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy,
- TII, !STI.isShader()))
- .addImm(SPIRV::GroupOperation::Reduce)
- .addUse(BallotReg)
- .constrainAllUses(TII, TRI, RBI);
+ Result &=
+ BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+ IntTy, TII, !STI.isShader()))
+ .addImm(SPIRV::GroupOperation::Reduce)
+ .addUse(BallotReg)
+ .constrainAllUses(TII, TRI, RBI);
return Result;
}
@@ -2436,8 +2441,8 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
- !STI.isShader()))
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+ IntTy, TII, !STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);
@@ -2463,8 +2468,8 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
- !STI.isShader()))
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+ IntTy, TII, !STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(I.getOperand(2).getReg());
}
@@ -2689,7 +2694,8 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
bool ZeroAsNull = !STI.isShader();
if (ResType->getOpcode() == SPIRV::OpTypeVector)
return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
- return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
+ return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0),
+ I, ResType, TII, ZeroAsNull);
}
Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
@@ -2720,7 +2726,9 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
if (ResType->getOpcode() == SPIRV::OpTypeVector)
return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
- return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII);
+ return GR.getOrCreateConstInt(
+ APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I,
+ ResType, TII);
}
bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
@@ -2939,7 +2947,8 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
ResType, TII, !STI.isShader());
} else {
- Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader());
+ Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I,
+ ResType, TII, !STI.isShader());
}
return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
}
@@ -3764,7 +3773,8 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
bool ZeroAsNull = !STI.isShader();
Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
Register ConstIntLastIdx = GR.getOrCreateConstInt(
- APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull);
+ APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I,
+ BaseType, TII, ZeroAsNull);
if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
SPIRV::OpVectorExtractDynamic))
@@ -3793,9 +3803,11 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
bool ZeroAsNull = !STI.isShader();
Register ConstIntZero =
- GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0),
+ I, BaseType, TII, ZeroAsNull);
Register ConstIntOne =
- GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1),
+ I, BaseType, TII, ZeroAsNull);
// SPIRV doesn't support vectors with more than 4 components. Since the
// algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
@@ -3879,10 +3891,15 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
unsigned AddOp;
if (IsScalarRes) {
- NegOneReg =
- GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull);
- Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
- Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull);
+ NegOneReg = GR.getOrCreateConstInt(
+ APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType,
+ TII, ZeroAsNull);
+ Reg0 =
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0),
+ I, ResType, TII, ZeroAsNull);
+ Reg32 =
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32),
+ I, ResType, TII, ZeroAsNull);
SelectOp = SPIRV::OpSelectSISCond;
AddOp = SPIRV::OpIAddS;
} else {
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
index 003a900c73770..23681d660dd20 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
@@ -54,4 +54,4 @@ define i160 @getConstantI160() {
; CHECK: %[[#GET_I160]] = OpFunction %[[#I160]]
; CHECK: OpReturnValue %[[#CST_I160]]
-; CHECK: OpFunctionEnd
\ No newline at end of file
+; CHECK: OpFunctionEnd
More information about the llvm-commits
mailing list