[Mlir-commits] [mlir] 34b5482 - Support NativeCodeCall binding in rewrite pattern.

Chia-hung Duan llvmlistbot at llvm.org
Sun May 9 18:29:43 PDT 2021


Author: Chia-hung Duan
Date: 2021-05-10T09:29:27+08:00
New Revision: 34b5482b334f2a3960ef079667adb5b3df20aa7d

URL: https://github.com/llvm/llvm-project/commit/34b5482b334f2a3960ef079667adb5b3df20aa7d
DIFF: https://github.com/llvm/llvm-project/commit/34b5482b334f2a3960ef079667adb5b3df20aa7d.diff

LOG: Support NativeCodeCall binding in rewrite pattern.

We are able to bind the result from native function while rewriting
pattern. In matching pattern, if we want to get some values back, we can
do that by passing parameter as return value placeholder. Besides, add
the semantic of '$_self' in NativeCodeCall while matching, it'll be the
operation that defines certain operand.

Differential Revision: https://reviews.llvm.org/D100746

Added: 
    

Modified: 
    mlir/docs/DeclarativeRewrites.md
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/test/mlir-tblgen/rewriter-errors.td
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index e0aabe3f98f6e..b5ae3e83aaad7 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -392,26 +392,31 @@ placeholder_.
 *   `$_builder` will be replaced by the current `mlir::PatternRewriter`.
 *   `$_loc` will be replaced by the fused location or custom location (as
     determined by location directive).
-*   `$_self` will be replaced with the entity `NativeCodeCall` is attached to.
+*   `$_self` will be replaced by the defining operation in a source pattern.
 
 We have seen how `$_builder` can be used in the above; it allows us to pass a
 `mlir::Builder` (`mlir::PatternRewriter` is a subclass of `mlir::OpBuilder`,
 which is a subclass of `mlir::Builder`) to the C++ helper function to use the
 handy methods on `mlir::Builder`.
 
-`$_self` is useful when we want to write something in the form of
-`NativeCodeCall<"...">:$symbol`. For example, if we want to reverse the previous
-example and decompose the array attribute into two attributes:
+Here's an example how we should use `$_self` in source pattern,
 
 ```tablegen
-class getNthAttr<int n> : NativeCodeCall<"$_self[" # n # "]">;
 
-def : Pat<(OneAttrOp $attr),
-          (TwoAttrOp (getNthAttr<0>:$attr), (getNthAttr<1>:$attr)>;
+def : Pat<(OneAttrOp (NativeCodeCall<"Foo($_self, &$0)"> I32Attr:$val)),
+          (TwoAttrOp $val, $val)>;
 ```
 
-In the above, `$_self` is substituted by the attribute bound by `$attr`, which
-is `OneAttrOp`'s array attribute.
+In the above, `$_self` is substituted by the defining operation of the first
+operand of OneAttrOp. Note that we don't support binding name to NativeCodeCall
+in the source pattern. To carry some return values from helper function, put the
+names (constraint is optional) in the parameter list and they will be bound to
+the variables with correspoding type. Then these named must be either passed by
+reference or a pointer to variable used as argument so that the matched value
+can be returned. In the same example, `$val` will be bound to a variable with
+`Attribute` type(as `I32Attr`) and the type of the second argument in Foo()
+could be `Attribute&` or `Attribute*`. Names with attribute constraints will be
+captured as Attributes while everything else will be treated as Value.
 
 Positional placeholders will be substituted by the `dag` object parameters at
 the `NativeCodeCall` use site. For example, if we define `SomeCall :

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 11038ef12d2ad..9436602fef57b 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2530,9 +2530,9 @@ class Pat<dag pattern, dag result, list<dag> preds = [],
 // the wrapped expression can take special placeholders listed below:
 //
 // * `$_builder` will be replaced by the current `mlir::PatternRewriter`.
-// * `$_self` will be replaced with the entity this transformer is attached to.
-//   E.g., with the definition `def transform : NativeCodeCall<"$_self...">`,
-//   `$_self` in `transform:$attr` will be replaced by the value for `$attr`.
+// * `$_self` will be replaced by the defining operation in a source pattern.
+//   E.g., `NativeCodeCall<"Foo($_self, &$0)> I32Attr:$attr)>`, `$_self` will be
+//   replaced with the defining operation of the first operand of OneArgOp.
 //
 // If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`,
 // then positional placeholders are also supported; placeholder `$N` in the
@@ -2542,7 +2542,7 @@ class NativeCodeCall<string expr> {
   string expression = expr;
 }
 
