[Mlir-commits] [mlir] [mlir][tosa] Improve matmul verifier to check shape information (PR #191300)
Luke Hutton
llvmlistbot at llvm.org
Thu Apr 9 14:21:24 PDT 2026
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/191300
Updates the matmul verifier to check input and output shapes are valid.
Also adds some tests for verifier failures which were previously not covered.
>From 9c3acc1e12bfcb4accfca52979be15c8a94e80b0 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 9 Apr 2026 21:13:01 +0000
Subject: [PATCH] [mlir][tosa] Improve matmul verifier to check shape
information
Updates the matmul verifier to check input and output shapes
are valid.
Also adds some tests for verifier failures which were previously
not covered.
Change-Id: I9c8a0991e76489f3b73258bdd446d746b53e8934
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 73 +++++++++-----
mlir/test/Dialect/Tosa/invalid.mlir | 40 --------
mlir/test/Dialect/Tosa/verifier.mlir | 137 +++++++++++++++++++++++++++
3 files changed, 187 insertions(+), 63 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index f3413ac6d0d29..546201484e8cb 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1879,24 +1879,14 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
}
LogicalResult MatMulOp::verify() {
- auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
- auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+ const ShapeAdaptor aShape(getA().getType());
+ const ShapeAdaptor bShape(getB().getType());
+ const Type aElementType = aShape.getElementType();
+ const Type bElementType = bShape.getElementType();
- // Must be shaped tensor types
- if (!aType)
- return emitOpError("expect a shaped tensor for input a, got ")
- << getA().getType();
-
- if (!bType)
- return emitOpError("expect a shaped tensor for input b, got ")
- << getB().getType();
-
- auto aElementType = aType.getElementType();
- auto bElementType = bType.getElementType();
-
- auto aQuantizedEType =
+ const auto aQuantizedEType =
llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
- auto bQuantizedEType =
+ const auto bQuantizedEType =
llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
if (aQuantizedEType || bQuantizedEType) {
@@ -1915,21 +1905,19 @@ LogicalResult MatMulOp::verify() {
}
// check a_zp and b_zp
- auto aEType = getStorageElementTypeOrSelf(aType);
+ auto aEType = getStorageElementTypeOrSelf(aElementType);
auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
- if (aEType != aZpEType) {
+ if (aEType != aZpEType)
return emitOpError("expect input a and a_zp have the same "
"element type, got ")
<< aEType << " and " << aZpEType;
- }
- auto bEType = getStorageElementTypeOrSelf(bType);
- auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
- if (bEType != bZpEType) {
+ const Type bEType = getStorageElementTypeOrSelf(bElementType);
+ const Type bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
+ if (bEType != bZpEType)
return emitOpError("expect input b and b_zp have the same "
"element type, got ")
<< bEType << " and " << bZpEType;
- }
FailureOr<int64_t> maybeAZp = getAZeroPoint();
if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
@@ -1939,6 +1927,45 @@ LogicalResult MatMulOp::verify() {
if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
return failure();
+ // Verify input/output shapes
+ int64_t N = ShapedType::kDynamic;
+ int64_t H = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+
+ if (aShape.hasRank()) {
+ N = aShape.getDimSize(0);
+ H = aShape.getDimSize(1);
+ C = aShape.getDimSize(2);
+ }
+
+ if (bShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, N, bShape.getDimSize(0), "b",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, C, bShape.getDimSize(1), "b",
+ "channels")))
+ return failure();
+ W = bShape.getDimSize(2);
+ }
+
+ const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
+ const auto outputType = cast<ShapedType>(getResult().getType());
+ if (outputType.hasRank() &&
+ failed(
+ verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
+ InFlightDiagnostic opError = emitOpError("expected output shape ");
+ auto stringifyDim = [&](int64_t d) {
+ if (ShapedType::isDynamic(d))
+ opError << "?";
+ else
+ opError << d;
+ };
+ llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
+ opError << " to be compatible with expected output shape ";
+ llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
+ return opError;
+ }
+
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b7334fb4246a7..e18cc40b78e8b 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1698,46 +1698,6 @@ func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tenso
return %0 : tensor<1xi16>
}
-// -----
-// CHECK-LABEL: test_matmul_a_zp_same_element_type
-func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
-%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
-%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-// expected-error at +1 {{'tosa.matmul' op expect input a and a_zp have the same element type, got 'f32' and 'f16'}}
-%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf16>, tensor<1xf32>) -> tensor<1x14x28xf32>
- return %0 : tensor<1x14x28xf32>
-}
-
-// -----
-// CHECK-LABEL: test_matmul_b_zp_same_element_type
-func.func @test_matmul_b_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
-%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
-// expected-error at +1 {{'tosa.matmul' op expect input b and b_zp have the same element type, got 'f32' and 'f16'}}
-%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf16>) -> tensor<1x14x28xf32>
- return %0 : tensor<1x14x28xf32>
-}
-
-// -----
-// CHECK-LABEL: test_matmul_a_zp_non_zero
-func.func @test_matmul_a_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
-%azp0 = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-// expected-error at +1 {{'tosa.matmul' op a zero point must be zero for non-int8 integer types}}
-%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
- return %0 : tensor<1x14x28xf32>
-}
-
-// -----
-// CHECK-LABEL: test_matmul_b_zp_non_zero
-func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
-%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-%bzp0 = "tosa.const"() <{values = dense<-1.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-// expected-error at +1 {{'tosa.matmul' op b zero point must be zero for non-int8 integer types}}
-%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
- return %0 : tensor<1x14x28xf32>
-}
-
// -----
// CHECK-LABEL: test_negate_same_element_type
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 80d5bca039909..931659fb435cc 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1147,6 +1147,143 @@ func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32
// -----
+func.func @test_matmul_output_batch_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<5x4x6xf32>) -> tensor<2x3x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected batch of b to match size 2, got 5}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<5x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x3x6xf32>
+ return %0 : tensor<2x3x6xf32>
+}
+
+// -----
+
+func.func @test_matmul_output_channel_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x7x6xf32>) -> tensor<2x3x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected channels of b to match size 4, got 7}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<2x7x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x3x6xf32>
+ return %0 : tensor<2x3x6xf32>
+}
+
+// -----
+
+func.func @test_matmul_output_shape_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x4x6xf32>) -> tensor<2x5x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected output shape 2, 5, 6 to be compatible with expected output shape 2, 3, 6}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<2x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x5x6xf32>
+ return %0 : tensor<2x5x6xf32>
+}
+
+// -----
+
+
+func.func @test_matmul_dynamic_batch_mismatch(%arg0: tensor<2x?x4xf32>, %arg1: tensor<5x4x6xf32>) -> tensor<2x?x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected batch of b to match size 2, got 5}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x?x4xf32>, tensor<5x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x6xf32>
+ return %0 : tensor<2x?x6xf32>
+}
+
+// -----
+
+func.func @test_matmul_dynamic_channel_mismatch(%arg0: tensor<?x3x4xf32>, %arg1: tensor<?x7x6xf32>) -> tensor<?x3x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected channels of b to match size 4, got 7}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<?x3x4xf32>, tensor<?x7x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x3x6xf32>
+ return %0 : tensor<?x3x6xf32>
+}
+
+// -----
+
+func.func @test_matmul_dynamic_output_shape_mismatch(%arg0: tensor<?x3x4xf32>, %arg1: tensor<2x4x6xf32>) -> tensor<5x3x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected output shape 5, 3, 6 to be compatible with expected output shape 2, 3, 6}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<?x3x4xf32>, tensor<2x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x3x6xf32>
+ return %0 : tensor<5x3x6xf32>
+}
+
+// -----
+
+
+func.func @test_matmul_unranked_b_output_shape_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<*xf32>) -> tensor<2x5x?xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected output shape 2, 5, ? to be compatible with expected output shape 2, 3, ?}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<*xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x5x?xf32>
+ return %0 : tensor<2x5x?xf32>
+}
+
+// -----
+
+
+func.func @test_matmul_quantized_mixed_operands(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 0.125>>, %arg1: tensor<2x4x6xf32>) -> tensor<2x3x6xi32> {
+ %azp0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expect operands to be both quantized or both not quantized, got '!quant.uniform<i8:f32, 1.250000e-01>' and 'f32'}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4x!quant.uniform<i8:f32, 0.125>>, tensor<2x4x6xf32>, tensor<1xi8>, tensor<1xf32>) -> tensor<2x3x6xi32>
+ return %0 : tensor<2x3x6xi32>
+}
+
+// -----
+
+func.func @test_matmul_quantized_width_mismatch(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 0.125>>, %arg1: tensor<2x4x6x!quant.uniform<i16:f32, 0.125>>) -> tensor<2x3x6xi32> {
+ %azp0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %bzp0 = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+ // expected-error at +1 {{'tosa.matmul' op expect quantized operands to have same widths, got 8 and 16}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4x!quant.uniform<i8:f32, 0.125>>, tensor<2x4x6x!quant.uniform<i16:f32, 0.125>>, tensor<1xi8>, tensor<1xi16>) -> tensor<2x3x6xi32>
+ return %0 : tensor<2x3x6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_a_zp_same_element_type
+func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+// expected-error at +1 {{'tosa.matmul' op expect input a and a_zp have the same element type, got 'f32' and 'f16'}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf16>, tensor<1xf32>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_b_zp_same_element_type
+func.func @test_matmul_b_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+// expected-error at +1 {{'tosa.matmul' op expect input b and b_zp have the same element type, got 'f32' and 'f16'}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf16>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_a_zp_non_zero
+func.func @test_matmul_a_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+// expected-error at +1 {{'tosa.matmul' op a zero point must be zero for non-int8 integer types}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_b_zp_non_zero
+func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%bzp0 = "tosa.const"() <{values = dense<-1.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+// expected-error at +1 {{'tosa.matmul' op b zero point must be zero for non-int8 integer types}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
func.func @test_matmul_t_block_scaled_data_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E5M2>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
// expected-error at +1 {{'tosa.matmul_t_block_scaled' op expect A_data and B_data to have same element type, got 'f8E4M3FN' and 'f8E5M2'}}
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E5M2>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
More information about the Mlir-commits
mailing list