[Mlir-commits] [mlir] [mlir][tosa] Improve matmul verifier to check shape information (PR #191300)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 9 14:22:00 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tosa
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
Updates the matmul verifier to check input and output shapes are valid.
Also adds some tests for verifier failures which were previously not covered.
---
Full diff: https://github.com/llvm/llvm-project/pull/191300.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+50-23)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (-40)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+137)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index f3413ac6d0d29..546201484e8cb 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1879,24 +1879,14 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
}
LogicalResult MatMulOp::verify() {
- auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
- auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+ const ShapeAdaptor aShape(getA().getType());
+ const ShapeAdaptor bShape(getB().getType());
+ const Type aElementType = aShape.getElementType();
+ const Type bElementType = bShape.getElementType();
- // Must be shaped tensor types
- if (!aType)
- return emitOpError("expect a shaped tensor for input a, got ")
- << getA().getType();
-
- if (!bType)
- return emitOpError("expect a shaped tensor for input b, got ")
- << getB().getType();
-
- auto aElementType = aType.getElementType();
- auto bElementType = bType.getElementType();
-
- auto aQuantizedEType =
+ const auto aQuantizedEType =
llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
- auto bQuantizedEType =
+ const auto bQuantizedEType =
llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
if (aQuantizedEType || bQuantizedEType) {
@@ -1915,21 +1905,19 @@ LogicalResult MatMulOp::verify() {
}
// check a_zp and b_zp
- auto aEType = getStorageElementTypeOrSelf(aType);
+ auto aEType = getStorageElementTypeOrSelf(aElementType);
auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
- if (aEType != aZpEType) {
+ if (aEType != aZpEType)
return emitOpError("expect input a and a_zp have the same "
"element type, got ")
<< aEType << " and " << aZpEType;
- }
- auto bEType = getStorageElementTypeOrSelf(bType);
- auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
- if (bEType != bZpEType) {
+ const Type bEType = getStorageElementTypeOrSelf(bElementType);
+ const Type bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
+ if (bEType != bZpEType)
return emitOpError("expect input b and b_zp have the same "
"element type, got ")
<< bEType << " and " << bZpEType;
- }
FailureOr<int64_t> maybeAZp = getAZeroPoint();
if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
@@ -1939,6 +1927,45 @@ LogicalResult MatMulOp::verify() {
if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
return failure();
+ // Verify input/output shapes
+ int64_t N = ShapedType::kDynamic;
+ int64_t H = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+
+ if (aShape.hasRank()) {
+ N = aShape.getDimSize(0);
+ H = aShape.getDimSize(1);
+ C = aShape.getDimSize(2);
+ }
+
+ if (bShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, N, bShape.getDimSize(0), "b",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, C, bShape.getDimSize(1), "b",
+ "channels")))
+ return failure();
+ W = bShape.getDimSize(2);
+ }
+
+ const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
+ const auto outputType = cast<ShapedType>(getResult().getType());
+ if (outputType.hasRank() &&
+ failed(
+ verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
+ InFlightDiagnostic opError = emitOpError("expected output shape ");
+ auto stringifyDim = [&](int64_t d) {
+ if (ShapedType::isDynamic(d))
+ opError << "?";
+ else
+ opError << d;
+ };
+ llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
+ opError << " to be compatible with expected output shape ";
+ llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
+ return opError;
+ }
+
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b7334fb4246a7..e18cc40b78e8b 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1698,46 +1698,6 @@ func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tenso
return %0 : tensor<1xi16>
}
-// -----
-// CHECK-LABEL: test_matmul_a_zp_same_element_type
-func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
-%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
-%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-// expected-error at +1 {{'tosa.matmul' op expect input a and a_zp have the same element type, got 'f32' and 'f16'}}
-%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf16>, tensor<1xf32>) -> tensor<1x14x28xf32>
- return %0 : tensor<1x14x28xf32>
-}
-
-// -----
-// CHECK-LABEL: test_matmul_b_zp_same_element_type
-func.func @test_matmul_b_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
-%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
-// expected-error at +1 {{'tosa.matmul' op expect input b and b_zp have the same element type, got 'f32' and 'f16'}}
-%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf16>) -> tensor<1x14x28xf32>
- return %0 : tensor<1x14x28xf32>
-}
-
-// -----
-// CHECK-LABEL: test_matmul_a_zp_non_zero
-func.func @test_matmul_a_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
-%azp0 = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-// expected-error at +1 {{'tosa.matmul' op a zero point must be zero for non-int8 integer types}}
-%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
- return %0 : tensor<1x14x28xf32>
-}
-
-// -----
-// CHECK-LABEL: test_matmul_b_zp_non_zero
-func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
-%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-%bzp0 = "tosa.const"() <{values = dense<-1.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-// expected-error at +1 {{'tosa.matmul' op b zero point must be zero for non-int8 integer types}}
-%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
- return %0 : tensor<1x14x28xf32>
-}
-
// -----
// CHECK-LABEL: test_negate_same_element_type
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 80d5bca039909..931659fb435cc 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1147,6 +1147,143 @@ func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32
// -----
+func.func @test_matmul_output_batch_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<5x4x6xf32>) -> tensor<2x3x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected batch of b to match size 2, got 5}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<5x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x3x6xf32>
+ return %0 : tensor<2x3x6xf32>
+}
+
+// -----
+
+func.func @test_matmul_output_channel_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x7x6xf32>) -> tensor<2x3x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected channels of b to match size 4, got 7}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<2x7x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x3x6xf32>
+ return %0 : tensor<2x3x6xf32>
+}
+
+// -----
+
+func.func @test_matmul_output_shape_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x4x6xf32>) -> tensor<2x5x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected output shape 2, 5, 6 to be compatible with expected output shape 2, 3, 6}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<2x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x5x6xf32>
+ return %0 : tensor<2x5x6xf32>
+}
+
+// -----
+
+
+func.func @test_matmul_dynamic_batch_mismatch(%arg0: tensor<2x?x4xf32>, %arg1: tensor<5x4x6xf32>) -> tensor<2x?x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected batch of b to match size 2, got 5}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x?x4xf32>, tensor<5x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x6xf32>
+ return %0 : tensor<2x?x6xf32>
+}
+
+// -----
+
+func.func @test_matmul_dynamic_channel_mismatch(%arg0: tensor<?x3x4xf32>, %arg1: tensor<?x7x6xf32>) -> tensor<?x3x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected channels of b to match size 4, got 7}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<?x3x4xf32>, tensor<?x7x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x3x6xf32>
+ return %0 : tensor<?x3x6xf32>
+}
+
+// -----
+
+func.func @test_matmul_dynamic_output_shape_mismatch(%arg0: tensor<?x3x4xf32>, %arg1: tensor<2x4x6xf32>) -> tensor<5x3x6xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected output shape 5, 3, 6 to be compatible with expected output shape 2, 3, 6}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<?x3x4xf32>, tensor<2x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x3x6xf32>
+ return %0 : tensor<5x3x6xf32>
+}
+
+// -----
+
+
+func.func @test_matmul_unranked_b_output_shape_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<*xf32>) -> tensor<2x5x?xf32> {
+ %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expected output shape 2, 5, ? to be compatible with expected output shape 2, 3, ?}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<*xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x5x?xf32>
+ return %0 : tensor<2x5x?xf32>
+}
+
+// -----
+
+
+func.func @test_matmul_quantized_mixed_operands(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 0.125>>, %arg1: tensor<2x4x6xf32>) -> tensor<2x3x6xi32> {
+ %azp0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+ // expected-error at +1 {{'tosa.matmul' op expect operands to be both quantized or both not quantized, got '!quant.uniform<i8:f32, 1.250000e-01>' and 'f32'}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4x!quant.uniform<i8:f32, 0.125>>, tensor<2x4x6xf32>, tensor<1xi8>, tensor<1xf32>) -> tensor<2x3x6xi32>
+ return %0 : tensor<2x3x6xi32>
+}
+
+// -----
+
+func.func @test_matmul_quantized_width_mismatch(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 0.125>>, %arg1: tensor<2x4x6x!quant.uniform<i16:f32, 0.125>>) -> tensor<2x3x6xi32> {
+ %azp0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %bzp0 = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+ // expected-error at +1 {{'tosa.matmul' op expect quantized operands to have same widths, got 8 and 16}}
+ %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4x!quant.uniform<i8:f32, 0.125>>, tensor<2x4x6x!quant.uniform<i16:f32, 0.125>>, tensor<1xi8>, tensor<1xi16>) -> tensor<2x3x6xi32>
+ return %0 : tensor<2x3x6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_a_zp_same_element_type
+func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+// expected-error at +1 {{'tosa.matmul' op expect input a and a_zp have the same element type, got 'f32' and 'f16'}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf16>, tensor<1xf32>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_b_zp_same_element_type
+func.func @test_matmul_b_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+// expected-error at +1 {{'tosa.matmul' op expect input b and b_zp have the same element type, got 'f32' and 'f16'}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf16>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_a_zp_non_zero
+func.func @test_matmul_a_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+// expected-error at +1 {{'tosa.matmul' op a zero point must be zero for non-int8 integer types}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_matmul_b_zp_non_zero
+func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
+%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+%bzp0 = "tosa.const"() <{values = dense<-1.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+// expected-error at +1 {{'tosa.matmul' op b zero point must be zero for non-int8 integer types}}
+%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32>
+ return %0 : tensor<1x14x28xf32>
+}
+
+// -----
+
func.func @test_matmul_t_block_scaled_data_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E5M2>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
// expected-error at +1 {{'tosa.matmul_t_block_scaled' op expect A_data and B_data to have same element type, got 'f8E4M3FN' and 'f8E5M2'}}
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E5M2>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/191300
More information about the Mlir-commits
mailing list