-def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">;
+def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult(0), m_Constant(&$0)))">;
 
 //===----------------------------------------------------------------------===//
 // Rewrite directives

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 0dae7ff718837..d3bd6f7662bff 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -232,7 +232,7 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
                 getVarName(name)));
   }
   case Kind::Value: {
-    return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
+    return std::string(formatv("::mlir::Value {0};\n", name));
   }
   case Kind::Result: {
     // Use the op itself for captured results.
@@ -626,11 +626,16 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
 
   if (tree.isNativeCodeCall()) {
     if (!treeName.empty()) {
-      PrintFatalError(
-          &def,
-          formatv(
-              "binding symbol '{0}' to native code call unsupported right now",
-              treeName));
+      if (!isSrcPattern) {
+        LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
+                                << treeName << '\n');
+        verifyBind(infoMap.bindValue(treeName), treeName);
+      } else {
+        PrintFatalError(&def,
+                        formatv("binding symbol '{0}' to NativecodeCall in "
+                                "MatchPattern is not supported",
+                                treeName));
+      }
     }
 
     for (int i = 0; i != numTreeArgs; ++i) {
@@ -649,24 +654,27 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
 
       // `$_` is a special symbol meaning ignore the current argument.
       if (!treeArgName.empty() && treeArgName != "_") {
-        if (tree.isNestedDagArg(i)) {
-          auto err = formatv("cannot bind '{0}' for nested native call arg",
-                             treeArgName);
-          PrintFatalError(&def, err);
-        }
-
         DagLeaf leaf = tree.getArgAsLeaf(i);
-        auto constraint = leaf.getAsConstraint();
-        bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
-                      leaf.isConstantAttr() ||
-                      constraint.getKind() == Constraint::Kind::CK_Attr;
-
-        if (isAttr) {
-          verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
-          continue;
-        }
 
-        verifyBind(infoMap.bindValue(treeArgName), treeArgName);
+        // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
+        if (leaf.isUnspecified()) {
+          // This is case of $c, a Value without any constraints.
+          verifyBind(infoMap.bindValue(treeArgName), treeArgName);
+        } else {
+          auto constraint = leaf.getAsConstraint();
+          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
+                        leaf.isConstantAttr() ||
+                        constraint.getKind() == Constraint::Kind::CK_Attr;
+
+          if (isAttr) {
+            // This is case of $a, a binding to a certain attribute.
+            verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
+            continue;
+          }
+
+          // This is case of $b, a binding to a certain type.
+          verifyBind(infoMap.bindValue(treeArgName), treeArgName);
+        }
       }
     }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 312287177bf6e..b0c2fe45ed681 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -837,6 +837,20 @@ def : Pattern<(OpNativeCodeCall3 $input),
               [(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input),
                (OpK)]>;
 
+def OpNativeCodeCall4 : TEST_Op<"native_code_call4"> {
+  let arguments = (ins AnyType:$input1);
+  let results = (outs I32:$output1, I32:$output2);
+}
+def OpNativeCodeCall5 : TEST_Op<"native_code_call5"> {
+  let arguments = (ins I32:$input1, I32:$input2);
+  let results = (outs I32:$output1, I32:$output2);
+}
+
+def GetFirstI32Result : NativeCodeCall<"success(getFirstI32Result($_self, $0))">;
+def BindNativeCodeCallResult : NativeCodeCall<"bindNativeCodeCallResult($0)">;
+def : Pat<(OpNativeCodeCall4 (GetFirstI32Result $ret)),
+          (OpNativeCodeCall5 (BindNativeCodeCallResult:$native $ret), $native)>;
+
 // Test AllAttrConstraintsOf.
 def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {
   let arguments = (ins I64ArrayAttr:$attr);

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 6ddb6c916737e..b319257dbe060 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -35,6 +35,15 @@ static void handleNoResultOp(PatternRewriter &rewriter,
                                     op.operand());
 }
 
+static bool getFirstI32Result(Operation *op, Value &value) {
+  if (!Type(op->getResult(0).getType()).isSignlessInteger(32))
+    return false;
+  value = op->getResult(0);
+  return true;
+}
+
+static Value bindNativeCodeCallResult(Value value) { return value; }
+
 // Test that natives calls are only called once during rewrites.
 // OpM_Test will return Pi, increased by 1 for each subsequent calls.
 // This let us check the number of times OpM_Test was called by inspecting

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 0425cf819e60a..6918f319198ca 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -88,6 +88,20 @@ func @verifyAuxiliaryNativeCodeCall(%arg0: i32) -> (i32) {
   return %0 : i32
 }
 
+// CHECK-LABEL: verifyNativeCodeCallBinding
+func @verifyNativeCodeCallBinding(%arg0 : i32) -> (i32) {
+  %0 = "test.op_k"() : () -> (i32)
+  // CHECK: %[[A:.*]], %[[B:.*]] = "test.native_code_call5"(%1, %1) : (i32, i32) -> (i32, i32)
+  %1, %2 = "test.native_code_call4"(%0) : (i32) -> (i32, i32)
+  %3 = "test.constant"() {value = 1 : i8} : () -> i8
+  // %3 is i8 so it'll fail at GetFirstI32Result match. The operation should
+  // keep the same form.
+  // CHECK: %{{.*}}, %{{.*}} = "test.native_code_call4"({{%.*}}) : (i8) -> (i32, i32)
+  %4, %5 = "test.native_code_call4"(%3) : (i8) -> (i32, i32)
+  // CHECK: return %[[A]]
+  return %1 : i32
+}
+
 // CHECK-LABEL: verifyAllAttrConstraintOf
 func @verifyAllAttrConstraintOf() -> (i32, i32, i32) {
   // CHECK: "test.all_attr_constraint_of2"

diff  --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td
index eeb049482b886..60e4710688e75 100644
--- a/mlir/test/mlir-tblgen/rewriter-errors.td
+++ b/mlir/test/mlir-tblgen/rewriter-errors.td
@@ -1,5 +1,6 @@
 // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
 // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
+// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR3 %s 2>&1 | FileCheck --check-prefix=ERROR3 %s
 
 include "mlir/IR/OpBase.td"
 
@@ -16,14 +17,21 @@ def OpB : A_Op<"op_b">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(ou
 
 #ifdef ERROR1
 def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
-// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now
-def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg),
+// ERROR1: [[@LINE+1]]:1: error: NativeCodeCall must have $_self as argument for passing the defining Operation
+def : Pat<(OpA (NativeMatcher $val), AnyI32Attr:$arg),
           (OpB $val, $arg)>;
 #endif
 
 #ifdef ERROR2
-def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
-// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for 
+def NativeMatcher : NativeCodeCall<"success(nativeCall($_self, &$0))">;
+// ERROR2: [[@LINE+1]]:1: error: binding symbol 'error' to NativecodeCall in MatchPattern is not supported
+def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg),
+          (OpB $val, $arg)>;
+#endif
+
+#ifdef ERROR3
+def NativeMatcher : NativeCodeCall<"success(nativeCall($_self, $0, $1))">;
+// ERROR3: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for
 def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg),
           (OpB $val, $arg)>;
 #endif

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 28889de1ea607..e0112af6b5b03 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -252,7 +252,6 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
   // TODO(suderman): iterate through arguments, determine their types, output
   // names.
   SmallVector<std::string, 8> capture;
