[llvm] [SPIRV] Add support for arbitrary-precision integers larger than 64 bits in SPIR-V backend (PR #161270)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 6 15:56:44 PDT 2025


https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/161270

>From 4c1d90c01082af32ee28a04edd3a0fc58a4ac1b9 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Sat, 27 Sep 2025 08:58:37 -0700
Subject: [PATCH 1/3] Add support for arbitrary integer with bitwidth larger
 than 64 bits in spirv-backend

---
 .../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp   | 18 ++++++---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 17 ++++----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  4 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 39 +++++++++----------
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          | 11 +++---
 .../SPV_INTEL_arbitrary_precision_integers.ll |  4 ++
 6 files changed, 52 insertions(+), 41 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 776208bd3e693..dff9f699ebd6f 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -50,18 +50,24 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
   unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
   const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
 
-  assert((NumVarOps == 1 || NumVarOps == 2) &&
+  // we support integer up to 1024 bits
+  assert((NumVarOps <= 1024) &&
          "Unsupported number of bits for literal variable");
 
   O << ' ';
 
-  uint64_t Imm = MI->getOperand(StartIndex).getImm();
-
-  // Handle 64 bit literals.
-  if (NumVarOps == 2) {
-    Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32);
+  // Handle arbitrary number of 32-bit words for the literal value.
+  if (MI->getOpcode() == SPIRV::OpConstantI){
+    APInt Val(NumVarOps * 32, 0);
+    for (unsigned i = 0; i < NumVarOps; ++i) {
+      Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32));
+    }
+    O << Val;
+    return;
   }
 
+  uint64_t Imm = MI->getOperand(StartIndex).getImm();
+
   // Format and print float values.
   if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
     APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 115766ce886c7..05b3371e97cdc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -149,7 +149,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
 }
 
 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
-  if (Width > 64)
+  if (Width > 1024)
     report_fatal_error("Unsupported integer width!");
   const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
   if (ST.canUseExtension(
@@ -343,7 +343,7 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
   return Res;
 }
 
-Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I,
                                                   SPIRVType *SpvType,
                                                   const SPIRVInstrInfo &TII,
                                                   bool ZeroAsNull) {
@@ -353,10 +353,11 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
   if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
              MI->getOpcode() == SPIRV::OpConstantI))
     return MI->getOperand(0).getReg();
-  return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
+  return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull);
 }
 
-Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
+Register SPIRVGlobalRegistry::createConstInt(const Constant *CI,
+                                             APInt Val,
                                              MachineInstr &I,
                                              SPIRVType *SpvType,
                                              const SPIRVInstrInfo &TII,
@@ -374,15 +375,15 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
         MachineInstrBuilder MIB;
         if (BitWidth == 1) {
           MIB = MIRBuilder
-                    .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse
+                    .buildInstr(Val.isZero() ? SPIRV::OpConstantFalse
                                              : SPIRV::OpConstantTrue)
                     .addDef(Res)
                     .addUse(getSPIRVTypeID(SpvType));
-        } else if (!CI->isZero() || !ZeroAsNull) {
+        } else if (!Val.isZero() || !ZeroAsNull) {
           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
                     .addDef(Res)
                     .addUse(getSPIRVTypeID(SpvType));
-          addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
+          addNumImm(Val, MIB);
         } else {
           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
                     .addDef(Res)
@@ -491,7 +492,7 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
   }
   assert(Type->getOpcode() == SPIRV::OpTypeInt);
   SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
-  return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
+  return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I,
                              SpvBaseType, TII, ZeroAsNull);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a648defa0a888..ee217f81fb416 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -515,10 +515,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
                             SPIRVType *SpvType, bool EmitIR,
                             bool ZeroAsNull = true);
-  Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+  Register getOrCreateConstInt(APInt Val, MachineInstr &I,
                                SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                bool ZeroAsNull = true);
-  Register createConstInt(const ConstantInt *CI, MachineInstr &I,
+  Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I,
                           SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                           bool ZeroAsNull);
   Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 1aadd9df189a8..3e5566945ec0b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2252,8 +2252,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
             .addDef(AElt)
             .addUse(GR.getSPIRVTypeID(ResType))
             .addUse(X)
