[Mlir-commits] [mlir] e7063e8 - [mlir][spirv] Enforce `SPIRV_Vector` to have rank of one (#178185)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 27 07:08:30 PST 2026
Author: Igor Wodiany
Date: 2026-01-27T15:08:25Z
New Revision: e7063e820637498355a184a45c42c19fa58ff2f3
URL: https://github.com/llvm/llvm-project/commit/e7063e820637498355a184a45c42c19fa58ff2f3
DIFF: https://github.com/llvm/llvm-project/commit/e7063e820637498355a184a45c42c19fa58ff2f3.diff
LOG: [mlir][spirv] Enforce `SPIRV_Vector` to have rank of one (#178185)
Currently only vector length is enforced however this allows vectors of
rank >1 to pass the verification as long as the length agrees. This
change restricts `SPIRV_Vector`s to be of rank 1 as required by the
SPIR-V spec.
This also fixes a bug where `SPIRV_Composite` allowed high ranked
vectors but `spirv::CompositeType` did not leading to cast assertions
where the composite type was assumed.
Finally, this change adds two new common constraints that can enforce
all three: rank, length and type.
fixes #178127
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/IR/CommonTypeConstraints.td
mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
mlir/test/Dialect/SPIRV/IR/group-ops.mlir
mlir/test/Dialect/SPIRV/IR/image-ops.mlir
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 4ea6d784dd88f..f8093d3042c50 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4242,7 +4242,7 @@ def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
-def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
+def SPIRV_Vector : VectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16],
[SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
@@ -4295,7 +4295,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
"Matrix">;
class SPIRV_VectorOf<Type type> :
- FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;
+ FixedVectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16], [type]>;
class SPIRV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
@@ -4314,7 +4314,7 @@ class SPIRV_MatrixOf<Type type> :
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
-class SPIRV_Vec4<Type type> : VectorOfLengthAndType<[4], [type]>;
+class SPIRV_Vec4<Type type> : VectorOfRankAndLengthAndType<[1], [4], [type]>;
def SPIRV_IntVec4 : SPIRV_Vec4<SPIRV_Integer>;
def SPIRV_IOrUIVec4 : SPIRV_Vec4<SPIRV_SignlessOrUnsignedInt>;
def SPIRV_Int32Vec4 : SPIRV_Vec4<AnyI32>;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 0fb4837e528be..a49880b81e90d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -649,6 +649,16 @@ class VectorOfLengthAndType<list<int> allowedLengths,
VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
+// Any vector where the number of elements is from the given
+// `allowedLengths` list and the type is from the given `allowedTypes`
+// list and the rank is from the given `allowedRanks` list
+class VectorOfRankAndLengthAndType<list<int> allowedRanks,
+ list<int> allowedLengths,
+ list<Type> allowedTypes> : AllOfType<
+ [VectorOfRank<allowedRanks>, VectorOfNonZeroRankOf<allowedTypes>, VectorOfLength<allowedLengths>],
+ VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary # VectorOfRank<allowedRanks>.summary,
+ "::mlir::VectorType">;
+
// Any vector where the number of elements is between
// `minLength` and `maxLength` (inclusive).
class VectorOfMinMaxLengthAndType<int minLength, int maxLength,
@@ -674,6 +684,18 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
FixedVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
+// Any fixed-length vector where the number of elements is from the given
+// `allowedLengths` list and the type is from the given `allowedTypes` list
+// as the rank is from the given `allowedRanks` list
+class FixedVectorOfRankAndLengthAndType<list<int> allowedRanks,
+ list<int> allowedLengths,
+ list<Type> allowedTypes> : AllOfType<
+ [VectorOfRank<allowedRanks>, FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
+ FixedVectorOfAnyRank<allowedTypes>.summary #
+ FixedVectorOfLength<allowedLengths>.summary #
+ VectorOfRank<allowedRanks>.summary,
+ "::mlir::VectorType">;
+
// Any scalable vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index e71b545de11df..9323518f50373 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -99,6 +99,14 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
// -----
+func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
+ // expected-error @+1 {{op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
+ %0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
+ return %0: vector<4x2xi1>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.CompositeExtractOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index bab12b183faf1..eea80ca3798a6 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -736,7 +736,7 @@ func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
// -----
func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) {
- // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
+ // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'vector<3xi32>'}}
%0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32>
return
}
@@ -1126,7 +1126,7 @@ func.func @lengthvec(%arg0 : vector<3xf32>) -> () {
// -----
func.func @length_i32_in(%arg0 : i32) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
%0 = spirv.GL.Length %arg0 : i32 -> f32
return
}
@@ -1142,7 +1142,7 @@ func.func @length_f16_in(%arg0 : f16) -> () {
// -----
func.func @length_i32vec_in(%arg0 : vector<3xi32>) -> () {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'vector<3xi32>'}}
%0 = spirv.GL.Length %arg0 : vector<3xi32> -> f32
return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
index d7a4a6d92fcd3..d26bfe9185bdd 100644
--- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
@@ -220,7 +220,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_scope(%value: vector<4xi32>)
// -----
func.func @group_non_uniform_ballot_bit_count_wrong_value_len(%value: vector<3xi32>) -> i32 {
- // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<3xi32>'}}
+ // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<3xi32>'}}
%0 = spirv.GroupNonUniformBallotBitCount <Subgroup> <InclusiveScan> %value : vector<3xi32> -> i32
return %0: i32
}
@@ -228,7 +228,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_value_len(%value: vector<3xi
// -----
func.func @group_non_uniform_ballot_bit_count_wrong_value_type(%value: vector<4xi8>) -> i32 {
- // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<4xi8>'}}
+ // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xi8>'}}
%0 = spirv.GroupNonUniformBallotBitCount <Subgroup> <InclusiveScan> %value : vector<4xi8> -> i32
return %0: i32
}
@@ -236,7 +236,7 @@ func.func @group_non_uniform_ballot_bit_count_wrong_value_type(%value: vector<4x
// -----
func.func @group_non_uniform_ballot_bit_count_value_sign(%value: vector<4xsi32>) -> i32 {
- // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
+ // expected-error @+1 {{operand #0 must be vector of 32-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xsi32>'}}
%0 = spirv.GroupNonUniformBallotBitCount <Subgroup> <InclusiveScan> %value : vector<4xsi32> -> i32
return %0: i32
}
diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
index 7369d719ca53d..12b5f2ce62a68 100644
--- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir
@@ -29,7 +29,7 @@ func.func @image_dref_gather_with_mismatch_imageoperands(%arg0 : !spirv.sampled_
// -----
func.func @image_dref_gather_error_result_type(%arg0 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %arg1 : vector<4xf32>, %arg2 : f32) -> () {
- // expected-error @+1 {{must be vector of 8/16/32/64-bit integer values of length 4 or vector of 16/32/64-bit float values of length 4}}
+ // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit integer values of length 4 of ranks 1 or vector of 16/32/64-bit float values of length 4 of ranks 1, but got 'vector<3xi32>'}}
%0 = spirv.ImageDrefGather %arg0, %arg1, %arg2 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<3xi32>
spirv.Return
}
@@ -326,7 +326,7 @@ func.func @image_fetch_type_mismatch(%arg0: !spirv.image<f32, Dim2D, NoDepth, No
// -----
func.func @image_fetch_2d_result(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, %arg1: vector<2xsi32>) -> () {
- // expected-error @+1 {{op result #0 must be vector of 16/32/64-bit float values of length 4 or vector of 8/16/32/64-bit integer values of length 4, but got 'vector<2xf32>'}}
+ // expected-error @+1 {{op result #0 must be vector of 16/32/64-bit float values of length 4 of ranks 1 or vector of 8/16/32/64-bit integer values of length 4 of ranks 1, but got 'vector<2xf32>'}}
%0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, vector<2xsi32> -> vector<2xf32>
spirv.Return
}
@@ -334,7 +334,7 @@ func.func @image_fetch_2d_result(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArr
// -----
func.func @image_fetch_float_coords(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, %arg1: vector<2xf32>) -> () {
- // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}}
+ // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'vector<2xf32>'}}
%0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, vector<2xf32> -> vector<2xf32>
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index 2e2fb1a9df328..d124c02231161 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
// -----
spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
- // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}}
+ // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got}}
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
spirv.Return
}
@@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
// -----
spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
- // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}}
+ // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got}}
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
spirv.Return
}
@@ -93,7 +93,7 @@ spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
// -----
spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
- // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
+ // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16 of ranks 1, but got 'f64'}}
%0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index d7f4ed05969aa..1018751cf65e0 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -184,7 +184,7 @@ func.func @logicalUnary(%arg0 : i1)
func.func @logicalUnary(%arg0 : i32)
{
- // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+ // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
%0 = spirv.LogicalNot %arg0 : i32
return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index b22951f90510a..168823a6e9c2d 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -21,7 +21,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
// -----
func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
- // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
+ // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4 of ranks 1, but got 'vector<4xsi32>'}}
%0 = spirv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xsi32>
return %0: vector<4xsi32>
}
@@ -185,7 +185,7 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto
// -----
func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'bf16'}}
%0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %val : bf16 -> bf16
return %0: bf16
}
@@ -206,7 +206,7 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 {
// -----
func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 {
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1, but got 'bf16'}}
%0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %val : bf16 -> bf16
return %0: bf16
}
@@ -511,7 +511,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 {
// -----
func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
%0 = spirv.GroupNonUniformBitwiseAnd <Workgroup> <Reduce> %val : i1 -> i1
return %0: i1
}
@@ -532,7 +532,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 {
// -----
func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
%0 = spirv.GroupNonUniformBitwiseOr <Workgroup> <Reduce> %val : i1 -> i1
return %0: i1
}
@@ -553,7 +553,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 {
// -----
func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 {
- // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}}
+ // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1, but got 'i1'}}
%0 = spirv.GroupNonUniformBitwiseXor <Workgroup> <Reduce> %val : i1 -> i1
return %0: i1
}
@@ -574,7 +574,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 {
// -----
func.func @group_non_uniform_logical_and(%val: i32) -> i32 {
- // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+ // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
%0 = spirv.GroupNonUniformLogicalAnd <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
@@ -595,7 +595,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 {
// -----
func.func @group_non_uniform_logical_or(%val: i32) -> i32 {
- // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+ // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
%0 = spirv.GroupNonUniformLogicalOr <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
@@ -616,7 +616,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 {
// -----
func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
- // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}}
+ // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'i32'}}
%0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32
return %0: i32
}
@@ -807,7 +807,7 @@ func.func @group_non_uniform_quad_swap(%value: vector<4xf32>) -> vector<4xf32> {
func.func @group_non_uniform_quad_swap(%value: !spirv.array<3 x i32>) -> !spirv.array<3 x i32> {
%dir = spirv.Constant 0 : i32
- // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 or bool or fixed-length vector of bool values of length 2/3/4/8/16, but got '!spirv.array<3 x i32>'}}
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16 of ranks 1 or 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 of ranks 1 or bool or fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got '!spirv.array<3 x i32>'}}
%0 = spirv.GroupNonUniformQuadSwap <Device> %value %dir : !spirv.array<3 x i32>, i32
return %0: !spirv.array<3 x i32>
}
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 20bb4eace370b..2c5dc8b9f3b0f 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -170,7 +170,7 @@ func.func @coop_matrix_const_wrong_type() -> () {
//===----------------------------------------------------------------------===//
func.func @ccr_result_not_composite() -> () {
- // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
+ // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
%0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
return
}
More information about the Mlir-commits
mailing list