-  capture.push_back(opName.str());
 
   raw_indented_ostream::DelimitedScope scope(os);
 
@@ -265,8 +264,8 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
       auto leaf = tree.getArgAsLeaf(i);
       if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
         os << "Attribute " << argName << ";\n";
-      } else if (leaf.isOperandMatcher()) {
-        os << "Operation " << argName << ";\n";
+      } else {
+        os << "Value " << argName << ";\n";
       }
     }
 
@@ -278,20 +277,25 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
 
   auto fmt = tree.getNativeCodeTemplate();
-  auto nativeCodeCall =
-      std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), capture));
+  if (fmt.count("$_self") != 1) {
+    PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
+                         "passing the defining Operation");
+  }
+
+  auto nativeCodeCall = std::string(tgfmt(
+      fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture));
 
   os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n";
 
   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
     auto name = tree.getArgName(i);
     if (!name.empty() && name != "_") {
-      os << formatv("{0} = {1};\n", name, capture[i + 1]);
+      os << formatv("{0} = {1};\n", name, capture[i]);
     }
   }
 
   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
-    std::string argName = capture[i + 1];
+    std::string argName = capture[i];
 
     // Handle nested DAG construct first
     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
@@ -302,9 +306,18 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
     }
 
     DagLeaf leaf = tree.getArgAsLeaf(i);
+
+    // The parameter for native function doesn't bind any constraints.
+    if (leaf.isUnspecified())
+      continue;
+
     auto constraint = leaf.getAsConstraint();
 
-    auto self = formatv("{0}", argName);
+    std::string self;
+    if (leaf.isAttrMatcher() || leaf.isConstantAttr())
+      self = argName;
+    else
+      self = formatv("{0}.getType()", argName);
     emitMatchCheck(
         opName,
         tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
@@ -362,6 +375,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
       os << "{\n";
 
       // Attributes don't count for getODSOperands.
+      // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
       os.indent() << formatv(
           "auto *{0} = "
           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
@@ -929,7 +943,13 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
                             << " replacement: " << attrs[i] << "\n");
   }
 
-  return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs));
+  std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs);
+  if (!tree.getSymbol().empty()) {
+    os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol);
+    symbol = tree.getSymbol().str();
+  }
+
+  return symbol;
 }
 
 int PatternEmitter::getNodeValueCount(DagNode node) {


        


More information about the Mlir-commits mailing list