-            .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
             .constrainAllUses(TII, TRI, RBI);
 
     // B[i]
@@ -2263,8 +2263,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
             .addDef(BElt)
             .addUse(GR.getSPIRVTypeID(ResType))
             .addUse(Y)
-            .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
             .constrainAllUses(TII, TRI, RBI);
 
     // A[i] * B[i]
@@ -2283,8 +2283,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
             .addDef(MaskMul)
             .addUse(GR.getSPIRVTypeID(ResType))
             .addUse(Mul)
-            .addUse(GR.getOrCreateConstInt(0, I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
             .constrainAllUses(TII, TRI, RBI);
 
     // Acc = Acc + A[i] * B[i]
@@ -2381,7 +2381,7 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
   auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
                  .addDef(ResVReg)
                  .addUse(GR.getSPIRVTypeID(ResType))
-                 .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
+                 .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
                                                 IntTy, TII, !STI.isShader()));
 
   for (unsigned J = 2; J < I.getNumOperands(); J++) {
@@ -2405,7 +2405,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
                     TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
                 .addDef(ResVReg)
                 .addUse(GR.getSPIRVTypeID(ResType))
-                .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy,
+                .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy,
                                                TII, !STI.isShader()))
                 .addImm(SPIRV::GroupOperation::Reduce)
                 .addUse(BallotReg)
@@ -2436,7 +2436,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
                                      !STI.isShader()))
       .addImm(SPIRV::GroupOperation::Reduce)
       .addUse(I.getOperand(2).getReg())
@@ -2463,7 +2463,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
                                      !STI.isShader()))
       .addImm(SPIRV::GroupOperation::Reduce)
       .addUse(I.getOperand(2).getReg());
@@ -2689,7 +2689,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
   bool ZeroAsNull = !STI.isShader();
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
     return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
-  return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
+  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
 }
 
 Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
@@ -2720,7 +2720,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
       AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
     return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
-  return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII);
+  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII);
 }
 
 bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
@@ -2939,8 +2939,7 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
     Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
                                 ResType, TII, !STI.isShader());
   } else {
-    Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I,
-                                 ResType, TII, !STI.isShader());
+    Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader());
   }
   return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
 }
@@ -3765,7 +3764,7 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
     bool ZeroAsNull = !STI.isShader();
     Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
     Register ConstIntLastIdx = GR.getOrCreateConstInt(
-        ComponentCount - 1, I, BaseType, TII, ZeroAsNull);
+        APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull);
 
     if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
                           SPIRV::OpVectorExtractDynamic))
@@ -3794,9 +3793,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
   SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
   bool ZeroAsNull = !STI.isShader();
   Register ConstIntZero =
-      GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
+      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull);
   Register ConstIntOne =
-      GR.getOrCreateConstInt(1, I, BaseType, TII, ZeroAsNull);
+      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull);
 
   // SPIRV doesn't support vectors with more than 4 components. Since the
   // algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
@@ -3881,9 +3880,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
 
   if (IsScalarRes) {
     NegOneReg =
-        GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull);
-    Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
-    Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull);
+        GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull);
+    Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
+    Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull);
     SelectOp = SPIRV::OpSelectSISCond;
     AddOp = SPIRV::OpIAddS;
   } else {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 820e56b362edc..e409234a83568 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -100,11 +100,12 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
     if (Bitwidth == 16)
       MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
     return;
-  } else if (Bitwidth <= 64) {
-    uint64_t FullImm = Imm.getZExtValue();
-    uint32_t LowBits = FullImm & 0xffffffff;
-    uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
-    MIB.addImm(LowBits).addImm(HighBits);
+  } else if (Bitwidth <= 1024) {
+    unsigned NumWords = (Bitwidth + 31) / 32;
+    for (unsigned i = 0; i < NumWords; ++i) {
+      uint32_t Word = Imm.extractBits(32, i * 32).getZExtValue();
+      MIB.addImm(Word);
+    }
     return;
   }
   report_fatal_error("Unsupported constant bitwidth");
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
index 41d4b58ed1157..17ba9b044842c 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
@@ -8,6 +8,10 @@ define i13 @getConstantI13() {
   ret i13 42
 }
 
