[Mlir-commits] [mlir] 07a029c - Canonicalization for add to no-op if one of the inputs is zero

Rob Suderman llvmlistbot at llvm.org
Thu Nov 4 17:01:29 PDT 2021


Author: not-jenni
Date: 2021-11-04T16:52:47-07:00
New Revision: 07a029c0577864827bad472364145844dfaf2d24

URL: https://github.com/llvm/llvm-project/commit/07a029c0577864827bad472364145844dfaf2d24
DIFF: https://github.com/llvm/llvm-project/commit/07a029c0577864827bad472364145844dfaf2d24.diff

LOG: Canonicalization for add to no-op if one of the inputs is zero

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D113207

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b57e8b2fb8cb..4dc678e31ddb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -442,6 +442,8 @@ def Tosa_AddOp : Tosa_Op<"add", [
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 90146f5bc29b..07d3e6e67c21 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -289,6 +289,55 @@ void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<NoOpOptimization>(context);
 }
 
+struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::AddOp op,
+                                PatternRewriter &rewriter) const override {
+    auto input1 = op.input1();
+    auto input2 = op.input2();
+
+    DenseElementsAttr input1Attr;
+    if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
+        input2.getType() == op.getType()) {
+      if (input1Attr.getType().getElementType().isa<FloatType>() &&
+          input1Attr.getSplatValue<APFloat>().isZero()) {
+        rewriter.replaceOp(op, op.input2());
+        return success();
+      }
+
+      if (input1Attr.getType().getElementType().isa<IntegerType>() &&
+          input1Attr.getSplatValue<APInt>().isZero()) {
+        rewriter.replaceOp(op, op.input2());
+        return success();
+      }
+    }
+
+    DenseElementsAttr input2Attr;
+    if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
+        input1.getType() == op.getType()) {
+      if (input2Attr.getType().getElementType().isa<FloatType>() &&
+          input2Attr.getSplatValue<APFloat>().isZero()) {
+        rewriter.replaceOp(op, op.input1());
+        return success();
+      }
+
+      if (input2Attr.getType().getElementType().isa<IntegerType>() &&
+          input2Attr.getSplatValue<APInt>().isZero()) {
+        rewriter.replaceOp(op, op.input1());
+        return success();
+      }
+    }
+
+    return failure();
+  }
+};
+
+void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                        MLIRContext *context) {
+  results.insert<AddZeroOptimization>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Operator Folders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e6cf1a15ac67..e4614853c71e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -9,6 +9,38 @@ func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
 
 // -----
 
+// CHECK-LABEL: @add_zero_
diff erent_shape
+func @add_zero_
diff erent_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
+  // CHECK: tosa.add
+  %zeros = "tosa.const"() {value = dense<0.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
+  %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
+  return %1 : tensor<4x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @add_zero_float
+func @add_zero_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  // CHECK: return %arg0
+  // CHECK-NOT: tosa.add
+  %zeros = "tosa.const"() {value = dense<0.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  return %1 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @add_zero_int
+func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  // CHECK: return %arg0
+  // CHECK-NOT: tosa.add
+  %zeros = "tosa.const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+  %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  return %1 : tensor<2x3xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @cast_fold
 func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0


        


More information about the Mlir-commits mailing list