[Mlir-commits] [mlir] 6874726 - [PatternMatching] Add convenience insert method to OwningRewritePatternList. NFC.
Chris Lattner
llvmlistbot at llvm.org
Mon Mar 22 11:18:29 PDT 2021
Author: Chris Lattner
Date: 2021-03-22T11:18:21-07:00
New Revision: 6874726610cc2f9eea7fa828c8585bf84969f9c3
URL: https://github.com/llvm/llvm-project/commit/6874726610cc2f9eea7fa828c8585bf84969f9c3
DIFF: https://github.com/llvm/llvm-project/commit/6874726610cc2f9eea7fa828c8585bf84969f9c3.diff
LOG: [PatternMatching] Add convenience insert method to OwningRewritePatternList. NFC.
This allows adding a C function pointer as a matchAndRewrite style pattern, which
is a very common case. This adopts it in ExpandTanh to show how it reduces a level
of nesting.
We could allow C++ lambdas here, but that doesn't work as well with type inference
in the common case. Instead of:
patterns.insert(convertTanhOp);
you need to specify:
patterns.insert<math::TanhOp>(convertTanhOp);
which is boilerplate'y. Capturing state like this is very uncommon, so we choose
to require clients to define their own structs and use the non-convenience method
when they need to do so.
Differential Revision: https://reviews.llvm.org/D99039
Added:
Modified:
mlir/docs/Tutorials/QuickstartRewrites.md
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md
index 0e560e8c6f6d..3dea430826ae 100644
--- a/mlir/docs/Tutorials/QuickstartRewrites.md
+++ b/mlir/docs/Tutorials/QuickstartRewrites.md
@@ -189,7 +189,7 @@ struct ConvertTFLeakyRelu : public RewritePattern {
: RewritePattern("tf.LeakyRelu", 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
+ PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
op, op->getResult(0).getType(), op->getOperand(0),
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
@@ -202,6 +202,19 @@ In the C++ rewrite the static benefit of the rewrite pattern is specified at
construction. While in the pattern generator a simple heuristic is currently
employed based around the number of ops matched and replaced.
+In the case where you have a registered op and want to use a benefit of 1, you
+can even define the pattern as a C function:
+
+```c++
+static LogicalResult
+convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
+ rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
+ op, op->getResult(0).getType(), op->getOperand(0),
+ /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
+ return success();
+}
+```
+
The above rule did not capture the matching operands/attributes, but in general
the `match` function in a multi-step rewrite may populate and return a
`PatternState` (or class derived from one) to pass information extracted during
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index bc49103786da..aac321dece61 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -790,6 +790,27 @@ class OwningRewritePatternList {
return *this;
}
+ // Add a matchAndRewrite style pattern represented as a C function pointer.
+ template <typename OpType>
+ OwningRewritePatternList &
+ insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
+ struct FnPattern final : public OpRewritePattern<OpType> {
+ FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
+ MLIRContext *context)
+ : OpRewritePattern<OpType>(context), implFn(implFn) {}
+
+ LogicalResult matchAndRewrite(OpType op,
+ PatternRewriter &rewriter) const override {
+ return implFn(op, rewriter);
+ }
+
+ private:
+ LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
+ };
+ insert(std::make_unique<FnPattern>(std::move(implFn), getContext()));
+ return *this;
+ }
+
private:
/// Add an instance of the pattern type 'T'. Return a reference to `this` for
/// chaining insertions.
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
index d61dc3136477..c795ad55a356 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
@@ -15,51 +15,42 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/DialectConversion.h"
-
using namespace mlir;
-namespace {
/// Expands tanh op into
/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
-struct TanhOpConverter : public OpRewritePattern<math::TanhOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(math::TanhOp op,
- PatternRewriter &rewriter) const final {
- auto floatType = op.operand().getType();
- Location loc = op.getLoc();
- auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
- auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
- Value one = rewriter.create<ConstantOp>(loc, floatOne);
- Value two = rewriter.create<ConstantOp>(loc, floatTwo);
- Value doubledX = rewriter.create<MulFOp>(loc, op.operand(), two);
-
- // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
- Value negDoubledX = rewriter.create<NegFOp>(loc, doubledX);
- Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
- Value dividend = rewriter.create<SubFOp>(loc, one, exp2x);
- Value divisor = rewriter.create<AddFOp>(loc, one, exp2x);
- Value positiveRes = rewriter.create<DivFOp>(loc, dividend, divisor);
-
- // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
- exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
- dividend = rewriter.create<SubFOp>(loc, exp2x, one);
- divisor = rewriter.create<AddFOp>(loc, exp2x, one);
- Value negativeRes = rewriter.create<DivFOp>(loc, dividend, divisor);
-
- // tanh(x) = x >= 0 ? positiveRes : negativeRes
- auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
- Value zero = rewriter.create<ConstantOp>(loc, floatZero);
- Value cmpRes =
- rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, op.operand(), zero);
- rewriter.replaceOpWithNewOp<SelectOp>(op, cmpRes, positiveRes, negativeRes);
- return success();
- }
-};
-} // namespace
+static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
+ auto floatType = op.operand().getType();
+ Location loc = op.getLoc();
+ auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+ auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
+ Value one = rewriter.create<ConstantOp>(loc, floatOne);
+ Value two = rewriter.create<ConstantOp>(loc, floatTwo);
+ Value doubledX = rewriter.create<MulFOp>(loc, op.operand(), two);
+
+ // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
+ Value negDoubledX = rewriter.create<NegFOp>(loc, doubledX);
+ Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
+ Value dividend = rewriter.create<SubFOp>(loc, one, exp2x);
+ Value divisor = rewriter.create<AddFOp>(loc, one, exp2x);
+ Value positiveRes = rewriter.create<DivFOp>(loc, dividend, divisor);
+
+ // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
+ exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
+ dividend = rewriter.create<SubFOp>(loc, exp2x, one);
+ divisor = rewriter.create<AddFOp>(loc, exp2x, one);
+ Value negativeRes = rewriter.create<DivFOp>(loc, dividend, divisor);
+
+ // tanh(x) = x >= 0 ? positiveRes : negativeRes
+ auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
+ Value zero = rewriter.create<ConstantOp>(loc, floatZero);
+ Value cmpRes =
+ rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, op.operand(), zero);
+ rewriter.replaceOpWithNewOp<SelectOp>(op, cmpRes, positiveRes, negativeRes);
+ return success();
+}
void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns) {
- patterns.insert<TanhOpConverter>(patterns.getContext());
+ patterns.insert(convertTanhOp);
}
More information about the Mlir-commits
mailing list