+define i96 @getConstantI96() {
+  ret i96 18446744073709551620
+} 
+
 ;; Capabilities:
 ; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers"
 ; CHECK-DAG: OpCapability ArbitraryPrecisionIntegersINTEL

>From 8578a56ff0bd41fc78ed0702f62f1c9d3233968a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 6 Oct 2025 15:22:56 -0700
Subject: [PATCH 2/3] update the test

---
 .../SPV_INTEL_arbitrary_precision_integers.ll | 20 ++++++++++++++++++-
 1 file changed, 19 insertions(+), 1 deletion(-)

diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
index 17ba9b044842c..003a900c73770 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
@@ -10,7 +10,11 @@ define i13 @getConstantI13() {
 
 define i96 @getConstantI96() {
   ret i96 18446744073709551620
-} 
+}
+
+define i160 @getConstantI160() {
+  ret i160 3363637389930338837376336738763689377839373638
+}
 
 ;; Capabilities:
 ; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers"
@@ -21,14 +25,20 @@ define i96 @getConstantI96() {
 ;; Names:
 ; CHECK-DAG: OpName %[[#GET_I6:]] "getConstantI6"
 ; CHECK-DAG: OpName %[[#GET_I13:]] "getConstantI13"
+; CHECK-DAG: OpName %[[#GET_I96:]] "getConstantI96"
+; CHECK-DAG: OpName %[[#GET_I160:]] "getConstantI160"
 
 ; CHECK-NOT: DAG-FENCE
 
 ;; Types and Constants:
 ; CHECK-DAG: %[[#I6:]] = OpTypeInt 6 0
 ; CHECK-DAG: %[[#I13:]] = OpTypeInt 13 0
+; CHECK-DAG: %[[#I96:]] = OpTypeInt 96 0
+; CHECK-DAG: %[[#I160:]] = OpTypeInt 160 0
 ; CHECK-DAG: %[[#CST_I6:]] = OpConstant %[[#I6]] 2
 ; CHECK-DAG: %[[#CST_I13:]] = OpConstant %[[#I13]] 42
+; CHECK-DAG: %[[#CST_I96:]] = OpConstant %[[#I96]] 18446744073709551620
+; CHECK-DAG: %[[#CST_I160:]] = OpConstant %[[#I160]] 3363637389930338837376336738763689377839373638
 
 ; CHECK: %[[#GET_I6]] = OpFunction %[[#I6]]
 ; CHECK: OpReturnValue %[[#CST_I6]]
@@ -37,3 +47,11 @@ define i96 @getConstantI96() {
 ; CHECK: %[[#GET_I13]] = OpFunction %[[#I13]]
 ; CHECK: OpReturnValue %[[#CST_I13]]
 ; CHECK: OpFunctionEnd
+
+; CHECK: %[[#GET_I96]] = OpFunction %[[#I96]]
+; CHECK: OpReturnValue %[[#CST_I96]]
+; CHECK: OpFunctionEnd
+
+; CHECK: %[[#GET_I160]] = OpFunction %[[#I160]]
+; CHECK: OpReturnValue %[[#CST_I160]]
+; CHECK: OpFunctionEnd
\ No newline at end of file

>From fd4a78f7119aca99afd7ac924231b10ddfffb28c Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 6 Oct 2025 15:56:33 -0700
Subject: [PATCH 3/3] code clean up

---
 .../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp   |   9 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |   8 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |   4 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 117 ++++++++++--------
 .../SPV_INTEL_arbitrary_precision_integers.ll |   2 +-
 5 files changed, 79 insertions(+), 61 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index dff9f699ebd6f..9529e18da21c7 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -50,17 +50,18 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
   unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
   const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
 
-  // we support integer up to 1024 bits
-  assert((NumVarOps <= 1024) &&
+  // We support integer up to 1024 bits
+  assert((NumVarOps <= 32) &&
          "Unsupported number of bits for literal variable");
 
   O << ' ';
 
   // Handle arbitrary number of 32-bit words for the literal value.
-  if (MI->getOpcode() == SPIRV::OpConstantI){
+  if (MI->getOpcode() == SPIRV::OpConstantI) {
     APInt Val(NumVarOps * 32, 0);
     for (unsigned i = 0; i < NumVarOps; ++i) {
-      Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32));
+      Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm())
+              << (i * 32));
     }
     O << Val;
     return;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 05b3371e97cdc..c4f8565f39b84 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -356,8 +356,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I,
   return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull);
 }
 
-Register SPIRVGlobalRegistry::createConstInt(const Constant *CI,
-                                             APInt Val,
+Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, APInt Val,
                                              MachineInstr &I,
                                              SPIRVType *SpvType,
                                              const SPIRVInstrInfo &TII,
@@ -492,8 +491,9 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
   }
   assert(Type->getOpcode() == SPIRV::OpTypeInt);
   SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
-  return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I,
-                             SpvBaseType, TII, ZeroAsNull);
+  return getOrCreateConstInt(
+      APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I, SpvBaseType,
+      TII, ZeroAsNull);
 }
 
 Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ee217f81fb416..9cb7d982c3fc2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -515,8 +515,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
                             SPIRVType *SpvType, bool EmitIR,
                             bool ZeroAsNull = true);
-  Register getOrCreateConstInt(APInt Val, MachineInstr &I,
-                               SPIRVType *SpvType, const SPIRVInstrInfo &TII,
+  Register getOrCreateConstInt(APInt Val, MachineInstr &I, SPIRVType *SpvType,
+                               const SPIRVInstrInfo &TII,
                                bool ZeroAsNull = true);
   Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I,
                           SPIRVType *SpvType, const SPIRVInstrInfo &TII,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3e5566945ec0b..f82ddbc8990b6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2247,33 +2247,36 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
   for (unsigned i = 0; i < 4; i++) {
     // A[i]
     Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result &=
-        BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
-            .addDef(AElt)
-            .addUse(GR.getSPIRVTypeID(ResType))
-            .addUse(X)
-            .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
-            .constrainAllUses(TII, TRI, RBI);
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+                  .addDef(AElt)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(X)
+                  .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType,
+                                                 TII, ZeroAsNull))
+                  .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII,
+                                                 ZeroAsNull))
+                  .constrainAllUses(TII, TRI, RBI);
 
     // B[i]
-    Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result &=
-        BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
-            .addDef(BElt)
-            .addUse(GR.getSPIRVTypeID(ResType))
-            .addUse(Y)
-            .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
-            .constrainAllUses(TII, TRI, RBI);
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+                  .addDef(BElt)
+                  .addUse(GR.getSPIRVTypeID(ResType))
+                  .addUse(Y)
+                  .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType,
+                                                 TII, ZeroAsNull))
+                  .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII,
+                                                 ZeroAsNull))
+                  .constrainAllUses(TII, TRI, RBI);
 
     // A[i] * B[i]
-    Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS))
-                  .addDef(Mul)
+    Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp))
+                  .addDef(MaskMul)
                   .addUse(GR.getSPIRVTypeID(ResType))
-                  .addUse(AElt)
-                  .addUse(BElt)
+                  .addUse(Mul)
+                  .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII,
+                                                 ZeroAsNull))
+                  .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII,
+                                                 ZeroAsNull))
                   .constrainAllUses(TII, TRI, RBI);
 
     // Discard 24 highest-bits so that stored i32 register is i8 equivalent
@@ -2378,11 +2381,12 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
   MachineBasicBlock &BB = *I.getParent();
   SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
 
-  auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
-                 .addDef(ResVReg)
-                 .addUse(GR.getSPIRVTypeID(ResType))
-                 .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
-                                                IntTy, TII, !STI.isShader()));
+  auto BMI =
+      BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
+          .addDef(ResVReg)
+          .addUse(GR.getSPIRVTypeID(ResType))
+          .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+                                         IntTy, TII, !STI.isShader()));
 
   for (unsigned J = 2; J < I.getNumOperands(); J++) {
     BMI.addUse(I.getOperand(J).getReg());
@@ -2401,15 +2405,16 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
                                  SPIRV::OpGroupNonUniformBallot);
 
   MachineBasicBlock &BB = *I.getParent();
-  Result &= BuildMI(BB, I, I.getDebugLoc(),
-                    TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
-                .addDef(ResVReg)
-                .addUse(GR.getSPIRVTypeID(ResType))
-                .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy,
-                                               TII, !STI.isShader()))
-                .addImm(SPIRV::GroupOperation::Reduce)
-                .addUse(BallotReg)
-                .constrainAllUses(TII, TRI, RBI);
+  Result &=
+      BuildMI(BB, I, I.getDebugLoc(),
+              TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
+          .addDef(ResVReg)
+          .addUse(GR.getSPIRVTypeID(ResType))
+          .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+                                         IntTy, TII, !STI.isShader()))
+          .addImm(SPIRV::GroupOperation::Reduce)
+          .addUse(BallotReg)
+          .constrainAllUses(TII, TRI, RBI);
 
   return Result;
 }
@@ -2436,8 +2441,8 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
-                                     !STI.isShader()))
+      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+                                     IntTy, TII, !STI.isShader()))
       .addImm(SPIRV::GroupOperation::Reduce)
       .addUse(I.getOperand(2).getReg())
       .constrainAllUses(TII, TRI, RBI);
@@ -2463,8 +2468,8 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
-                                     !STI.isShader()))
+      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
+                                     IntTy, TII, !STI.isShader()))
       .addImm(SPIRV::GroupOperation::Reduce)
       .addUse(I.getOperand(2).getReg());
 }
@@ -2689,7 +2694,8 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
   bool ZeroAsNull = !STI.isShader();
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
     return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
-  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
+  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0),
+                                I, ResType, TII, ZeroAsNull);
 }
 
 Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
@@ -2720,7 +2726,9 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
       AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
     return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
-  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII);
+  return GR.getOrCreateConstInt(
+      APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I,
+      ResType, TII);
 }
 
 bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
@@ -2939,7 +2947,8 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
     Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
                                 ResType, TII, !STI.isShader());
   } else {
-    Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader());
+    Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I,
+                                 ResType, TII, !STI.isShader());
   }
   return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
 }
