[Mlir-commits] [mlir] Reland "[mlir][spirv] Fix UpdateVCEPass to deduce the correct set of capabilities" (PR #151502)

Davide Grohmann llvmlistbot at llvm.org
Fri Aug 1 06:23:28 PDT 2025


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/151502

>From 8fa18ae27d7d94646c59790532c7f687e84cbc88 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Tue, 29 Jul 2025 10:23:34 +0200
Subject: [PATCH 1/5] Reland "[mlir][spirv] Fix UpdateVCEPass to deduce the
 correct set of capabilities"

This reland PR #151108

The original PR made sanitizer builds to fail. The issue are now
resolved in this new patch.

Original commit message:
> When deducing capabilities implied capabilities are not considered,
which causes generation of incorrect SPIR-V modules. This commit fixes
that by pulling in the capability set all the implied ones.

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ia30149fb35bbf0071010cb7bc92b86d2e5b6a6af
---
 .../SPIRV/Transforms/UpdateVCEPass.cpp        | 10 ++++++
 .../SPIRV/Transforms/vce-deduction.mlir       | 32 +++++++++----------
 2 files changed, 26 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index a53d0a7e4c44b..d6d52947e57d8 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -95,6 +95,14 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
   return success();
 }
 
+static void addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
+  for (unsigned i = 0; i < caps.size(); ++i) {
+    spirv::Capability cap = caps[i];
+    ArrayRef<spirv::Capability> impliedCaps = getDirectImpliedCapabilities(cap);
+    caps.insert_range(impliedCaps);
+  }
+}
+
 void UpdateVCEPass::runOnOperation() {
   spirv::ModuleOp module = getOperation();
 
@@ -168,6 +176,8 @@ void UpdateVCEPass::runOnOperation() {
         return WalkResult::interrupt();
     }
 
+    addAllImpliedCapabilities(deducedCapabilities);
+
     return WalkResult::advance();
   });
 
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 8d7f3da4007cd..444c649a0c188 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -7,7 +7,7 @@
 // Test deducing minimal version.
 // spirv.IAdd is available from v1.0.
 
-// CHECK: requires #spirv.vce<v1.0, [Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.5, [Shader], []>, #spirv.resource_limits<>>
@@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes {
 // Test deducing minimal version.
 // spirv.GroupNonUniformBallot is available since v1.3.
 
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, Shader], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -32,7 +32,7 @@ spirv.module Logical GLSL450 attributes {
   }
 }
 
