[Mlir-commits] [mlir] [mlir][spirv] Verify vector types when parsing (PR #180430)

Jakub Kuderski llvmlistbot at llvm.org
Sun Feb 8 10:18:57 PST 2026


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/180430

Also remove some invalid test cases (replicated constants).

Fixes: https://github.com/llvm/llvm-project/issues/180419

>From 81dabe7e5c214c09a50c89d0d7dfb3cb28aa33a2 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 8 Feb 2026 13:16:16 -0500
Subject: [PATCH] [mlir][spirv] Verify vector types when parsing

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    | 15 ++++++++--
 mlir/test/Dialect/SPIRV/IR/types.mlir         | 10 +++++++
 .../replicated-const-composites.mlir          | 30 -------------------
 3 files changed, 23 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 7c3bfd72115e6..78f33c238d414 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -167,11 +167,11 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
   if (parser.parseType(type))
     return Type();
 
-  // Allow SPIR-V dialect types
+  // Allow SPIR-V dialect types.
   if (&type.getDialect() == &dialect)
     return type;
 
-  // Check other allowed types
+  // Check other allowed types.
   if (auto t = dyn_cast<FloatType>(type)) {
     // TODO: All float types are allowed for now, but this should be fixed.
   } else if (auto t = dyn_cast<IntegerType>(type)) {
@@ -186,12 +186,23 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
       return Type();
     }
+    if (t.getNumElements() < 2) {
+      parser.emitError(typeLoc, "SPIR-V does not allow one-element vectors");
+      return Type();
+    }
     if (t.getNumElements() > 4) {
       parser.emitError(
           typeLoc, "vector length has to be less than or equal to 4 but found ")
           << t.getNumElements();
       return Type();
     }
+    if (!isa<ScalarType>(t.getElementType())) {
+      parser.emitError(
+          typeLoc,
+          "vector element type must be a SPIR-V scalar type but found ")
+          << t.getElementType();
+      return Type();
+    }
   } else if (auto t = dyn_cast<TensorArmType>(type)) {
     if (!isa<ScalarType>(t.getElementType())) {
       parser.emitError(
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 145a291343504..98509fb376acf 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -120,6 +120,16 @@ func.func private @unknown_storage_class(!spirv.ptr<f32, SomeStorageClass>) -> (
 
 // -----
 
+// expected-error @+1 {{SPIR-V does not allow one-element vectors}}
+func.func private @invalid_ptr_to_vector_one_element(!spirv.ptr<vector<1xi32>, SomeStorageClass>) -> ()
+
+// -----
+
+// expected-error @+1 {{vector element type must be a SPIR-V scalar type}}
+func.func private @invalid_ptr_to_vector_index(!spirv.ptr<vector<2xindex>, SomeStorageClass>) -> ()
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // RuntimeArrayType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
index 56e26eee83ff9..0a413e5036be9 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
@@ -55,18 +55,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>>
   }
 
-  spirv.func @array_of_one_splat_array_of_vector_of_one_i32() -> !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> "None" {
-    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>
-    %cst = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>>
-    spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>>
-  }
-
-  spirv.func @splat_array_of_array_of_one_vector_of_one_i32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xi32>>>) "None" {
-    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
-    %0 = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
-    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
-  }
-
   spirv.func @array_of_one_array_of_one_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
     %0 = spirv.Constant [[dense<1> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
@@ -133,18 +121,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>>
   }
 
-  spirv.func @array_of_one_splat_array_of_vector_of_one_f32() -> !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> "None" {
-    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>
-    %cst = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>>
-    spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>>
-  }
-
-  spirv.func @splat_array_of_array_of_one_vector_of_one_f32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xf32>>>) "None" {
-    // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
-    %0 = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
-    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
-  }
-
   spirv.func @array_of_one_array_of_one_splat_vector_of_f32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xf32>>>) "None" {
     // CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
     %0 = spirv.Constant [[dense<1.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
@@ -210,12 +186,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     %0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
     spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
   }
-  
-  spirv.func @array_of_one_array_of_one_vector_of_one_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<1xi32>>>) "None" {
-    // CHECK-NOT spirv.EXT.ConstantCompositeReplicate
-    %0 = spirv.Constant [[dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>>
-    spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>>
-  }
 }
 
 // -----



More information about the Mlir-commits mailing list