[Mlir-commits] [mlir] cdb0623 - [mlir][tosa] Add tosa.mul by one canonicalization
Rob Suderman
llvmlistbot at llvm.org
Mon Nov 15 14:53:16 PST 2021
Author: not-jenni
Date: 2021-11-15T14:52:16-08:00
New Revision: cdb0623ad82751c87d9c7d4b2c6eaf0c5ccf3719
URL: https://github.com/llvm/llvm-project/commit/cdb0623ad82751c87d9c7d4b2c6eaf0c5ccf3719
DIFF: https://github.com/llvm/llvm-project/commit/cdb0623ad82751c87d9c7d4b2c6eaf0c5ccf3719.diff
LOG: [mlir][tosa] Add tosa.mul by one canonicalization
Multiply by one can be removed during canonicalization. This optimizes away unneeded operations.
Differential Revision: https://reviews.llvm.org/D113807
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index bdc7ac13e675a..0e6c0d2560e88 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -760,6 +760,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [
let results = (outs
Tosa_Tensor:$output
);
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -1413,7 +1415,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
Tosa_RankedTensor:$output
);
- let builders = [Tosa_PadOpQuantInfoBuilder,
+ let builders = [Tosa_PadOpQuantInfoBuilder,
Tosa_ExplicitValuePadOpQuantInfoBuilder];
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2a435476e5bef..e32be42a31015 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -339,6 +339,55 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<AddZeroOptimization>(context);
}
+struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::MulOp op,
+ PatternRewriter &rewriter) const override {
+ auto input1 = op.input1();
+ auto input2 = op.input2();
+
+ DenseElementsAttr input1Attr;
+ if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
+ input2.getType() == op.getType()) {
+ if (input1Attr.getType().getElementType().isa<FloatType>() &&
+ input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
+ rewriter.replaceOp(op, op.input2());
+ return success();
+ }
+
+ if (input1Attr.getType().getElementType().isa<IntegerType>() &&
+ matchPattern(input1, m_One())) {
+ rewriter.replaceOp(op, op.input2());
+ return success();
+ }
+ }
+
+ DenseElementsAttr input2Attr;
+ if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
+ input1.getType() == op.getType()) {
+ if (input2Attr.getType().getElementType().isa<FloatType>() &&
+ input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
+ rewriter.replaceOp(op, op.input1());
+ return success();
+ }
+
+ if (input2Attr.getType().getElementType().isa<IntegerType>() &&
+ matchPattern(input2, m_One())) {
+ rewriter.replaceOp(op, op.input1());
+ return success();
+ }
+ }
+
+ return failure();
+ }
+};
+
+void MulOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<MulOneOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e4614853c71e4..11583eef7c867 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -78,6 +78,38 @@ func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
// -----
+// CHECK-LABEL: @mul_one_
diff erent_shape
+func @mul_one_
diff erent_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
+ // CHECK: tosa.mul
+ %ones = "tosa.const"() {value = dense<1.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
+ %1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
+ return %1 : tensor<4x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_float
+func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.mul
+ %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+ %1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ return %1 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_int
+func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.mul
+ %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+ %1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ return %1 : tensor<2x3xi32>
+}
+
+// -----
+
// CHECK-LABEL: @reduce_all_fold
func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
More information about the Mlir-commits
mailing list