[Mlir-commits] [mlir] [mlir][spirv] Verify vector types when parsing (PR #180430)
Jakub Kuderski
llvmlistbot at llvm.org
Sun Feb 8 10:18:57 PST 2026
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/180430
Also remove some invalid test cases (replicated constants).
Fixes: https://github.com/llvm/llvm-project/issues/180419
>From 81dabe7e5c214c09a50c89d0d7dfb3cb28aa33a2 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 8 Feb 2026 13:16:16 -0500
Subject: [PATCH] [mlir][spirv] Verify vector types when parsing
---
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 15 ++++++++--
mlir/test/Dialect/SPIRV/IR/types.mlir | 10 +++++++
.../replicated-const-composites.mlir | 30 -------------------
3 files changed, 23 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 7c3bfd72115e6..78f33c238d414 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -167,11 +167,11 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
if (parser.parseType(type))
return Type();
- // Allow SPIR-V dialect types
+ // Allow SPIR-V dialect types.
if (&type.getDialect() == &dialect)
return type;
- // Check other allowed types
+ // Check other allowed types.
if (auto t = dyn_cast<FloatType>(type)) {
// TODO: All float types are allowed for now, but this should be fixed.
} else if (auto t = dyn_cast<IntegerType>(type)) {
@@ -186,12 +186,23 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
return Type();
}
+ if (t.getNumElements() < 2) {
+ parser.emitError(typeLoc, "SPIR-V does not allow one-element vectors");
+ return Type();
+ }
if (t.getNumElements() > 4) {
parser.emitError(
typeLoc, "vector length has to be less than or equal to 4 but found ")
<< t.getNumElements();
return Type();
}
+ if (!isa<ScalarType>(t.getElementType())) {
+ parser.emitError(
+ typeLoc,
+ "vector element type must be a SPIR-V scalar type but found ")
+ << t.getElementType();
+ return Type();
+ }
} else if (auto t = dyn_cast<TensorArmType>(type)) {
if (!isa<ScalarType>(t.getElementType())) {
parser.emitError(
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 145a291343504..98509fb376acf 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -120,6 +120,16 @@ func.func private @unknown_storage_class(!spirv.ptr<f32, SomeStorageClass>) -> (
// -----
+// expected-error @+1 {{SPIR-V does not allow one-element vectors}}
+func.func private @invalid_ptr_to_vector_one_element(!spirv.ptr<vector<1xi32>, SomeStorageClass>) -> ()
+
+// -----
+
+// expected-error @+1 {{vector element type must be a SPIR-V scalar type}}
+func.func private @invalid_ptr_to_vector_index(!spirv.ptr<vector<2xindex>, SomeStorageClass>) -> ()
+
+// -----
+
//===----------------------------------------------------------------------===//
// RuntimeArrayType
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
index 56e26eee83ff9..0a413e5036be9 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -55,18 +55,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>>
}
- spirv.func @array_of_one_splat_array_of_vector_of_one_i32() -> !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> "None" {
- // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>
- %cst = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>>
- spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>>
- }
-
- spirv.func @splat_array_of_array_of_one_vector_of_one_i32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xi32>>>) "None" {
- // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
- %0 = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
- spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
- }
-
spirv.func @array_of_one_array_of_one_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
%0 = spirv.Constant [[dense<1> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
@@ -133,18 +121,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>>
}
- spirv.func @array_of_one_splat_array_of_vector_of_one_f32() -> !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> "None" {
- // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>
- %cst = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>>
- spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>>
- }
-
- spirv.func @splat_array_of_array_of_one_vector_of_one_f32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xf32>>>) "None" {
- // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
- %0 = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
- spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
- }
-
spirv.func @array_of_one_array_of_one_splat_vector_of_f32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xf32>>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
%0 = spirv.Constant [[dense<1.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
@@ -210,12 +186,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
%0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
}
-
- spirv.func @array_of_one_array_of_one_vector_of_one_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<1xi32>>>) "None" {
- // CHECK-NOT spirv.EXT.ConstantCompositeReplicate
- %0 = spirv.Constant [[dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>>
- spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>>
- }
}
// -----
More information about the Mlir-commits
mailing list