-// CHECK: requires #spirv.vce<v1.4, [Shader], []>
+// CHECK: requires #spirv.vce<v1.4, [Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader], []>, #spirv.resource_limits<>>
 } {
@@ -48,7 +48,7 @@ spirv.module Logical GLSL450 attributes {
 
 // Test minimal capabilities.
 
-// CHECK: requires #spirv.vce<v1.0, [Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, Float16, Float64, Int16, Int64, VariablePointers], []>, #spirv.resource_limits<>>
@@ -61,10 +61,10 @@ spirv.module Logical GLSL450 attributes {
 
 // Test Physical Storage Buffers are deduced correctly.
 
-// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader], [SPV_EXT_physical_storage_buffer]>
+// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader, Matrix], [SPV_EXT_physical_storage_buffer]>
 spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
-    #spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
+    #spirv.vce<v1.0, [PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
 } {
   spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
     spirv.Return
@@ -74,7 +74,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
 // Test deducing implied capability.
 // AtomicStorage implies Shader.
 
-// CHECK: requires #spirv.vce<v1.0, [Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [AtomicStorage], []>, #spirv.resource_limits<>>
@@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes {
 // * GroupNonUniformArithmetic
 // * GroupNonUniformBallot
 
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, Shader], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, GroupNonUniform, Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
@@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes {
   }
 }
 
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, Shader], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes {
 // Test type required capabilities
 
 // Using 8-bit integers in non-interface storage class requires Int8.
-// CHECK: requires #spirv.vce<v1.0, [Int8, Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Int8, Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, Int8], []>, #spirv.resource_limits<>>
@@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using 16-bit floats in non-interface storage class requires Float16.
-// CHECK: requires #spirv.vce<v1.0, [Float16, Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Float16, Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, Float16], []>, #spirv.resource_limits<>>
@@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using 16-element vectors requires Vector16.
-// CHECK: requires #spirv.vce<v1.0, [Vector16, Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Vector16, Kernel, Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, Vector16], []>, #spirv.resource_limits<>>
@@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes {
 // Test deducing minimal extensions.
 // spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension.
 
-// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader], [SPV_KHR_shader_ballot]>
+// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader, Matrix], [SPV_KHR_shader_ballot]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, SubgroupBallotKHR],
@@ -193,7 +193,7 @@ spirv.module Logical Vulkan attributes {
 
 // Using 8-bit integers in interface storage class requires additional
 // extensions and capabilities.
-// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16, Matrix], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>, #spirv.resource_limits<>>
@@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes {
 // Complicated nested types
 // * Buffer requires ImageBuffer or SampledBuffer.
 // * Rg32f requires StorageImageExtendedFormats.
-// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, StorageBuffer8BitAccess, StorageBuffer16BitAccess, Matrix, ImageBuffer, StorageImageExtendedFormats, SampledBuffer], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
@@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
-// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR, Matrix], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,

>From 2e9669385857c3755ec02f6eefed466058c85b0a Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 1 Aug 2025 12:02:15 +0200
Subject: [PATCH 2/5] Use two SetVectors when computed implied capabilities

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Id169f0ed656a3728cf3667cb35417d8a5d7c90e7
---
 .../SPIRV/Transforms/UpdateVCEPass.cpp        | 19 +++++++++++-----
 .../SPIRV/Transforms/vce-deduction.mlir       | 22 +++++++++----------
 2 files changed, 24 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index d6d52947e57d8..e583de4d58b9f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -95,12 +95,19 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
   return success();
 }
 
-static void addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
-  for (unsigned i = 0; i < caps.size(); ++i) {
-    spirv::Capability cap = caps[i];
+static SetVector<spirv::Capability>
+addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
+  SetVector<spirv::Capability> allCaps;
+  while (!caps.empty()) {
+    spirv::Capability cap = caps.pop_back_val();
+    allCaps.insert(cap);
     ArrayRef<spirv::Capability> impliedCaps = getDirectImpliedCapabilities(cap);
-    caps.insert_range(impliedCaps);
+    for (spirv::Capability impliedCap : impliedCaps) {
+      if (!allCaps.contains(impliedCap))
+        caps.insert(impliedCap);
+    }
   }
+  return allCaps;
 }
 
 void UpdateVCEPass::runOnOperation() {
@@ -176,14 +183,14 @@ void UpdateVCEPass::runOnOperation() {
         return WalkResult::interrupt();
     }
 
-    addAllImpliedCapabilities(deducedCapabilities);
-
     return WalkResult::advance();
   });
 
   if (walkResult.wasInterrupted())
     return signalPassFailure();
 
+  deducedCapabilities = addAllImpliedCapabilities(deducedCapabilities);
+
   // Update min version requirement for capabilities after deducing them.
   for (spirv::Capability cap : deducedCapabilities) {
     if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) {
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 444c649a0c188..06925aaeb8127 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes {
 // Test deducing minimal version.
 // spirv.GroupNonUniformBallot is available since v1.3.
 
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniform, GroupNonUniformBallot], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -61,7 +61,7 @@ spirv.module Logical GLSL450 attributes {
 
 // Test Physical Storage Buffers are deduced correctly.
 
-// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader, Matrix], [SPV_EXT_physical_storage_buffer]>
+// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [Shader, Matrix, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>
 spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
@@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes {
 // * GroupNonUniformArithmetic
 // * GroupNonUniformBallot
 
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, GroupNonUniform, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniform, GroupNonUniformArithmetic], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
@@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes {
   }
 }
 
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformClustered, GroupNonUniform, GroupNonUniformBallot], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes {
 // Test type required capabilities
 
 // Using 8-bit integers in non-interface storage class requires Int8.
-// CHECK: requires #spirv.vce<v1.0, [Int8, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Int8], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, Int8], []>, #spirv.resource_limits<>>
@@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using 16-bit floats in non-interface storage class requires Float16.
-// CHECK: requires #spirv.vce<v1.0, [Float16, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Float16], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, Float16], []>, #spirv.resource_limits<>>
@@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using 16-element vectors requires Vector16.
-// CHECK: requires #spirv.vce<v1.0, [Vector16, Kernel, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Kernel, Vector16], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, Vector16], []>, #spirv.resource_limits<>>
@@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes {
 // Test deducing minimal extensions.
 // spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension.
 
-// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader, Matrix], [SPV_KHR_shader_ballot]>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, SubgroupBallotKHR], [SPV_KHR_shader_ballot]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, SubgroupBallotKHR],
@@ -193,7 +193,7 @@ spirv.module Logical Vulkan attributes {
 
 // Using 8-bit integers in interface storage class requires additional
 // extensions and capabilities.
-// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16, Matrix], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [Int16, Shader, Matrix, StorageBuffer16BitAccess], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>, #spirv.resource_limits<>>
@@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes {
 // Complicated nested types
 // * Buffer requires ImageBuffer or SampledBuffer.
 // * Rg32f requires StorageImageExtendedFormats.
-// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, StorageBuffer8BitAccess, StorageBuffer16BitAccess, Matrix, ImageBuffer, StorageImageExtendedFormats, SampledBuffer], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Int64, StorageUniform16, StorageBuffer16BitAccess, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, SampledBuffer, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
@@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
-// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR, Matrix], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Matrix, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,

>From 1e0ab8828c71a94e754eeedb006faaa1a97f986c Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 1 Aug 2025 13:23:39 +0200
Subject: [PATCH 3/5] Fix tests

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I81f717f9e6d1c16ccddbf533c36e786233131774
---
 .../test/Dialect/SPIRV/Transforms/vce-deduction.mlir | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 06925aaeb8127..58094b3c783d6 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes {
 // Test deducing minimal version.
 // spirv.GroupNonUniformBallot is available since v1.3.
 
-// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniform, GroupNonUniformBallot], []>
+// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformBallot, GroupNonUniform], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes {
 // * GroupNonUniformArithmetic
 // * GroupNonUniformBallot
 
-// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniform, GroupNonUniformArithmetic], []>
+// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformArithmetic, GroupNonUniform], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
@@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes {
   }
 }
 
-// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformClustered, GroupNonUniform, GroupNonUniformBallot], []>
+// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformBallot, GroupNonUniform, GroupNonUniformClustered], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using 16-element vectors requires Vector16.
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Kernel, Vector16], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Vector16, Kernel], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, Vector16], []>, #spirv.resource_limits<>>
@@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes {
 // Complicated nested types
 // * Buffer requires ImageBuffer or SampledBuffer.
 // * Rg32f requires StorageImageExtendedFormats.
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Int64, StorageUniform16, StorageBuffer16BitAccess, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, SampledBuffer, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+// CHECK: requires #spirv.vce<v1.0, [StorageImageExtendedFormats, ImageBuffer, SampledBuffer, Shader, Matrix, Int64, StorageUniform16, StorageBuffer16BitAccess, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
@@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
-// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Matrix, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [BFloat16TypeKHR, Shader, Matrix, StorageBuffer16BitAccess], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,

>From dd70ae9bd56f6d0f1bc85587133a6036ce9ead3f Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 1 Aug 2025 14:57:57 +0200
Subject: [PATCH 4/5] Resolve review comments

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Icebb3419e792b94ea49d193c712ccba685592619
---
 .../SPIRV/Transforms/UpdateVCEPass.cpp        | 23 ++++++++-----------
 .../SPIRV/Transforms/vce-deduction.mlir       | 22 +++++++++---------
 2 files changed, 21 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index e583de4d58b9f..83b827860730a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -95,19 +95,16 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
   return success();
 }
 
-static SetVector<spirv::Capability>
-addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
-  SetVector<spirv::Capability> allCaps;
-  while (!caps.empty()) {
-    spirv::Capability cap = caps.pop_back_val();
-    allCaps.insert(cap);
-    ArrayRef<spirv::Capability> impliedCaps = getDirectImpliedCapabilities(cap);
-    for (spirv::Capability impliedCap : impliedCaps) {
-      if (!allCaps.contains(impliedCap))
-        caps.insert(impliedCap);
-    }
+static void addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
+  size_t old_size{0};
+  while (caps.size() > old_size) {
+    old_size = caps.size();
+    SetVector<spirv::Capability> tmp;
+    for (spirv::Capability cap : caps)
+      tmp.insert_range(getDirectImpliedCapabilities(cap));
+
+    caps.insert_range(tmp);
   }
-  return allCaps;
 }
 
 void UpdateVCEPass::runOnOperation() {
@@ -189,7 +186,7 @@ void UpdateVCEPass::runOnOperation() {
   if (walkResult.wasInterrupted())
     return signalPassFailure();
 
-  deducedCapabilities = addAllImpliedCapabilities(deducedCapabilities);
+  addAllImpliedCapabilities(deducedCapabilities);
 
   // Update min version requirement for capabilities after deducing them.
   for (spirv::Capability cap : deducedCapabilities) {
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 58094b3c783d6..4e534a30ad516 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes {
 // Test deducing minimal version.
 // spirv.GroupNonUniformBallot is available since v1.3.
 
-// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformBallot, GroupNonUniform], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, Shader, GroupNonUniform, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -61,7 +61,7 @@ spirv.module Logical GLSL450 attributes {
 
 // Test Physical Storage Buffers are deduced correctly.
 
-// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [Shader, Matrix, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>
+// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader, Matrix], [SPV_EXT_physical_storage_buffer]>
 spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
@@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes {
 // * GroupNonUniformArithmetic
 // * GroupNonUniformBallot
 
-// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformArithmetic, GroupNonUniform], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, Shader, GroupNonUniform, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
@@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes {
   }
 }
 
-// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformBallot, GroupNonUniform, GroupNonUniformClustered], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, Shader, GroupNonUniform, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes {
 // Test type required capabilities
 
 // Using 8-bit integers in non-interface storage class requires Int8.
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Int8], []>
+// CHECK: requires #spirv.vce<v1.0, [Int8, Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, Int8], []>, #spirv.resource_limits<>>
@@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using 16-bit floats in non-interface storage class requires Float16.
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Float16], []>
+// CHECK: requires #spirv.vce<v1.0, [Float16, Shader, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, Float16], []>, #spirv.resource_limits<>>
@@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using 16-element vectors requires Vector16.
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Vector16, Kernel], []>
+// CHECK: requires #spirv.vce<v1.0, [Vector16, Shader, Kernel, Matrix], []>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, Vector16], []>, #spirv.resource_limits<>>
@@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes {
 // Test deducing minimal extensions.
 // spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension.
 
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, SubgroupBallotKHR], [SPV_KHR_shader_ballot]>
+// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader, Matrix], [SPV_KHR_shader_ballot]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, SubgroupBallotKHR],
@@ -193,7 +193,7 @@ spirv.module Logical Vulkan attributes {
 
 // Using 8-bit integers in interface storage class requires additional
 // extensions and capabilities.
-// CHECK: requires #spirv.vce<v1.0, [Int16, Shader, Matrix, StorageBuffer16BitAccess], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16, Matrix], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>, #spirv.resource_limits<>>
@@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes {
 // Complicated nested types
 // * Buffer requires ImageBuffer or SampledBuffer.
 // * Rg32f requires StorageImageExtendedFormats.
-// CHECK: requires #spirv.vce<v1.0, [StorageImageExtendedFormats, ImageBuffer, SampledBuffer, Shader, Matrix, Int64, StorageUniform16, StorageBuffer16BitAccess, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, ImageBuffer, StorageImageExtendedFormats, StorageBuffer8BitAccess, StorageBuffer16BitAccess, Matrix, SampledBuffer], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
@@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes {
 }
 
 // Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
-// CHECK: requires #spirv.vce<v1.0, [BFloat16TypeKHR, Shader, Matrix, StorageBuffer16BitAccess], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR, Matrix], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
 spirv.module Logical GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,

>From efac44894d1daad85a460d11d41a33493f2ebf87 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 1 Aug 2025 15:22:08 +0200
Subject: [PATCH 5/5] Use getRecursiveImpliedCapabilities to simplify the code

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ida9dfa4d9ee912b99d922932e58aa70c6b6f6865
---
 mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 83b827860730a..9d47d255bc15f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -96,15 +96,10 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
 }
 
 static void addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
-  size_t old_size{0};
-  while (caps.size() > old_size) {
-    old_size = caps.size();
-    SetVector<spirv::Capability> tmp;
-    for (spirv::Capability cap : caps)
-      tmp.insert_range(getDirectImpliedCapabilities(cap));
-
-    caps.insert_range(tmp);
-  }
+  SetVector<spirv::Capability> tmp;
+  for (spirv::Capability cap : caps)
+    tmp.insert_range(getRecursiveImpliedCapabilities(cap));
+  caps.insert_range(tmp);
 }
 
 void UpdateVCEPass::runOnOperation() {



More information about the Mlir-commits mailing list