[Mlir-commits] [mlir] [mlir][spirv] Fix UpdateVCEPass to deduce the correct set of capabilities (PR #151108)
Davide Grohmann
llvmlistbot at llvm.org
Wed Jul 30 03:54:46 PDT 2025
https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/151108
>From b5156d6735280ddb074558582c1fd11b6004a3c8 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Tue, 29 Jul 2025 10:23:34 +0200
Subject: [PATCH 1/3] [mlir][spirv] Fix UpdateVCEPass to deduce the correct set
of capabilities
When deducing capabilities implied capabilities are not considered,
which causes generation of incorrect SPIR-V modules. This commit fixes
that by pulling in the capability set all the implied ones.
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ia30149fb35bbf0071010cb7bc92b86d2e5b6a6af
---
.../SPIRV/Transforms/UpdateVCEPass.cpp | 12 ++++++++
.../SPIRV/Transforms/vce-deduction.mlir | 30 +++++++++----------
2 files changed, 27 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6a9b951ca61d6..da316b98c2b20 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -95,6 +95,16 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
return success();
}
+static SetVector<spirv::Capability>
+withImpliedCapabilities(SetVector<spirv::Capability> &caps) {
+ SetVector<spirv::Capability> allCaps(caps.begin(), caps.end());
+ for (auto cap : caps) {
+ ArrayRef<spirv::Capability> directCaps = getDirectImpliedCapabilities(cap);
+ allCaps.insert(directCaps.begin(), directCaps.end());
+ }
+ return allCaps;
+}
+
void UpdateVCEPass::runOnOperation() {
spirv::ModuleOp module = getOperation();
@@ -168,6 +178,8 @@ void UpdateVCEPass::runOnOperation() {
return WalkResult::interrupt();
}
+ deducedCapabilities = withImpliedCapabilities(deducedCapabilities);
+
return WalkResult::advance();
});
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 2b237665ffc4a..b536b8e4003f9 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -7,7 +7,7 @@
// Test deducing minimal version.
// spirv.IAdd is available from v1.0.
-// CHECK: requires #spirv.vce<v1.0, [Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader], []>, #spirv.resource_limits<>>
@@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes {
// Test deducing minimal version.
// spirv.GroupNonUniformBallot is available since v1.3.
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, Shader], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -32,7 +32,7 @@ spirv.module Logical GLSL450 attributes {
}
}
-// CHECK: requires #spirv.vce<v1.4, [Shader], []>
+// CHECK: requires #spirv.vce<v1.4, [Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader], []>, #spirv.resource_limits<>>
} {
@@ -48,7 +48,7 @@ spirv.module Logical GLSL450 attributes {
// Test minimal capabilities.
-// CHECK: requires #spirv.vce<v1.0, [Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, Float16, Float64, Int16, Int64, VariablePointers], []>, #spirv.resource_limits<>>
@@ -61,7 +61,7 @@ spirv.module Logical GLSL450 attributes {
// Test Physical Storage Buffers are deduced correctly.
-// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader], [SPV_EXT_physical_storage_buffer]>
+// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader, Matrix], [SPV_EXT_physical_storage_buffer]>
spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
@@ -74,7 +74,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
// Test deducing implied capability.
// AtomicStorage implies Shader.
-// CHECK: requires #spirv.vce<v1.0, [Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [AtomicStorage], []>, #spirv.resource_limits<>>
@@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes {
// * GroupNonUniformArithmetic
// * GroupNonUniformBallot
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, Shader], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, GroupNonUniform, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
@@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes {
}
}
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, Shader], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes {
// Test type required capabilities
// Using 8-bit integers in non-interface storage class requires Int8.
-// CHECK: requires #spirv.vce<v1.0, [Int8, Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Int8, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Int8], []>, #spirv.resource_limits<>>
@@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes {
}
// Using 16-bit floats in non-interface storage class requires Float16.
-// CHECK: requires #spirv.vce<v1.0, [Float16, Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Float16, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Float16], []>, #spirv.resource_limits<>>
@@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes {
}
// Using 16-element vectors requires Vector16.
-// CHECK: requires #spirv.vce<v1.0, [Vector16, Shader], []>
+// CHECK: requires #spirv.vce<v1.0, [Vector16, Kernel, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Vector16], []>, #spirv.resource_limits<>>
@@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes {
// Test deducing minimal extensions.
// spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension.
-// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader], [SPV_KHR_shader_ballot]>
+// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader, Matrix], [SPV_KHR_shader_ballot]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, SubgroupBallotKHR],
@@ -193,7 +193,7 @@ spirv.module Logical Vulkan attributes {
// Using 8-bit integers in interface storage class requires additional
// extensions and capabilities.
-// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16, Matrix], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>, #spirv.resource_limits<>>
@@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes {
// Complicated nested types
// * Buffer requires ImageBuffer or SampledBuffer.
// * Rg32f requires StorageImageExtendedFormats.
-// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, StorageBuffer8BitAccess, StorageBuffer16BitAccess, Matrix, ImageBuffer, StorageImageExtendedFormats, SampledBuffer], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
@@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes {
}
// Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
-// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR, Matrix], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,
>From 02dac82f88b8ae590942d53cb850e163b45b50be Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 30 Jul 2025 10:19:00 +0200
Subject: [PATCH 2/3] Resolve code review comments
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ib58ef4d1d24e395678c9527abdd7e96a9b1df9eb
---
.../SPIRV/Transforms/UpdateVCEPass.cpp | 17 +++++++++-----
.../SPIRV/Transforms/vce-deduction.mlir | 22 +++++++++----------
2 files changed, 22 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index da316b98c2b20..ae79c39c29b46 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -96,11 +96,16 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
}
static SetVector<spirv::Capability>
-withImpliedCapabilities(SetVector<spirv::Capability> &caps) {
- SetVector<spirv::Capability> allCaps(caps.begin(), caps.end());
- for (auto cap : caps) {
- ArrayRef<spirv::Capability> directCaps = getDirectImpliedCapabilities(cap);
- allCaps.insert(directCaps.begin(), directCaps.end());
+addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
+ SetVector<spirv::Capability> allCaps;
+ while (!caps.empty()) {
+ spirv::Capability cap = caps.pop_back_val();
+ allCaps.insert(cap);
+ ArrayRef<spirv::Capability> impliedCaps = getDirectImpliedCapabilities(cap);
+ for (spirv::Capability impliedCap : impliedCaps) {
+ if (!allCaps.contains(impliedCap))
+ caps.insert(impliedCap);
+ }
}
return allCaps;
}
@@ -178,7 +183,7 @@ void UpdateVCEPass::runOnOperation() {
return WalkResult::interrupt();
}
- deducedCapabilities = withImpliedCapabilities(deducedCapabilities);
+ deducedCapabilities = addAllImpliedCapabilities(deducedCapabilities);
return WalkResult::advance();
});
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index b536b8e4003f9..9410435bbea99 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes {
// Test deducing minimal version.
// spirv.GroupNonUniformBallot is available since v1.3.
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniform, GroupNonUniformBallot], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -61,7 +61,7 @@ spirv.module Logical GLSL450 attributes {
// Test Physical Storage Buffers are deduced correctly.
-// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader, Matrix], [SPV_EXT_physical_storage_buffer]>
+// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [Shader, Matrix, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>
spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
@@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes {
// * GroupNonUniformArithmetic
// * GroupNonUniformBallot
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, GroupNonUniform, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniform, GroupNonUniformArithmetic], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
@@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes {
}
}
-// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformClustered, GroupNonUniform, GroupNonUniformBallot], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes {
// Test type required capabilities
// Using 8-bit integers in non-interface storage class requires Int8.
-// CHECK: requires #spirv.vce<v1.0, [Int8, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Int8], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Int8], []>, #spirv.resource_limits<>>
@@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes {
}
// Using 16-bit floats in non-interface storage class requires Float16.
-// CHECK: requires #spirv.vce<v1.0, [Float16, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Float16], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Float16], []>, #spirv.resource_limits<>>
@@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes {
}
// Using 16-element vectors requires Vector16.
-// CHECK: requires #spirv.vce<v1.0, [Vector16, Kernel, Shader, Matrix], []>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Kernel, Vector16], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Vector16], []>, #spirv.resource_limits<>>
@@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes {
// Test deducing minimal extensions.
// spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension.
-// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader, Matrix], [SPV_KHR_shader_ballot]>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, SubgroupBallotKHR], [SPV_KHR_shader_ballot]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, SubgroupBallotKHR],
@@ -193,7 +193,7 @@ spirv.module Logical Vulkan attributes {
// Using 8-bit integers in interface storage class requires additional
// extensions and capabilities.
-// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16, Matrix], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [Int16, Shader, Matrix, StorageBuffer16BitAccess], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>, #spirv.resource_limits<>>
@@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes {
// Complicated nested types
// * Buffer requires ImageBuffer or SampledBuffer.
// * Rg32f requires StorageImageExtendedFormats.
-// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, StorageBuffer8BitAccess, StorageBuffer16BitAccess, Matrix, ImageBuffer, StorageImageExtendedFormats, SampledBuffer], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Int64, StorageUniform16, StorageBuffer16BitAccess, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, SampledBuffer, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
@@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes {
}
// Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
-// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR, Matrix], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Matrix, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,
>From f0e913952a9673146ff5ed9a442e530917e11c69 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 30 Jul 2025 12:53:17 +0200
Subject: [PATCH 3/3] More improvements from code review
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I34150644e4bcf559597b3c3b3dbb668e5c828faf
---
.../SPIRV/Transforms/UpdateVCEPass.cpp | 16 ++++----------
.../SPIRV/Transforms/vce-deduction.mlir | 22 +++++++++----------
2 files changed, 15 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index ae79c39c29b46..9b1c84ee66156 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -95,19 +95,11 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
return success();
}
-static SetVector<spirv::Capability>
-addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
- SetVector<spirv::Capability> allCaps;
- while (!caps.empty()) {
- spirv::Capability cap = caps.pop_back_val();
- allCaps.insert(cap);
+static void addAllImpliedCapabilities(SetVector<spirv::Capability> &caps) {
+ for (spirv::Capability cap : caps) {
ArrayRef<spirv::Capability> impliedCaps = getDirectImpliedCapabilities(cap);
- for (spirv::Capability impliedCap : impliedCaps) {
- if (!allCaps.contains(impliedCap))
- caps.insert(impliedCap);
- }
+ caps.insert_range(impliedCaps);
}
- return allCaps;
}
void UpdateVCEPass::runOnOperation() {
@@ -183,7 +175,7 @@ void UpdateVCEPass::runOnOperation() {
return WalkResult::interrupt();
}
- deducedCapabilities = addAllImpliedCapabilities(deducedCapabilities);
+ addAllImpliedCapabilities(deducedCapabilities);
return WalkResult::advance();
});
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 9410435bbea99..b536b8e4003f9 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes {
// Test deducing minimal version.
// spirv.GroupNonUniformBallot is available since v1.3.
-// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniform, GroupNonUniformBallot], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -61,7 +61,7 @@ spirv.module Logical GLSL450 attributes {
// Test Physical Storage Buffers are deduced correctly.
-// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [Shader, Matrix, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>
+// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0, [PhysicalStorageBufferAddresses, Shader, Matrix], [SPV_EXT_physical_storage_buffer]>
spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
@@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes {
// * GroupNonUniformArithmetic
// * GroupNonUniformBallot
-// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniform, GroupNonUniformArithmetic], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformArithmetic, GroupNonUniform, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spirv.resource_limits<>>
@@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes {
}
}
-// CHECK: requires #spirv.vce<v1.3, [Shader, Matrix, GroupNonUniformClustered, GroupNonUniform, GroupNonUniformBallot], []>
+// CHECK: requires #spirv.vce<v1.3, [GroupNonUniformClustered, GroupNonUniformBallot, GroupNonUniform, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, GroupNonUniformClustered, GroupNonUniformBallot], []>, #spirv.resource_limits<>>
@@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes {
// Test type required capabilities
// Using 8-bit integers in non-interface storage class requires Int8.
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Int8], []>
+// CHECK: requires #spirv.vce<v1.0, [Int8, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Int8], []>, #spirv.resource_limits<>>
@@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes {
}
// Using 16-bit floats in non-interface storage class requires Float16.
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Float16], []>
+// CHECK: requires #spirv.vce<v1.0, [Float16, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Float16], []>, #spirv.resource_limits<>>
@@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes {
}
// Using 16-element vectors requires Vector16.
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Kernel, Vector16], []>
+// CHECK: requires #spirv.vce<v1.0, [Vector16, Kernel, Shader, Matrix], []>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, Vector16], []>, #spirv.resource_limits<>>
@@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes {
// Test deducing minimal extensions.
// spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension.
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, SubgroupBallotKHR], [SPV_KHR_shader_ballot]>
+// CHECK: requires #spirv.vce<v1.0, [SubgroupBallotKHR, Shader, Matrix], [SPV_KHR_shader_ballot]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, SubgroupBallotKHR],
@@ -193,7 +193,7 @@ spirv.module Logical Vulkan attributes {
// Using 8-bit integers in interface storage class requires additional
// extensions and capabilities.
-// CHECK: requires #spirv.vce<v1.0, [Int16, Shader, Matrix, StorageBuffer16BitAccess], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16, Matrix], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>, #spirv.resource_limits<>>
@@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes {
// Complicated nested types
// * Buffer requires ImageBuffer or SampledBuffer.
// * Rg32f requires StorageImageExtendedFormats.
-// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix, Int64, StorageUniform16, StorageBuffer16BitAccess, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, SampledBuffer, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+// CHECK: requires #spirv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, StorageBuffer8BitAccess, StorageBuffer16BitAccess, Matrix, ImageBuffer, StorageImageExtendedFormats, SampledBuffer], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
@@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes {
}
// Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension.
-// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Matrix, Shader, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+// CHECK: requires #spirv.vce<v1.0, [StorageBuffer16BitAccess, Shader, BFloat16TypeKHR, Matrix], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
spirv.module Logical GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, StorageBuffer16BitAccess, BFloat16TypeKHR], [SPV_KHR_bfloat16, SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>,
More information about the Mlir-commits
mailing list