[Mlir-commits] [mlir] [mlir][tosa] Add a verifier for `tosa.mul` (PR #113320)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 22 07:19:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types.
Fixes #<!-- -->112716.

---
Full diff: https://github.com/llvm/llvm-project/pull/113320.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+4-10) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+8) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+9) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3bb5ceb0f4695b..6e7d575ac26df1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -811,6 +811,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c88f4db27c304e..495f1b4f10b028 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -78,16 +78,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
     return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
 
-  // tosa::MulOp
-  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
-    if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
-      (void)rewriter.notifyMatchFailure(op,
-                                        "Cannot have shift value for float");
-      return nullptr;
-    }
-    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-  }
-
   // tosa::IntDivOp
   if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
     return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
@@ -99,6 +89,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
     return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
   }
 
+  // tosa::MulOp
+  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
+    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
+
   if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
     Value a = args[0];
     Value b = args[1];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1f3e19fe90c6db..631d3c48f2df02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -865,6 +865,14 @@ LogicalResult tosa::SliceOp::verify() {
   return success();
 }
 
+LogicalResult tosa::MulOp::verify() {
+  Type elementTy = getInput1().getType().getElementType();
+  if (isa<FloatType>(elementTy) && getShift() != 0)
+    return emitOpError() << "require shift to be 0 for float type";
+
+  return success();
+}
+
 LogicalResult tosa::TableOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TableOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b9298b66643538..f1b1707a0c40d9 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -609,3 +609,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
   %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
   return %0 : tensor<1x32x32x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_mul_invalid_shift
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.mul' op require shift to be 0 for float type}}
+  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a1600fd33c54b4..a756588a7cc0db 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/113320


More information about the Mlir-commits mailing list