[Mlir-commits] [mlir] d22965f - Reland "[mlir] Add a postprocessing parameter in Pattern"

Jian Cai llvmlistbot at llvm.org
Tue Aug 15 19:21:44 PDT 2023


Author: Jian Cai
Date: 2023-08-15T19:21:27-07:00
New Revision: d22965f0d623392fdf87e5a09a0276aa70e8dfed

URL: https://github.com/llvm/llvm-project/commit/d22965f0d623392fdf87e5a09a0276aa70e8dfed
DIFF: https://github.com/llvm/llvm-project/commit/d22965f0d623392fdf87e5a09a0276aa70e8dfed.diff

LOG: Reland "[mlir] Add a postprocessing parameter in Pattern"

This fixed a test failure that caused the rollback of the original
commit. Verified with ninja check-mlir.

Added: 
    

Modified: 
    mlir/docs/DeclarativeRewrites.md
    mlir/examples/toy/Ch3/mlir/ToyCombine.td
    mlir/include/mlir/IR/PatternBase.td
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 6a9016a47cf463..dd996baf3cd957 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -59,12 +59,13 @@ features:
 ## Rule Definition
 
 The core construct for defining a rewrite rule is defined in
-[`OpBase.td`][OpBase] as
+[`PatternBase.td`][PatternBase] as
 
 ```tablegen
 class Pattern<
     dag sourcePattern, list<dag> resultPatterns,
     list<dag> additionalConstraints = [],
+    list<dag> supplementalPatterns = [],
     dag benefitsAdded = (addBenefit 0)>;
 ```
 
@@ -678,6 +679,36 @@ You can
 *   Apply constraints on multiple bound symbols (`$input` and `TwoResultOp`'s
     first result must have the same element type).
 
+### Supplying additional result patterns
+
+Sometimes we need to add additional code after the result patterns, e.g. coping
+the attributes of the source op to the result ops. These can be specified via
+`SupplementalPatterns` parameter. Similar to auxiliary patterns, they are not
+for replacing results in the source pattern.
+
+For example, we can write
+
+```tablegen
+def GetOwner: NativeCodeCall<"$0.getOwner()">;
+
+def CopyAttrFoo: NativeCodeCallVoid<
+  "$1->setAttr($_builder.getStringAttr(\"foo\"), $0->getAttr(\"foo\"))">;
+
+def CopyAttrBar: NativeCodeCallVoid<
+  "$1->setAttr($_builder.getStringAttr(\"bar\"), $0->getAttr(\"bar\"))">;
+
+
+def : Pattern<
+  (ThreeResultOp:$src ...),
+  [(ZeroResultOp:$dest1 ...), (ThreeResultOp:$dest2 ...)],
+  [(CopyAttrFoo (GetOwner $src), $dest1),
+    (CopyAttrBar (GetOwner $src), (GetOwner $dest2))]>;
+```
+
+This will copy the attribute `foo` and `bar` of `ThreeResultOp` in the source
+pattern to `ZeroResultOp` and `ThreeResultOp` in the result patterns respectively.
+The patterns are executed in specified order.
+
 ### Adjusting benefits
 
 The benefit of a `Pattern` is an integer value indicating the benefit of

diff  --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.td b/mlir/examples/toy/Ch3/mlir/ToyCombine.td
index 11d783150ebe1b..8bd2b442d69f2b 100644
--- a/mlir/examples/toy/Ch3/mlir/ToyCombine.td
+++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.td
@@ -22,6 +22,7 @@ include "toy/Ops.td"
 /// class Pattern<
 ///    dag sourcePattern, list<dag> resultPatterns,
 ///    list<dag> additionalConstraints = [],
+//     list<dag> supplementalPatterns = [],
 ///    dag benefitsAdded = (addBenefit 0)
 /// >;
 

diff  --git a/mlir/include/mlir/IR/PatternBase.td b/mlir/include/mlir/IR/PatternBase.td
index c6ab1b5a91b58b..919fb884adb0e9 100644
--- a/mlir/include/mlir/IR/PatternBase.td
+++ b/mlir/include/mlir/IR/PatternBase.td
@@ -90,6 +90,7 @@ def addBenefit;
 // * `FiveResultOp`#3: `TwoResultOp2`#1
 // * `FiveResultOp`#4: `TwoResultOp2`#1
 class Pattern<dag source, list<dag> results, list<dag> preds = [],