@@ -3764,7 +3773,8 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
     bool ZeroAsNull = !STI.isShader();
     Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
     Register ConstIntLastIdx = GR.getOrCreateConstInt(
-        APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull);
+        APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I,
+        BaseType, TII, ZeroAsNull);
 
     if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
                           SPIRV::OpVectorExtractDynamic))
@@ -3793,9 +3803,11 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
   SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
   bool ZeroAsNull = !STI.isShader();
   Register ConstIntZero =
-      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull);
+      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0),
+                             I, BaseType, TII, ZeroAsNull);
   Register ConstIntOne =
-      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull);
+      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1),
+                             I, BaseType, TII, ZeroAsNull);
 
   // SPIRV doesn't support vectors with more than 4 components. Since the
   // algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
@@ -3879,10 +3891,15 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
   unsigned AddOp;
 
   if (IsScalarRes) {
-    NegOneReg =
-        GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull);
-    Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
-    Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull);
+    NegOneReg = GR.getOrCreateConstInt(
+        APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType,
+        TII, ZeroAsNull);
+    Reg0 =
+        GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0),
+                               I, ResType, TII, ZeroAsNull);
+    Reg32 =
+        GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32),
+                               I, ResType, TII, ZeroAsNull);
     SelectOp = SPIRV::OpSelectSISCond;
     AddOp = SPIRV::OpIAddS;
   } else {
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
index 003a900c73770..23681d660dd20 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
@@ -54,4 +54,4 @@ define i160 @getConstantI160() {
 
 ; CHECK: %[[#GET_I160]] = OpFunction %[[#I160]]
 ; CHECK: OpReturnValue %[[#CST_I160]]
-; CHECK: OpFunctionEnd
\ No newline at end of file
+; CHECK: OpFunctionEnd



More information about the llvm-commits mailing list