[llvm] [NFC][SPIRV] Re-work extension parsing (PR #171826)

Alex Voicu via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 11 05:11:56 PST 2025


https://github.com/AlexVlx created https://github.com/llvm/llvm-project/pull/171826

This changes the extension parsing mechanism underpinning `--spirv-ext` to be more explicit about what it is doing and not rely on a sort. More specifically, we partition extensions into enabled (prefixed with `+`) and others, and then individually handle the resulting ranges. 

>From d63fa6f3a04f17ce99073e9f1f70dfbdf7ee1c42 Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Fri, 5 Dec 2025 03:29:18 +0000
Subject: [PATCH 1/5] Add partial support for `int128`.

---
 .../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp   |  8 +++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 23 +++----
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  3 +-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 11 +++
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  7 ++
 .../i128-addsub.ll                            | 67 +++++++++++++++++++
 .../i128-arith.ll                             | 27 ++++++++
 .../i128-switch-lower.ll                      | 27 ++++++++
 8 files changed, 158 insertions(+), 15 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-addsub.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-arith.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-switch-lower.ll

diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 62f5e47c5ea3b..42de884840dc9 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -50,6 +50,14 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
   unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
   const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
 
+  if (MI->getOpcode() == SPIRV::OpConstantI && NumVarOps > 2) {
+    // SPV_ALTERA_arbitrary_precision_integers allows for integer widths greater
+    // than 64, which will be encoded via multiple operands.
+    for (unsigned I = StartIndex; I != MI->getNumOperands(); ++I)
+      O << ' ' << MI->getOperand(I).getImm();
+    return;
+  }
+
   assert((NumVarOps == 1 || NumVarOps == 2) &&
          "Unsupported number of bits for literal variable");
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0fb44052527f0..693795d09dacd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -151,22 +151,17 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
 }
 
 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
-  if (Width > 64)
-    report_fatal_error("Unsupported integer width!");
-  const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
-  if (ST.canUseExtension(
-          SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers) ||
-      ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
-    return Width;
   if (Width <= 8)
-    Width = 8;
+    return 8;
   else if (Width <= 16)
-    Width = 16;
+    return 16;
   else if (Width <= 32)
-    Width = 32;
-  else
-    Width = 64;
-  return Width;
+    return 32;
+  else if (Width <= 64)
+    return 64;
+  else if (Width <= 128)
+    return 128;
+  reportFatalUsageError("Unsupported Integer width!");
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
@@ -413,7 +408,7 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
                     .addDef(Res)
                     .addUse(getSPIRVTypeID(SpvType));
-          addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
+          addNumImm(CI->getValue(), MIB);
         } else {
           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
                     .addDef(Res)
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index b5912c27316c9..38d4f7d706083 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -48,6 +48,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   const LLT s16 = LLT::scalar(16);
   const LLT s32 = LLT::scalar(32);
   const LLT s64 = LLT::scalar(64);
+  const LLT s128 = LLT::scalar(128);
 
   const LLT v16s64 = LLT::fixed_vector(16, 64);
   const LLT v16s32 = LLT::fixed_vector(16, 32);
@@ -307,7 +308,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
           typeInSet(1, allPtrsScalarsAndVectors)));
 
   getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
-      .legalFor({s1})
+      .legalFor({s1, s128})
       .legalFor(allFloatAndIntScalarsAndPtrs)
       .legalFor(allowedVectorTypes)
       .moreElementsToNextPow2(0)
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 2feb73d8dedfa..832139f7d7354 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1395,6 +1395,17 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Int16);
     else if (BitWidth == 8)
       Reqs.addCapability(SPIRV::Capability::Int8);
+    else {
+      if (!ST.canUseExtension(
+              SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers))
+        reportFatalUsageError(
+            "OpTypeInt type with a width other than 8, 16, 32 or 64 bits "
+            "requires the following SPIR-V extension: "
+            "SPV_ALTERA_arbitrary_precision_integers");
+      Reqs.addExtension(
+              SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers);
+      Reqs.addCapability(SPIRV::Capability::ArbitraryPrecisionIntegersALTERA);
+    }
     break;
   }
   case SPIRV::OpDot: {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 7fdb0fafa3719..6feae31c8370f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -108,6 +108,13 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
     // Asm Printer needs this info to print 64-bit operands correctly
     MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH64);
     return;
