[Mlir-commits] [mlir] e7063e8 - [mlir][spirv] Enforce `SPIRV_Vector` to have rank of one (#178185)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 27 07:08:30 PST 2026


Author: Igor Wodiany
Date: 2026-01-27T15:08:25Z
New Revision: e7063e820637498355a184a45c42c19fa58ff2f3

URL: https://github.com/llvm/llvm-project/commit/e7063e820637498355a184a45c42c19fa58ff2f3
DIFF: https://github.com/llvm/llvm-project/commit/e7063e820637498355a184a45c42c19fa58ff2f3.diff

LOG: [mlir][spirv] Enforce `SPIRV_Vector` to have rank of one (#178185)

Currently only vector length is enforced however this allows vectors of
rank >1 to pass the verification as long as the length agrees. This
change restricts `SPIRV_Vector`s to be of rank 1 as required by the
SPIR-V spec.

This also fixes a bug where `SPIRV_Composite` allowed high ranked
vectors but `spirv::CompositeType` did not leading to cast assertions
where the composite type was assumed.

Finally, this change adds two new common constraints that can enforce
all three: rank, length and type.

fixes #178127

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/IR/CommonTypeConstraints.td
    mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
    mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
    mlir/test/Dialect/SPIRV/IR/group-ops.mlir
    mlir/test/Dialect/SPIRV/IR/image-ops.mlir
    mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
    mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
    mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
    mlir/test/Dialect/SPIRV/IR/structure-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 4ea6d784dd88f..f8093d3042c50 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4242,7 +4242,7 @@ def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
 def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
 def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
-def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
+def SPIRV_Vector : VectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16],
                                        [SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
 // Component type check is done in the type parser for the following SPIR-V
 // dialect-specific types so we use "Any" here.
@@ -4295,7 +4295,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
     "Matrix">;
 
 class SPIRV_VectorOf<Type type> :
-    FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;
+    FixedVectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16], [type]>;
 
 class SPIRV_ScalarOrVectorOf<Type type> :
     AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
@@ -4314,7 +4314,7 @@ class SPIRV_MatrixOf<Type type> :
 def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
 def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
 
-class SPIRV_Vec4<Type type> : VectorOfLengthAndType<[4], [type]>;
+class SPIRV_Vec4<Type type> : VectorOfRankAndLengthAndType<[1], [4], [type]>;
 def SPIRV_IntVec4 : SPIRV_Vec4<SPIRV_Integer>;
 def SPIRV_IOrUIVec4 : SPIRV_Vec4<SPIRV_SignlessOrUnsignedInt>;
 def SPIRV_Int32Vec4 : SPIRV_Vec4<AnyI32>;

diff  --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 0fb4837e528be..a49880b81e90d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -649,6 +649,16 @@ class VectorOfLengthAndType<list<int> allowedLengths,
    VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
+// Any vector where the number of elements is from the given
+// `allowedLengths` list and the type is from the given `allowedTypes`
+// list and the rank is from the given `allowedRanks` list
+class VectorOfRankAndLengthAndType<list<int> allowedRanks,
+                                   list<int> allowedLengths,
+                                   list<Type> allowedTypes> : AllOfType<
+  [VectorOfRank<allowedRanks>, VectorOfNonZeroRankOf<allowedTypes>, VectorOfLength<allowedLengths>],
+   VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary # VectorOfRank<allowedRanks>.summary,
+  "::mlir::VectorType">;
+
 // Any vector where the number of elements is between
 // `minLength` and `maxLength` (inclusive).
 class VectorOfMinMaxLengthAndType<int minLength, int maxLength,
@@ -674,6 +684,18 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
   FixedVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
+// Any fixed-length vector where the number of elements is from the given
+// `allowedLengths` list and the type is from the given `allowedTypes` list
+// as the rank is from the given `allowedRanks` list
+class FixedVectorOfRankAndLengthAndType<list<int> allowedRanks,
+                                        list<int> allowedLengths,
+                                        list<Type> allowedTypes> : AllOfType<
+  [VectorOfRank<allowedRanks>, FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
+  FixedVectorOfAnyRank<allowedTypes>.summary #
+  FixedVectorOfLength<allowedLengths>.summary #
+  VectorOfRank<allowedRanks>.summary,
+  "::mlir::VectorType">;
+
 // Any scalable vector where the number of elements is from the given
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class ScalableVectorOfLengthAndType<list<int> allowedLengths,

diff  --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index e71b545de11df..9323518f50373 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -99,6 +99,14 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
 
 // -----
 
+func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
+  // expected-error @+1 {{op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
+  %0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
+  return %0: vector<4x2xi1>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.CompositeExtractOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index bab12b183faf1..eea80ca3798a6 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -736,7 +736,7 @@ func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
 // -----
 
 func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) {
-  // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
+  // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'vector<3xi32>'}}
   %0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32>
   return
 }
@@ -1126,7 +1126,7 @@ func.func @lengthvec(%arg0 : vector<3xf32>) -> () {
 // -----
 
 func.func @length_i32_in(%arg0 : i32) -> () {
-  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}}
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
   %0 = spirv.GL.Length %arg0 : i32 -> f32
   return
 }
