[Mlir-commits] [mlir] [mlir][tosa] Add error_if checks for Mul Op (PR #135075)

Tai Ly llvmlistbot at llvm.org
Wed Apr 9 13:00:24 PDT 2025


https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/135075

This adds error_if validation checking for Mul Op


>From e267a1ee61d2fbc3a3fa362211d1cf7f69c2636f Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Mon, 7 Apr 2025 20:50:46 +0000
Subject: [PATCH] [mlir][tosa] Add error_if checks for Mul Op

This adds error_if validation checking for Mul Op

Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: Iff40d52c63e2edb31f29b4dff6db2348a87a0b35
---
 .../Tosa/Transforms/TosaValidation.cpp        | 35 ++++++++++++++++++-
 mlir/test/Dialect/Tosa/error_if_check.mlir    | 30 ++++++++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 28e562c813eb3..11eb0d969d78b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -979,8 +979,41 @@ bool checkErrorIfResize(Operation *op) {
   return true;
 }
 
+bool checkErrorIfMul(Operation *op) {
+  auto mul = dyn_cast<tosa::MulOp>(op);
+  if (!mul)
+    return true;
+
+  // REQUIRE(0 <= shift && shift <= 63);
+  // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
+  ElementsAttr shift_elem;
+  if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
+    return true;
+  }
+  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+  auto inputElemType = getElementTypeOrSelf(mul.getInput1());
+  if (inputElemType.isInteger(32)) {
+    // 0 <= shift <= 63 for int32_t type
+    if (shift < 0 || shift > 63) {
+      op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
+                        << shift;
+      return false;
+    }
+  } else {
+    // shift must be 0 for all other types
+    if (shift != 0) {
+      op->emitOpError() << "requires shift = 0 for all input data types that "
+                           "are not int32_t, but got: "
+                        << shift;
+      return false;
+    }
+  }
+
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
-  if (!checkErrorIfResize(op))
+  if (!checkErrorIfResize(op) || !checkErrorIfMul(op))
     return failure();
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index ce3ad04ea68ca..f7ca0faa8bc9e 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -83,3 +83,33 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
   %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
   return %1 : tensor<?x?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_mul_negative_shift
+func.func @test_mul_negative_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
+  %shift = "tosa.const" () { values = dense<-1> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: -1}}
+  %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+  return %mul : tensor<1x8x8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_too_big_shift
+func.func @test_mul_too_big_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
+  %shift = "tosa.const" () { values = dense<64> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op requires 0 <= shift && shift <= 63, but got: 64}}
+  %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi32>, tensor<1x8x8x8xi32>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+  return %mul : tensor<1x8x8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_non_zero_shift
+func.func @test_mul_non_zero_shift(%arg0: tensor<1x8x8x8xi16>, %arg1: tensor<1x8x8x8xi16>) -> tensor<1x8x8x8xi32> {
+  %shift = "tosa.const" () { values = dense<1> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op requires shift = 0 for all input data types that are not int32_t, but got: 1}}
+  %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi16>, tensor<1x8x8x8xi16>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
+  return %mul : tensor<1x8x8x8xi32>
+}



More information about the Mlir-commits mailing list