[Mlir-commits] [mlir] [mlir][tosa] Add error_if checks for Mul Op (PR #135075)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 9 13:00:57 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This adds error_if validation checking for Mul Op


---
Full diff: https://github.com/llvm/llvm-project/pull/135075.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+34-1) 
- (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+30) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 28e562c813eb3..11eb0d969d78b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -979,8 +979,41 @@ bool checkErrorIfResize(Operation *op) {
   return true;
 }
 
+bool checkErrorIfMul(Operation *op) {
+  auto mul = dyn_cast<tosa::MulOp>(op);
+  if (!mul)
+    return true;
+
+  // REQUIRE(0 <= shift && shift <= 63);
+  // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
+  ElementsAttr shift_elem;
+  if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
+    return true;
+  }
+  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+  auto inputElemType = getElementTypeOrSelf(mul.getInput1());
+  if (inputElemType.isInteger(32)) {
+    // 0 <= shift <= 63 for int32_t type
+    if (shift < 0 || shift > 63) {
+      op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
+                        << shift;
+      return false;
+    }
+  } else {
+    // shift must be 0 for all other types
+    if (shift != 0) {
+      op->emitOpError() << "requires shift = 0 for all input data types that "
+                           "are not int32_t, but got: "
+                        << shift;
+      return false;
+    }
+  }
+
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
-  if (!checkErrorIfResize(op))
+  if (!checkErrorIfResize(op) || !checkErrorIfMul(op))
     return failure();
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index ce3ad04ea68ca..f7ca0faa8bc9e 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -83,3 +83,33 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
   %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
   return %1 : tensor<?x?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_mul_negative_shift
+func.func @test_mul_negative_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
+  %shift = "tosa.const" () { values = dense<-1> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: -1}}
+  %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+  return %mul : tensor<1x8x8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_too_big_shift
+func.func @test_mul_too_big_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
+  %shift = "tosa.const" () { values = dense<64> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: 64}}
+  %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+  return %mul : tensor<1x8x8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_non_zero_shift
+func.func @test_mul_non_zero_shift(%arg0: tensor<1x8x8x8xi16>, %arg1: tensor<1x8x8x8xi16>) -> tensor<1x8x8x8xi32> {
+  %shift = "tosa.const" () { values = dense<1> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op requires shift = 0 for all input data types that are not int32_t, but got: 1}}
+  %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi16>, tensor<1x8x8x8xi16>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+  return %mul : tensor<1x8x8x8xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/135075


More information about the Mlir-commits mailing list