[Mlir-commits] [mlir] 519eef3 - [mlir][tosa] Add a verifier for `tosa.mul` (#113320)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 22 14:34:08 PDT 2024
Author: Longsheng Mou
Date: 2024-10-22T22:34:04+01:00
New Revision: 519eef3bdc3c17ac9b59933187e1d7bdc6c2729d
URL: https://github.com/llvm/llvm-project/commit/519eef3bdc3c17ac9b59933187e1d7bdc6c2729d
DIFF: https://github.com/llvm/llvm-project/commit/519eef3bdc3c17ac9b59933187e1d7bdc6c2729d.diff
LOG: [mlir][tosa] Add a verifier for `tosa.mul` (#113320)
This PR adds a verifier check for tosa.mul, requiring that the shift be
0 for float types.
Fixes #112716.
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/invalid.mlir
mlir/test/Dialect/Tosa/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3bb5ceb0f4695b..6e7d575ac26df1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -811,6 +811,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
);
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c88f4db27c304e..495f1b4f10b028 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -78,16 +78,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
- // tosa::MulOp
- if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
- if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
- (void)rewriter.notifyMatchFailure(op,
- "Cannot have shift value for float");
- return nullptr;
- }
- return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
- }
-
// tosa::IntDivOp
if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
@@ -99,6 +89,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
}
+ // tosa::MulOp
+ if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
+ return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
+
if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
Value a = args[0];
Value b = args[1];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1f3e19fe90c6db..631d3c48f2df02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -865,6 +865,14 @@ LogicalResult tosa::SliceOp::verify() {
return success();
}
+LogicalResult tosa::MulOp::verify() {
+ Type elementTy = getInput1().getType().getElementType();
+ if (isa<FloatType>(elementTy) && getShift() != 0)
+ return emitOpError() << "require shift to be 0 for float type";
+
+ return success();
+}
+
LogicalResult tosa::TableOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TableOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b9298b66643538..f1b1707a0c40d9 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -609,3 +609,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_mul_invalid_shift
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ // expected-error at +1 {{'tosa.mul' op require shift to be 0 for float type}}
+ %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a1600fd33c54b4..a756588a7cc0db 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
// -----
// CHECK-LABEL: mul
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
More information about the Mlir-commits
mailing list