[Mlir-commits] [mlir] [mlir][tosa] Add a verifier for `tosa.mul` (PR #113320)

Longsheng Mou llvmlistbot at llvm.org
Tue Oct 22 07:19:03 PDT 2024


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/113320

This PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types.
Fixes #112716.

>From 33c1c10ae79d9f2bf8e6b75bf580981eb5dbb7fd Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Tue, 22 Oct 2024 22:13:53 +0800
Subject: [PATCH] [mlir][tosa] Add a verifier for `tosa.mul`

This PR adds a verifier check for tosa.mul, requiring that the shift be
0 for float types.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td      |  1 +
 mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 14 ++++----------
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp              |  8 ++++++++
 mlir/test/Dialect/Tosa/invalid.mlir               |  9 +++++++++
 mlir/test/Dialect/Tosa/ops.mlir                   |  2 +-
 5 files changed, 23 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3bb5ceb0f4695b..6e7d575ac26df1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -811,6 +811,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c88f4db27c304e..495f1b4f10b028 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -78,16 +78,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
     return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
 
-  // tosa::MulOp
-  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
-    if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
-      (void)rewriter.notifyMatchFailure(op,
-                                        "Cannot have shift value for float");
-      return nullptr;
-    }
-    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-  }
-
   // tosa::IntDivOp
   if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
     return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
@@ -99,6 +89,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
     return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
   }
 
+  // tosa::MulOp
+  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
+    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
+
   if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
     Value a = args[0];
     Value b = args[1];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1f3e19fe90c6db..631d3c48f2df02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -865,6 +865,14 @@ LogicalResult tosa::SliceOp::verify() {
   return success();
 }
 
+LogicalResult tosa::MulOp::verify() {
+  Type elementTy = getInput1().getType().getElementType();
+  if (isa<FloatType>(elementTy) && getShift() != 0)
+    return emitOpError() << "require shift to be 0 for float type";
+
+  return success();
+}
+
 LogicalResult tosa::TableOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TableOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b9298b66643538..f1b1707a0c40d9 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -609,3 +609,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
   %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
   return %0 : tensor<1x32x32x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_mul_invalid_shift
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  // expected-error at +1 {{'tosa.mul' op require shift to be 0 for float type}}
+  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a1600fd33c54b4..a756588a7cc0db 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 



More information about the Mlir-commits mailing list