[Mlir-commits] [mlir] [mlir][tosa] Add error_if checks for Mul Op (PR #135075)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 9 13:00:57 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
This adds error_if validation checking for Mul Op
---
Full diff: https://github.com/llvm/llvm-project/pull/135075.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+34-1)
- (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+30)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 28e562c813eb3..11eb0d969d78b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -979,8 +979,41 @@ bool checkErrorIfResize(Operation *op) {
return true;
}
+bool checkErrorIfMul(Operation *op) {
+ auto mul = dyn_cast<tosa::MulOp>(op);
+ if (!mul)
+ return true;
+
+ // REQUIRE(0 <= shift && shift <= 63);
+ // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
+ ElementsAttr shift_elem;
+ if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
+ return true;
+ }
+ int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ auto inputElemType = getElementTypeOrSelf(mul.getInput1());
+ if (inputElemType.isInteger(32)) {
+ // 0 <= shift <= 63 for int32_t type
+ if (shift < 0 || shift > 63) {
+ op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
+ << shift;
+ return false;
+ }
+ } else {
+ // shift must be 0 for all other types
+ if (shift != 0) {
+ op->emitOpError() << "requires shift = 0 for all input data types that "
+ "are not int32_t, but got: "
+ << shift;
+ return false;
+ }
+ }
+
+ return true;
+}
+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
- if (!checkErrorIfResize(op))
+ if (!checkErrorIfResize(op) || !checkErrorIfMul(op))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index ce3ad04ea68ca..f7ca0faa8bc9e 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -83,3 +83,33 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
%1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_mul_negative_shift
+func.func @test_mul_negative_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
+ %shift = "tosa.const" () { values = dense<-1> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: -1}}
+ %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+ return %mul : tensor<1x8x8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_too_big_shift
+func.func @test_mul_too_big_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
+ %shift = "tosa.const" () { values = dense<64> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: 64}}
+ %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+ return %mul : tensor<1x8x8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_non_zero_shift
+func.func @test_mul_non_zero_shift(%arg0: tensor<1x8x8x8xi16>, %arg1: tensor<1x8x8x8xi16>) -> tensor<1x8x8x8xi32> {
+ %shift = "tosa.const" () { values = dense<1> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op requires shift = 0 for all input data types that are not int32_t, but got: 1}}
+ %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi16>, tensor<1x8x8x8xi16>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+ return %mul : tensor<1x8x8x8xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/135075
More information about the Mlir-commits
mailing list