[Mlir-commits] [mlir] 07a029c - Canonicalization for add to no-op if one of the inputs is zero
Rob Suderman
llvmlistbot at llvm.org
Thu Nov 4 17:01:29 PDT 2021
Author: not-jenni
Date: 2021-11-04T16:52:47-07:00
New Revision: 07a029c0577864827bad472364145844dfaf2d24
URL: https://github.com/llvm/llvm-project/commit/07a029c0577864827bad472364145844dfaf2d24
DIFF: https://github.com/llvm/llvm-project/commit/07a029c0577864827bad472364145844dfaf2d24.diff
LOG: Canonicalization for add to no-op if one of the inputs is zero
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D113207
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b57e8b2fb8cb..4dc678e31ddb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -442,6 +442,8 @@ def Tosa_AddOp : Tosa_Op<"add", [
let results = (outs
Tosa_Tensor:$output
);
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 90146f5bc29b..07d3e6e67c21 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -289,6 +289,55 @@ void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<NoOpOptimization>(context);
}
+struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::AddOp op,
+ PatternRewriter &rewriter) const override {
+ auto input1 = op.input1();
+ auto input2 = op.input2();
+
+ DenseElementsAttr input1Attr;
+ if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
+ input2.getType() == op.getType()) {
+ if (input1Attr.getType().getElementType().isa<FloatType>() &&
+ input1Attr.getSplatValue<APFloat>().isZero()) {
+ rewriter.replaceOp(op, op.input2());
+ return success();
+ }
+
+ if (input1Attr.getType().getElementType().isa<IntegerType>() &&
+ input1Attr.getSplatValue<APInt>().isZero()) {
+ rewriter.replaceOp(op, op.input2());
+ return success();
+ }
+ }
+
+ DenseElementsAttr input2Attr;
+ if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
+ input1.getType() == op.getType()) {
+ if (input2Attr.getType().getElementType().isa<FloatType>() &&
+ input2Attr.getSplatValue<APFloat>().isZero()) {
+ rewriter.replaceOp(op, op.input1());
+ return success();
+ }
+
+ if (input2Attr.getType().getElementType().isa<IntegerType>() &&
+ input2Attr.getSplatValue<APInt>().isZero()) {
+ rewriter.replaceOp(op, op.input1());
+ return success();
+ }
+ }
+
+ return failure();
+ }
+};
+
+void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<AddZeroOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e6cf1a15ac67..e4614853c71e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -9,6 +9,38 @@ func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// -----
+// CHECK-LABEL: @add_zero_
diff erent_shape
+func @add_zero_
diff erent_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
+ // CHECK: tosa.add
+ %zeros = "tosa.const"() {value = dense<0.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
+ %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
+ return %1 : tensor<4x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @add_zero_float
+func @add_zero_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.add
+ %zeros = "tosa.const"() {value = dense<0.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+ %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+ return %1 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @add_zero_int
+func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.add
+ %zeros = "tosa.const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+ %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ return %1 : tensor<2x3xi32>
+}
+
+// -----
+
// CHECK-LABEL: @cast_fold
func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
More information about the Mlir-commits
mailing list