[Mlir-commits] [mlir] cdb0623 - [mlir][tosa] Add tosa.mul by one canonicalization

Rob Suderman llvmlistbot at llvm.org
Mon Nov 15 14:53:16 PST 2021


Author: not-jenni
Date: 2021-11-15T14:52:16-08:00
New Revision: cdb0623ad82751c87d9c7d4b2c6eaf0c5ccf3719

URL: https://github.com/llvm/llvm-project/commit/cdb0623ad82751c87d9c7d4b2c6eaf0c5ccf3719
DIFF: https://github.com/llvm/llvm-project/commit/cdb0623ad82751c87d9c7d4b2c6eaf0c5ccf3719.diff

LOG: [mlir][tosa] Add tosa.mul by one canonicalization

Multiply by one can be removed during canonicalization. This optimizes away unneeded operations.

Differential Revision: https://reviews.llvm.org/D113807

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index bdc7ac13e675a..0e6c0d2560e88 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -760,6 +760,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1413,7 +1415,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
     Tosa_RankedTensor:$output
   );
 
-  let builders = [Tosa_PadOpQuantInfoBuilder, 
+  let builders = [Tosa_PadOpQuantInfoBuilder,
                   Tosa_ExplicitValuePadOpQuantInfoBuilder];
 }
 

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2a435476e5bef..e32be42a31015 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -339,6 +339,55 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<AddZeroOptimization>(context);
 }
 
+struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::MulOp op,
+                                PatternRewriter &rewriter) const override {
+    auto input1 = op.input1();
+    auto input2 = op.input2();
+
+    DenseElementsAttr input1Attr;
+    if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
+        input2.getType() == op.getType()) {
+      if (input1Attr.getType().getElementType().isa<FloatType>() &&
+          input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
+        rewriter.replaceOp(op, op.input2());
+        return success();
+      }
+
+      if (input1Attr.getType().getElementType().isa<IntegerType>() &&
+          matchPattern(input1, m_One())) {
+        rewriter.replaceOp(op, op.input2());
+        return success();
+      }
+    }
+
+    DenseElementsAttr input2Attr;
+    if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
+        input1.getType() == op.getType()) {
+      if (input2Attr.getType().getElementType().isa<FloatType>() &&
+          input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
+        rewriter.replaceOp(op, op.input1());
+        return success();
+      }
+
+      if (input2Attr.getType().getElementType().isa<IntegerType>() &&
+          matchPattern(input2, m_One())) {
+        rewriter.replaceOp(op, op.input1());
+        return success();
+      }
+    }
+
+    return failure();
+  }
+};
+
+void MulOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                        MLIRContext *context) {
+  results.insert<MulOneOptimization>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Operator Folders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e4614853c71e4..11583eef7c867 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -78,6 +78,38 @@ func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
 
 // -----
 
+// CHECK-LABEL: @mul_one_
diff erent_shape
+func @mul_one_
diff erent_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
+  // CHECK: tosa.mul
+  %ones = "tosa.const"() {value = dense<1.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
+  %1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
+  return %1 : tensor<4x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_float
+func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  // CHECK: return %arg0
+  // CHECK-NOT: tosa.mul
+  %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  %1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  return %1 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_int
+func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  // CHECK: return %arg0
+  // CHECK-NOT: tosa.mul
+  %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+  %1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  return %1 : tensor<2x3xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @reduce_all_fold
 func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0


        


More information about the Mlir-commits mailing list