+  } else if (Bitwidth <= 128) {
+    uint32_t LowBits = Imm.getRawData()[0] & 0xffffffff;
+    uint32_t MidBits0 = (Imm.getRawData()[0] >> 32) & 0xffffffff;
+    uint32_t MidBits1 = Imm.getRawData()[1] & 0xffffffff;
+    uint32_t HighBits = (Imm.getRawData()[1] >> 32) & 0xffffffff;
+    MIB.addImm(LowBits).addImm(MidBits0).addImm(MidBits1).addImm(HighBits);
+    return;
   }
   report_fatal_error("Unsupported constant bitwidth");
 }
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-addsub.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-addsub.ll
new file mode 100644
index 0000000000000..c90ffdd17996c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-addsub.ll
@@ -0,0 +1,67 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: OpTypeInt type with a width other than 8, 16, 32 or 64 bits requires the following SPIR-V extension: SPV_ALTERA_arbitrary_precision_integers
+
+; CHECK: OpCapability ArbitraryPrecisionIntegersALTERA
+; CHECK: OpExtension "SPV_ALTERA_arbitrary_precision_integers"
+; CHECK: OpName %[[#TestAdd:]] "test_add"
+; CHECK: OpName %[[#TestSub:]] "test_sub"
+; CHECK: %[[#Int128Ty:]] = OpTypeInt 128 0
+; CHECK: %[[#Const64Int128:]] = OpConstant %[[#Int128Ty]] 64 0 0 0
+
+; CHECK: %[[#TestAdd]] = OpFunction
+define spir_func void @test_add(i64 %AL, i64 %AH, i64 %BL, i64 %BH, ptr %RL, ptr %RH) {
+entry:
+; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
+; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
+; CHECK: {{.*}} = OpShiftLeftLogical %[[#Int128Ty]] {{%[0-9]+}} %[[#Const64Int128]]
+; CHECK: {{.*}} = OpBitwiseOr %[[#Int128Ty]]
+; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
+; CHECK: {{.*}} = OpIAdd %[[#Int128Ty]]
+	%tmp1 = zext i64 %AL to i128
+	%tmp23 = zext i64 %AH to i128
+	%tmp4 = shl i128 %tmp23, 64
+	%tmp5 = or i128 %tmp4, %tmp1
+	%tmp67 = zext i64 %BL to i128
+	%tmp89 = zext i64 %BH to i128
+	%tmp11 = shl i128 %tmp89, 64
+	%tmp12 = or i128 %tmp11, %tmp67
+	%tmp15 = add i128 %tmp12, %tmp5
+	%tmp1617 = trunc i128 %tmp15 to i64
+	store i64 %tmp1617, ptr %RL
+	%tmp21 = lshr i128 %tmp15, 64
+	%tmp2122 = trunc i128 %tmp21 to i64
+	store i64 %tmp2122, ptr %RH
+	ret void
+; CHECK: OpFunctionEnd
+}
+
+; CHECK: %[[#TestSub]] = OpFunction
+define spir_func void @test_sub(i64 %AL, i64 %AH, i64 %BL, i64 %BH, ptr %RL, ptr %RH) {
+entry:
+; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
+; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
+; CHECK: {{.*}} = OpShiftLeftLogical %[[#Int128Ty]] {{%[0-9]+}} %[[#Const64Int128]]
+; CHECK: {{.*}} = OpBitwiseOr %[[#Int128Ty]]
+; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
+; CHECK: {{.*}} = OpISub %[[#Int128Ty]]
+	%tmp1 = zext i64 %AL to i128
+	%tmp23 = zext i64 %AH to i128
+	%tmp4 = shl i128 %tmp23, 64
+	%tmp5 = or i128 %tmp4, %tmp1
+	%tmp67 = zext i64 %BL to i128
+	%tmp89 = zext i64 %BH to i128
+	%tmp11 = shl i128 %tmp89, 64
+	%tmp12 = or i128 %tmp11, %tmp67
+	%tmp15 = sub i128 %tmp5, %tmp12
+	%tmp1617 = trunc i128 %tmp15 to i64
+	store i64 %tmp1617, ptr %RL
+	%tmp21 = lshr i128 %tmp15, 64
+	%tmp2122 = trunc i128 %tmp21 to i64
+	store i64 %tmp2122, ptr %RH
+	ret void
+; CHECK: OpFunctionEnd
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-arith.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-arith.ll
new file mode 100644
index 0000000000000..d1de5df499566
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-arith.ll
@@ -0,0 +1,27 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: OpTypeInt type with a width other than 8, 16, 32 or 64 bits requires the following SPIR-V extension: SPV_ALTERA_arbitrary_precision_integers
+
+; CHECK: OpCapability ArbitraryPrecisionIntegersALTERA
+; CHECK: OpExtension "SPV_ALTERA_arbitrary_precision_integers"
+; CHECK: OpName %[[#Foo:]] "foo"
+; CHECK: %[[#Int128Ty:]] = OpTypeInt 128 0
+
+; CHECK: %[[#Foo]] = OpFunction
+define i64 @foo(i64 %x, i64 %y, i32 %amt) {
+; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
+; CHECK: {{.*}} = OpSConvert %[[#Int128Ty]]
+; CHECK: {{.*}} = OpBitwiseOr %[[#Int128Ty]]
+; CHECK: {{.*}} = OpUConvert %[[#Int128Ty]]
+; CHECK: {{.*}} = OpShiftRightLogical %[[#Int128Ty]]
+  %tmp0 = zext i64 %x to i128
+  %tmp1 = sext i64 %y to i128
+  %tmp2 = or i128 %tmp0, %tmp1
+  %tmp7 = zext i32 13 to i128
+  %tmp3 = lshr i128 %tmp2, %tmp7
+  %tmp4 = trunc i128 %tmp3 to i64
+  ret i64 %tmp4
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-switch-lower.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-switch-lower.ll
new file mode 100644
index 0000000000000..669e9362605a5
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_ALTERA_arbitrary_precision_integers/i128-switch-lower.ll
@@ -0,0 +1,27 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_ALTERA_arbitrary_precision_integers %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: OpTypeInt type with a width other than 8, 16, 32 or 64 bits requires the following SPIR-V extension: SPV_ALTERA_arbitrary_precision_integers
+
+; CHECK: OpCapability ArbitraryPrecisionIntegersALTERA
+; CHECK: OpExtension "SPV_ALTERA_arbitrary_precision_integers"
+; CHECK: OpName %[[#Test:]] "test"
+; CHECK: OpName %[[#Exit:]] "exit"
+; CHECK: %[[#Int128Ty:]] = OpTypeInt 128 0
+; CHECK: %[[#UndefInt128:]] = OpUndef %[[#Int128Ty]]
+
+; CHECK: %[[#Test]] = OpFunction
+define void @test() {
+entry:
+; CHECK: OpSwitch %[[#UndefInt128]] %[[#Exit]] 0 0 3 0 %[[#Exit]] 0 0 5 0 %[[#Exit]] 0 0 4 0 %[[#Exit]] 0 0 8 0 %[[#Exit]]
+  switch i128 poison, label %exit [
+    i128 55340232221128654848, label %exit
+    i128 92233720368547758080, label %exit
+    i128 73786976294838206464, label %exit
+    i128 147573952589676412928, label %exit
+  ]
+exit:
+  unreachable
+}

>From ec9c0e7588792b928f9278d07d35fbf30831a191 Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Fri, 5 Dec 2025 13:16:44 +0000
Subject: [PATCH 2/5] Fix bug.

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 832139f7d7354..b6a48e3b6c9ee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1395,7 +1395,7 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Int16);
     else if (BitWidth == 8)
       Reqs.addCapability(SPIRV::Capability::Int8);
-    else {
+    else if (BitWidth != 32) {
       if (!ST.canUseExtension(
               SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers))
         reportFatalUsageError(
@@ -1403,7 +1403,7 @@ void addInstrRequirements(const MachineInstr &MI,
             "requires the following SPIR-V extension: "
             "SPV_ALTERA_arbitrary_precision_integers");
       Reqs.addExtension(
-              SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers);
+          SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers);
       Reqs.addCapability(SPIRV::Capability::ArbitraryPrecisionIntegersALTERA);
     }
     break;

>From 75d378b509442d24fd4a1f32f0050c992159ee67 Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Sat, 6 Dec 2025 01:22:12 +0000
Subject: [PATCH 3/5] Fix more bugs.

---
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp           | 12 +++++++++++-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp        |  5 +++++
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp        |  6 +++++-
 .../extensions/enable-all-extensions-but-one.ll      |  1 +
 4 files changed, 22 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index d2a8fddc5d8e4..4731da2fc1fd1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -176,7 +176,17 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   std::set<SPIRV::Extension::Extension> &Vals) {
   SmallVector<StringRef, 10> Tokens;
   ArgValue.split(Tokens, ",", -1, false);
-  std::sort(Tokens.begin(), Tokens.end());
+  std::sort(Tokens.begin(), Tokens.end(), [](auto &&LHS, auto &&RHS) {
+    // We want to ensure that we handle "all" first, to ensure that any
+    // subsequent disablement actually behaves as expected i.e. given
+    // --spv-ext=all,-foo, we first enable all and then disable foo; this should
+    // be revisited and simplified.
+    if (LHS == "all")
+      return true;
+    if (RHS == "all")
+      return false;
+    return !(RHS < LHS);
+  });
 
   std::set<SPIRV::Extension::Extension> EnabledExtensions;
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 693795d09dacd..5d96a67500dff 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -151,6 +151,11 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
 }
 
 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
+  const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
+  if (ST.canUseExtension(
+          SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers) ||
+      (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)))
+    return Width;
   if (Width <= 8)
     return 8;
   else if (Width <= 16)
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index c107cd9904420..7717f18d14a34 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1436,7 +1436,11 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Int16);
     else if (BitWidth == 8)
       Reqs.addCapability(SPIRV::Capability::Int8);
-    else if (BitWidth != 32) {
+    else if (BitWidth == 4 &&
+             ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_int4);
+      Reqs.addCapability(SPIRV::Capability::Int4TypeINTEL);
+    } else if (BitWidth != 32) {
       if (!ST.canUseExtension(
               SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers))
         reportFatalUsageError(
diff --git a/llvm/test/CodeGen/SPIRV/extensions/enable-all-extensions-but-one.ll b/llvm/test/CodeGen/SPIRV/extensions/enable-all-extensions-but-one.ll
index 5ddfc85702540..ecf4807a1d5fc 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/enable-all-extensions-but-one.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/enable-all-extensions-but-one.ll
@@ -1,4 +1,5 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=all,-SPV_ALTERA_arbitrary_precision_integers %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=-SPV_ALTERA_arbitrary_precision_integers,all %s -o - | FileCheck %s
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=KHR %s -o - | FileCheck %s
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=khr %s -o - | FileCheck %s
 

>From 2e6927f5cd8271a85f19685e2eabbe7458d1456d Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Tue, 9 Dec 2025 15:10:52 +0000
Subject: [PATCH 4/5] Use llvm:sort

---
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 4731da2fc1fd1..42edad255ce82 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -13,9 +13,14 @@
 
 #include "SPIRVCommandLine.h"
 #include "MCTargetDesc/SPIRVBaseInfo.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/TargetParser/Triple.h"
-#include <algorithm>
+
+#include <functional>
 #include <map>
+#include <string>
+#include <utility>
+#include <vector>
 
 #define DEBUG_TYPE "spirv-commandline"
 
@@ -176,7 +181,7 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   std::set<SPIRV::Extension::Extension> &Vals) {
   SmallVector<StringRef, 10> Tokens;
   ArgValue.split(Tokens, ",", -1, false);
-  std::sort(Tokens.begin(), Tokens.end(), [](auto &&LHS, auto &&RHS) {
+  llvm::sort(Tokens, [](auto &&LHS, auto &&RHS) {
     // We want to ensure that we handle "all" first, to ensure that any
     // subsequent disablement actually behaves as expected i.e. given
     // --spv-ext=all,-foo, we first enable all and then disable foo; this should

>From 80ee3f69bff9edbd670e8f0bbfd8b9204a2df962 Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Thu, 11 Dec 2025 13:04:08 +0000
Subject: [PATCH 5/5] We don't need a sort to handle extension parsing.

---
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 63 +++++++++++-----------
 1 file changed, 30 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 42edad255ce82..04c54f9b0e53d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -17,7 +17,9 @@
 #include "llvm/TargetParser/Triple.h"
 
 #include <functional>
+#include <iterator>
 #include <map>
+#include <set>
 #include <string>
 #include <utility>
 #include <vector>
@@ -26,7 +28,7 @@
 
 using namespace llvm;
 
-static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
+static const std::map<StringRef, SPIRV::Extension::Extension>
     SPIRVExtensionMap = {
         {"SPV_EXT_shader_atomic_float_add",
          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_add},
@@ -181,57 +183,52 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   std::set<SPIRV::Extension::Extension> &Vals) {
   SmallVector<StringRef, 10> Tokens;
   ArgValue.split(Tokens, ",", -1, false);
-  llvm::sort(Tokens, [](auto &&LHS, auto &&RHS) {
-    // We want to ensure that we handle "all" first, to ensure that any
-    // subsequent disablement actually behaves as expected i.e. given
-    // --spv-ext=all,-foo, we first enable all and then disable foo; this should
-    // be revisited and simplified.
-    if (LHS == "all")
-      return true;
-    if (RHS == "all")
-      return false;
-    return !(RHS < LHS);
-  });
 
   std::set<SPIRV::Extension::Extension> EnabledExtensions;
 
-  for (const auto &Token : Tokens) {
-    if (Token == "all") {
-      for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
-        EnabledExtensions.insert(ExtensionEnum);
+  auto M = partition(Tokens, [](auto &&T) { return T.starts_with('+'); });
+
+  if (std::any_of(M, Tokens.end(), [](auto &&T) { return T == "all"; }))
+    copy(make_second_range(SPIRVExtensionMap), std::inserter(Vals, Vals.end()));
+
+  for (auto &&Token : make_range(Tokens.begin(), M)) {
+    StringRef ExtensionName = Token.substr(1);
+    auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
 
+    if (NameValuePair == SPIRVExtensionMap.end())
+      return O.error("Unknown SPIR-V extension: " + Token.str());
+
+    EnabledExtensions.insert(NameValuePair->second);
+  }
+
+  for (auto &&Token : make_range(M, Tokens.end())) {
+    if (Token == "all")
       continue;
-    }
 
     if (Token.size() == 3 && Token.upper() == "KHR") {
       for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
         if (StringRef(ExtensionName).starts_with("SPV_KHR_"))
-          EnabledExtensions.insert(ExtensionEnum);
+          Vals.insert(ExtensionEnum);
       continue;
     }
 
     if (Token.empty() || (!Token.starts_with("+") && !Token.starts_with("-")))
-      return O.error("Invalid extension list format: " + Token.str());
+      return O.error("Invalid extension list format: " + Token);
 
-    StringRef ExtensionName = Token.substr(1);
-    auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
+    auto NameValuePair = SPIRVExtensionMap.find(Token.substr(1));
 
-    if (NameValuePair == SPIRVExtensionMap.end())
+    if (NameValuePair == SPIRVExtensionMap.cend())
       return O.error("Unknown SPIR-V extension: " + Token.str());
+    if (EnabledExtensions.count(NameValuePair->second))
+      return O.error(
+          "Extension cannot be allowed and disallowed at the same time: " +
+          NameValuePair->first);
 
-    if (Token.starts_with("+")) {
-      EnabledExtensions.insert(NameValuePair->second);
-    } else if (EnabledExtensions.count(NameValuePair->second)) {
-      if (llvm::is_contained(Tokens, "+" + ExtensionName.str()))
-        return O.error(
-            "Extension cannot be allowed and disallowed at the same time: " +
-            ExtensionName.str());
-
-      EnabledExtensions.erase(NameValuePair->second);
-    }
+    Vals.erase(NameValuePair->second);
   }
 
-  Vals = std::move(EnabledExtensions);
+  Vals.insert(EnabledExtensions.cbegin(), EnabledExtensions.cend());
+
   return false;
 }
 



More information about the llvm-commits mailing list