@@ -1142,7 +1142,7 @@ func.func @length_f16_in(%arg0 : f16) -> () {
 // -----
 
 func.func @length_i32vec_in(%arg0 : vector<3xi32>) -> () {
-  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'vector<3xi32>'}}
   %0 = spirv.GL.Length %arg0 : vector<3xi32> -> f32
   return
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
index d7a4a6d92fcd3..d26bfe9185bdd 100644
--- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
@@ -220,7 +220,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_scope(%value: vector<4xi32>)
 // -----
 
 func.func @group_non_uniform_ballot_bit_count_wrong_value_len(%value: vector<3xi32>) -> i32 {
-  // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<3xi32>'}}
+  // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<3xi32>'}}
   %0 = spirv.GroupNonUniformBallotBitCount <Subgroup> <InclusiveScan> %value : vector<3xi32> -> i32
   return %0: i32
 }
@@ -228,7 +228,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_value_len(%value: vector<3xi
 // -----
 
 func.func @group_non_uniform_ballot_bit_count_wrong_value_type(%value: vector<4xi8>) -> i32 {
-  // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<4xi8>'}}
+  // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xi8>'}}
   %0 = spirv.GroupNonUniformBallotBitCount <Subgroup> <InclusiveScan> %value : vector<4xi8> -> i32
   return %0: i32
 }
@@ -236,7 +236,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_value_type(%value: vector<4x
 // -----
 
 func.func @group_non_uniform_ballot_bit_count_value_sign(%value: vector<4xsi32>) -> i32 {
-  // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
+  // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xsi32>'}}
   %0 = spirv.GroupNonUniformBallotBitCount <Subgroup> <InclusiveScan> %value : vector<4xsi32> -> i32
   return %0: i32
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
index 7369d719ca53d..12b5f2ce62a68 100644
--- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
@@ -29,7 +29,7 @@ func.func @image_dref_gather_with_mismatch_imageoperands(%arg0 : !spirv.sampled_
 // -----
 
 func.func @image_dref_gather_error_result_type(%arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) -> () {
-  // expected-error @+1 {{must be vector of 8/16/32/64-bit integer values of length 4 or vector of 16/32/64-bit float values of length 4}}
+  // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit integer values of length 4 of ranks 1 or vector of 16/32/64-bit float values of length 4 of ranks 1, but got 'vector<3xi32>'}}
   %0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<3xi32>
   spirv.Return
 }
@@ -326,7 +326,7 @@ func.func @image_fetch_type_mismatch(%arg0: !spirv.image<f32, Dim2D, NoDepth, No
 // -----
 
 func.func @image_fetch_2d_result(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, %arg1: vector<2xsi32>) -> () {
-  // expected-error @+1 {{op result #0 must be vector of 16/32/64-bit float values of length 4 or vector of 8/16/32/64-bit integer values of length 4, but got 'vector<2xf32>'}}
+  // expected-error @+1 {{op result #0 must be vector of 16/32/64-bit float values of length 4 of ranks 1 or vector of 8/16/32/64-bit integer values of length 4 of ranks 1, but got 'vector<2xf32>'}}
   %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, vector<2xsi32> -> vector<2xf32>
   spirv.Return
 }
@@ -334,7 +334,7 @@ func.func @image_fetch_2d_result(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArr
 // -----
 
 func.func @image_fetch_float_coords(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, %arg1: vector<2xf32>) -> () {
-  // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}}
+  // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'vector<2xf32>'}}
   %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, vector<2xf32> -> vector<2xf32>
   spirv.Return
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index 2e2fb1a9df328..d124c02231161 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
 // -----
 
 spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
-  // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}}
+  // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got}}
   %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
   spirv.Return
 }
@@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
 // -----
 
 spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
-  // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}}
+  // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got}}
   %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
   spirv.Return
 }
@@ -93,7 +93,7 @@ spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
 // -----
 
 spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
-  // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
+  // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got 'f64'}}
   %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
   spirv.Return
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index d7f4ed05969aa..1018751cf65e0 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -184,7 +184,7 @@ func.func @logicalUnary(%arg0 : i1)
 
 func.func @logicalUnary(%arg0 : i32)
 {
-  // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+  // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
   %0 = spirv.LogicalNot %arg0 : i32
   return
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index b22951f90510a..168823a6e9c2d 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -21,7 +21,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
 // -----
 
 func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
-  // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
+  // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xsi32>'}}
   %0 = spirv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xsi32>
   return %0: vector<4xsi32>
 }
@@ -185,7 +185,7 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto
 // -----
 
 func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 {
-  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'bf16'}}
   %0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %val : bf16 -> bf16
   return %0: bf16
 }
@@ -206,7 +206,7 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 {
 // -----
 
 func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 {
-  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'bf16'}}
   %0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %val : bf16 -> bf16
   return %0: bf16
 }
@@ -511,7 +511,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 {
 // -----
 
 func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 {
-  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
   %0 = spirv.GroupNonUniformBitwiseAnd <Workgroup> <Reduce> %val : i1 -> i1
   return %0: i1
 }
@@ -532,7 +532,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 {
 // -----
 
 func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 {
-  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
   %0 = spirv.GroupNonUniformBitwiseOr <Workgroup> <Reduce> %val : i1 -> i1
   return %0: i1
 }
@@ -553,7 +553,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 {
 // -----
 
 func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 {
-  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
   %0 = spirv.GroupNonUniformBitwiseXor <Workgroup> <Reduce> %val : i1 -> i1
   return %0: i1
 }
@@ -574,7 +574,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 {
 // -----
 
 func.func @group_non_uniform_logical_and(%val: i32) -> i32 {
-  // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+  // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
   %0 = spirv.GroupNonUniformLogicalAnd <Workgroup> <Reduce> %val : i32 -> i32
   return %0: i32
 }
@@ -595,7 +595,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 {
 // -----
 
 func.func @group_non_uniform_logical_or(%val: i32) -> i32 {
-  // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+  // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
   %0 = spirv.GroupNonUniformLogicalOr <Workgroup> <Reduce> %val : i32 -> i32
   return %0: i32
 }
@@ -616,7 +616,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 {
 // -----
 
 func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
-  // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+  // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
   %0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32
   return %0: i32
 }
@@ -807,7 +807,7 @@ func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
 
 func.func @group_non_uniform_quad_swap(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
   %dir = spirv.Constant 0 : i32
-  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 or bool or fixed-length vector of bool values of length 2/3/4/8/16, but got '!spirv.array<3 x i32>'}}
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}}
   %0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : !spirv.array<3 x i32>, i32
   return %0: !spirv.array<3 x i32>
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 20bb4eace370b..2c5dc8b9f3b0f 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -170,7 +170,7 @@ func.func @coop_matrix_const_wrong_type() -> () {
 //===----------------------------------------------------------------------===//
 
 func.func @ccr_result_not_composite() -> () {
-  // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
+  // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
   %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
   return
 }


        


More information about the Mlir-commits mailing list