[Mlir-commits] [mlir] 0259669 - [mlir] Add a postprocessing parameter in Pattern
Jian Cai
llvmlistbot at llvm.org
Tue Aug 15 14:10:04 PDT 2023
Author: Jian Cai
Date: 2023-08-15T14:08:31-07:00
New Revision: 02596693fac55f550e85620f5184547c80c8f930
URL: https://github.com/llvm/llvm-project/commit/02596693fac55f550e85620f5184547c80c8f930
DIFF: https://github.com/llvm/llvm-project/commit/02596693fac55f550e85620f5184547c80c8f930.diff
LOG: [mlir] Add a postprocessing parameter in Pattern
This adds a parameter SupplementalPatterns in tablegen class Pattern for
postprocessing code. For example, this can be used to ensure ops are
placed in the correct device by copying the atttributes that decide
devicement placement in Tensorflow dialect to prevent performance
regression.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D157032
Added:
Modified:
mlir/docs/DeclarativeRewrites.md
mlir/include/mlir/IR/PatternBase.td
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 6a9016a47cf463..d068b96142a40e 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -59,12 +59,13 @@ features:
## Rule Definition
The core construct for defining a rewrite rule is defined in
-[`OpBase.td`][OpBase] as
+[`PatternBase.td`][PatternBase] as
```tablegen
class Pattern<
dag sourcePattern, list<dag> resultPatterns,
list<dag> additionalConstraints = [],
+ list<dag> supplementalPatterns = [],
dag benefitsAdded = (addBenefit 0)>;
```
@@ -678,6 +679,36 @@ You can
* Apply constraints on multiple bound symbols (`$input` and `TwoResultOp`'s
first result must have the same element type).
+### Supplying additional result patterns
+
+Sometimes we need to add additional code after the result patterns, e.g. coping
+the attributes of the source op to the result ops. These can be specified via
+`SupplementalPatterns` parameter. Similar to auxiliary patterns, they are not
+for replacing results in the source pattern.
+
+For example, we can write
+
+```tablegen
+def GetOwner: NativeCodeCall<"$0.getOwner()">;
+
+def CopyAttrFoo: NativeCodeCallVoid<
+ "$1->setAttr($_builder.getStringAttr(\"foo\"), $0->getAttr(\"foo\"))">;
+
+def CopyAttrBar: NativeCodeCallVoid<
+ "$1->setAttr($_builder.getStringAttr(\"bar\"), $0->getAttr(\"bar\"))">;
+
+
+def : Pattern<
+ (ThreeResultOp:$src ...),
+ [(ZeroResultOp:$dest1 ...), (ThreeResultOp:$dest2 ...)],
+ [(CopyAttrFoo (GetOwner $src), $dest1),
+ (CopyAttrBar (GetOwner $src), (GetOwner $dest2))]>;
+```
+
+This will copy the attribute `foo` and `bar` of `ThreeResultOp` in the source
+pattern to `TwoResultOp` and `OneResultOp` in the result patterns respectively.
+The patterns are executed in the order they are specified.
+
### Adjusting benefits
The benefit of a `Pattern` is an integer value indicating the benefit of
diff --git a/mlir/include/mlir/IR/PatternBase.td b/mlir/include/mlir/IR/PatternBase.td
index c6ab1b5a91b58b..919fb884adb0e9 100644
--- a/mlir/include/mlir/IR/PatternBase.td
+++ b/mlir/include/mlir/IR/PatternBase.td
@@ -90,6 +90,7 @@ def addBenefit;
// * `FiveResultOp`#3: `TwoResultOp2`#1
// * `FiveResultOp`#4: `TwoResultOp2`#1
class Pattern<dag source, list<dag> results, list<dag> preds = [],
+ list<dag> supplemental_results = [],
dag benefitAdded = (addBenefit 0)> {
dag sourcePattern = source;
// Result patterns. Each result pattern is expected to replace one result
@@ -103,6 +104,11 @@ class Pattern<dag source, list<dag> results, list<dag> preds = [],
// matched in source pattern and places further constraints on them as a
// whole.
list<dag> constraints = preds;
+ // Optional patterns that are executed after the result patterns. Similar to
+ // auxiliary patterns, they are not used for replacement. These patterns can
+ // be used to invoke additional code after the result patterns, e.g. copy
+ // the attributes from the source op to the result ops.
+ list<dag> supplementalPatterns = supplemental_results;
// The delta value added to the default benefit value. The default value is
// the number of ops in the source pattern. The rule with the highest final
// benefit value will be applied first if there are multiple rules matches.
@@ -112,8 +118,9 @@ class Pattern<dag source, list<dag> results, list<dag> preds = [],
// Form of a pattern which produces a single result.
class Pat<dag pattern, dag result, list<dag> preds = [],
+ list<dag> supplemental_results = [],
dag benefitAdded = (addBenefit 0)> :
- Pattern<pattern, [result], preds, benefitAdded>;
+ Pattern<pattern, [result], preds, supplemental_results, benefitAdded>;
// Native code call wrapper. This allows invoking an arbitrary C++ expression
// to create an op operand/attribute or replace an op result.
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index b17932095f9620..4511ba7dd833ef 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -482,6 +482,14 @@ class Pattern {
// Returns the constraints.
std::vector<AppliedConstraint> getConstraints() const;
+ // Returns the number of supplemental auxiliary patterns generated by applying
+ // this rewrite rule.
+ int getNumSupplementalPatterns() const;
+
+ // Returns the DAG tree root node of the `index`-th supplemental result
+ // pattern.
+ DagNode getSupplementalPattern(unsigned index) const;
+
// Returns the benefit score of the pattern.
int getBenefit() const;
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index e8625b2e6b7102..d9e1d6c7f89152 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -675,6 +675,16 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
return ret;
}
+int Pattern::getNumSupplementalPatterns() const {
+ auto *results = def.getValueAsListInit("supplementalPatterns");
+ return results->size();
+}
+
+DagNode Pattern::getSupplementalPattern(unsigned index) const {
+ auto *results = def.getValueAsListInit("supplementalPatterns");
+ return DagNode(cast<llvm::DagInit>(results->getElement(index)));
+}
+
int Pattern::getBenefit() const {
// The initial benefit value is a heuristic with number of ops in the source
// pattern.
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 9cd4414bf99d7e..8b5ef5c6e01829 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1105,6 +1105,17 @@ void PatternEmitter::emitRewriteLogic() {
os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
}
+ // Process supplemtal patterns.
+ int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
+ for (int i = 0, offset = -numSupplementalPatterns;
+ i < numSupplementalPatterns; ++i) {
+ DagNode resultTree = pattern.getSupplementalPattern(i);
+ auto val = handleResultPattern(resultTree, offset++, 0);
+ if (resultTree.isNativeCodeCall() &&
+ resultTree.getNumReturnsOfNativeCode() == 0)
+ os << val << ";\n";
+ }
+
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
}
More information about the Mlir-commits
mailing list