+  list<dag> supplemental_results = [],
   dag benefitAdded = (addBenefit 0)> {
   dag sourcePattern = source;
   // Result patterns. Each result pattern is expected to replace one result
@@ -103,6 +104,11 @@ class Pattern<dag source, list<dag> results, list<dag> preds = [],
   // matched in source pattern and places further constraints on them as a
   // whole.
   list<dag> constraints = preds;
+  // Optional patterns that are executed after the result patterns. Similar to
+  // auxiliary patterns, they are not used for replacement. These patterns can
+  // be used to invoke additional code after the result patterns, e.g. copy
+  // the attributes from the source op to the result ops.
+  list<dag> supplementalPatterns = supplemental_results;
   // The delta value added to the default benefit value. The default value is
   // the number of ops in the source pattern. The rule with the highest final
   // benefit value will be applied first if there are multiple rules matches.
@@ -112,8 +118,9 @@ class Pattern<dag source, list<dag> results, list<dag> preds = [],
 
 // Form of a pattern which produces a single result.
 class Pat<dag pattern, dag result, list<dag> preds = [],
+  list<dag> supplemental_results = [],
   dag benefitAdded = (addBenefit 0)> :
-  Pattern<pattern, [result], preds, benefitAdded>;
+  Pattern<pattern, [result], preds, supplemental_results, benefitAdded>;
 
 // Native code call wrapper. This allows invoking an arbitrary C++ expression
 // to create an op operand/attribute or replace an op result.

diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index b17932095f9620..4511ba7dd833ef 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -482,6 +482,14 @@ class Pattern {
   // Returns the constraints.
   std::vector<AppliedConstraint> getConstraints() const;
 
+  // Returns the number of supplemental auxiliary patterns generated by applying
+  // this rewrite rule.
+  int getNumSupplementalPatterns() const;
+
+  // Returns the DAG tree root node of the `index`-th supplemental result
+  // pattern.
+  DagNode getSupplementalPattern(unsigned index) const;
+
   // Returns the benefit score of the pattern.
   int getBenefit() const;
 

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index e8625b2e6b7102..d9e1d6c7f89152 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -675,6 +675,16 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
   return ret;
 }
 
+int Pattern::getNumSupplementalPatterns() const {
+  auto *results = def.getValueAsListInit("supplementalPatterns");
+  return results->size();
+}
+
+DagNode Pattern::getSupplementalPattern(unsigned index) const {
+  auto *results = def.getValueAsListInit("supplementalPatterns");
+  return DagNode(cast<llvm::DagInit>(results->getElement(index)));
+}
+
 int Pattern::getBenefit() const {
   // The initial benefit value is a heuristic with number of ops in the source
   // pattern.

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index af86b480db9ff8..11bcb8b0bae84c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -985,7 +985,7 @@ def OpF : TEST_Op<"op_f">, Arguments<(ins I32)>, Results<(outs I32)>;
 def OpG : TEST_Op<"op_g">, Arguments<(ins I32)>, Results<(outs I32)>;
 // Verify that bumping benefit results in selecting 
diff erent op.
 def : Pat<(OpD $input), (OpE $input)>;
-def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>;
+def : Pat<(OpD $input), (OpF $input), [], [], (addBenefit 10)>;
 // Verify that patterns with more source nodes are selected before those with fewer.
 def : Pat<(OpG $input), (OpB $input, ConstantAttr<I32Attr, "20">:$attr)>;
 def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
@@ -1803,7 +1803,7 @@ def : Pat<(ILLegalOpB), (LegalOpA Test_LegalizerEnum_Failure)>;
 def : Pat<(ILLegalOpC), (ILLegalOpD)>;
 def : Pat<(ILLegalOpD), (LegalOpA Test_LegalizerEnum_Failure)>;
 
-def : Pat<(ILLegalOpC), (ILLegalOpE), [], (addBenefit 10)>;
+def : Pat<(ILLegalOpC), (ILLegalOpE), [], [], (addBenefit 10)>;
 def : Pat<(ILLegalOpE), (LegalOpA Test_LegalizerEnum_Success)>;
 
 // Check that patterns use the most up-to-date value when being replaced.

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 9cd4414bf99d7e..8b5ef5c6e01829 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1105,6 +1105,17 @@ void PatternEmitter::emitRewriteLogic() {
     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
   }
 
+  // Process supplemtal patterns.
+  int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
+  for (int i = 0, offset = -numSupplementalPatterns;
+       i < numSupplementalPatterns; ++i) {
+    DagNode resultTree = pattern.getSupplementalPattern(i);
+    auto val = handleResultPattern(resultTree, offset++, 0);
+    if (resultTree.isNativeCodeCall() &&
+        resultTree.getNumReturnsOfNativeCode() == 0)
+      os << val << ";\n";
+  }
+
   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
 }
 


        


More information about the Mlir-commits mailing list