[Mlir-commits] [mlir] 6874726 - [PatternMatching] Add convenience insert method to OwningRewritePatternList. NFC.

Chris Lattner llvmlistbot at llvm.org
Mon Mar 22 11:18:29 PDT 2021


Author: Chris Lattner
Date: 2021-03-22T11:18:21-07:00
New Revision: 6874726610cc2f9eea7fa828c8585bf84969f9c3

URL: https://github.com/llvm/llvm-project/commit/6874726610cc2f9eea7fa828c8585bf84969f9c3
DIFF: https://github.com/llvm/llvm-project/commit/6874726610cc2f9eea7fa828c8585bf84969f9c3.diff

LOG: [PatternMatching] Add convenience insert method to OwningRewritePatternList. NFC.

This allows adding a C function pointer as a matchAndRewrite style pattern, which
is a very common case.  This adopts it in ExpandTanh to show how it reduces a level
of nesting.

We could allow C++ lambdas here, but that doesn't work as well with type inference
in the common case.  Instead of:

  patterns.insert(convertTanhOp);

you need to specify:

  patterns.insert<math::TanhOp>(convertTanhOp);

which is boilerplate'y.  Capturing state like this is very uncommon, so we choose
to require clients to define their own structs and use the non-convenience method
when they need to do so.

Differential Revision: https://reviews.llvm.org/D99039

Added: 
    

Modified: 
    mlir/docs/Tutorials/QuickstartRewrites.md
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md
index 0e560e8c6f6d..3dea430826ae 100644
--- a/mlir/docs/Tutorials/QuickstartRewrites.md
+++ b/mlir/docs/Tutorials/QuickstartRewrites.md
@@ -189,7 +189,7 @@ struct ConvertTFLeakyRelu : public RewritePattern {
       : RewritePattern("tf.LeakyRelu", 1, context) {}
 
   LogicalResult matchAndRewrite(Operation *op,
-                                     PatternRewriter &rewriter) const override {
+                                PatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
         op, op->getResult(0).getType(), op->getOperand(0),
         /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
@@ -202,6 +202,19 @@ In the C++ rewrite the static benefit of the rewrite pattern is specified at
 construction. While in the pattern generator a simple heuristic is currently
 employed based around the number of ops matched and replaced.
 
+In the case where you have a registered op and want to use a benefit of 1, you
+can even define the pattern as a C function:
+
+```c++
+static LogicalResult
+convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
+      op, op->getResult(0).getType(), op->getOperand(0),
+      /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
+  return success();
+}
+```
+
 The above rule did not capture the matching operands/attributes, but in general
 the `match` function in a multi-step rewrite may populate and return a
 `PatternState` (or class derived from one) to pass information extracted during

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index bc49103786da..aac321dece61 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -790,6 +790,27 @@ class OwningRewritePatternList {
     return *this;
   }
 
+  // Add a matchAndRewrite style pattern represented as a C function pointer.
+  template <typename OpType>
+  OwningRewritePatternList &
+  insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
+    struct FnPattern final : public OpRewritePattern<OpType> {
+      FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
+                MLIRContext *context)
+          : OpRewritePattern<OpType>(context), implFn(implFn) {}
+
+      LogicalResult matchAndRewrite(OpType op,
+                                    PatternRewriter &rewriter) const override {
+        return implFn(op, rewriter);
+      }
+
+    private:
+      LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
+    };
+    insert(std::make_unique<FnPattern>(std::move(implFn), getContext()));
+    return *this;
+  }
+
 private:
   /// Add an instance of the pattern type 'T'. Return a reference to `this` for
   /// chaining insertions.

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
index d61dc3136477..c795ad55a356 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
@@ -15,51 +15,42 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Transforms/DialectConversion.h"
-
 using namespace mlir;
 
-namespace {
 /// Expands tanh op into
 ///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
 ///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
-struct TanhOpConverter : public OpRewritePattern<math::TanhOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(math::TanhOp op,
-                                PatternRewriter &rewriter) const final {
-    auto floatType = op.operand().getType();
-    Location loc = op.getLoc();
-    auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
-    auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
-    Value one = rewriter.create<ConstantOp>(loc, floatOne);
-    Value two = rewriter.create<ConstantOp>(loc, floatTwo);
-    Value doubledX = rewriter.create<MulFOp>(loc, op.operand(), two);
-
-    // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
-    Value negDoubledX = rewriter.create<NegFOp>(loc, doubledX);
-    Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
-    Value dividend = rewriter.create<SubFOp>(loc, one, exp2x);
-    Value divisor = rewriter.create<AddFOp>(loc, one, exp2x);
-    Value positiveRes = rewriter.create<DivFOp>(loc, dividend, divisor);
-
-    // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
-    exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
-    dividend = rewriter.create<SubFOp>(loc, exp2x, one);
-    divisor = rewriter.create<AddFOp>(loc, exp2x, one);
-    Value negativeRes = rewriter.create<DivFOp>(loc, dividend, divisor);
-
-    // tanh(x) = x >= 0 ? positiveRes : negativeRes
-    auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
-    Value zero = rewriter.create<ConstantOp>(loc, floatZero);
-    Value cmpRes =
-        rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, op.operand(), zero);
-    rewriter.replaceOpWithNewOp<SelectOp>(op, cmpRes, positiveRes, negativeRes);
-    return success();
-  }
-};
-} // namespace
+static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
+  auto floatType = op.operand().getType();
+  Location loc = op.getLoc();
+  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+  auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
+  Value one = rewriter.create<ConstantOp>(loc, floatOne);
+  Value two = rewriter.create<ConstantOp>(loc, floatTwo);
+  Value doubledX = rewriter.create<MulFOp>(loc, op.operand(), two);
+
+  // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
+  Value negDoubledX = rewriter.create<NegFOp>(loc, doubledX);
+  Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
+  Value dividend = rewriter.create<SubFOp>(loc, one, exp2x);
+  Value divisor = rewriter.create<AddFOp>(loc, one, exp2x);
+  Value positiveRes = rewriter.create<DivFOp>(loc, dividend, divisor);
+
+  // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
+  exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
+  dividend = rewriter.create<SubFOp>(loc, exp2x, one);
+  divisor = rewriter.create<AddFOp>(loc, exp2x, one);
+  Value negativeRes = rewriter.create<DivFOp>(loc, dividend, divisor);
+
+  // tanh(x) = x >= 0 ? positiveRes : negativeRes
+  auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
+  Value zero = rewriter.create<ConstantOp>(loc, floatZero);
+  Value cmpRes =
+      rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, op.operand(), zero);
+  rewriter.replaceOpWithNewOp<SelectOp>(op, cmpRes, positiveRes, negativeRes);
+  return success();
+}
 
 void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns) {
-  patterns.insert<TanhOpConverter>(patterns.getContext());
+  patterns.insert(convertTanhOp);
 }


        


More information about the Mlir-commits mailing list