[Mlir-commits] [mlir] 2bf423b - [mlir] RewriterGen NativeCodeCall matcher with ConstantOp matcher

Rob Suderman llvmlistbot at llvm.org
Thu Oct 15 16:34:10 PDT 2020


Author: Rob Suderman
Date: 2020-10-15T16:32:20-07:00
New Revision: 2bf423b0218c9583e3a372950a34facbf93e63d3

URL: https://github.com/llvm/llvm-project/commit/2bf423b0218c9583e3a372950a34facbf93e63d3
DIFF: https://github.com/llvm/llvm-project/commit/2bf423b0218c9583e3a372950a34facbf93e63d3.diff

LOG: [mlir] RewriterGen NativeCodeCall matcher with ConstantOp matcher

Added an underlying matcher for generic constant ops. This
included a rewriter of RewriterGen to make variable use more
clear.

Differential Revision: https://reviews.llvm.org/D89161

Added: 
    mlir/test/mlir-tblgen/rewriter-errors.td

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 82dc6a456f29..72b3b1ab41f5 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2351,6 +2351,8 @@ class NativeCodeCall<string expr> {
   string expression = expr;
 }
 
+def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">;
+
 //===----------------------------------------------------------------------===//
 // Rewrite directives
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 4fc2ae762a66..98c5d9b18f5d 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -252,6 +252,9 @@ class SymbolInfoMap {
     static SymbolInfo getAttr(const Operator *op, int index) {
       return SymbolInfo(op, Kind::Attr, index);
     }
+    static SymbolInfo getAttr() {
+      return SymbolInfo(nullptr, Kind::Attr, llvm::None);
+    }
     static SymbolInfo getOperand(const Operator *op, int index) {
       return SymbolInfo(op, Kind::Operand, index);
     }
@@ -319,6 +322,10 @@ class SymbolInfoMap {
   // is already bound.
   bool bindValue(StringRef symbol);
 
+  // Registers the given `symbol` as bound to an attr. Returns false if `symbol`
+  // is already bound.
+  bool bindAttr(StringRef symbol);
+
   // Returns true if the given `symbol` is bound.
   bool contains(StringRef symbol) const;
 
@@ -421,6 +428,9 @@ class Pattern {
   std::vector<IdentifierLine> getLocation() const;
 
 private:
+  // Helper function to verify variabld binding.
+  void verifyBind(bool result, StringRef symbolName);
+
   // Recursively collects all bound symbols inside the DAG tree rooted
   // at `tree` and updates the given `infoMap`.
   void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 448f70359bd0..7044677fad36 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -216,9 +216,13 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
   LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
   switch (kind) {
   case Kind::Attr: {
-    auto type =
-        op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
-    return std::string(formatv("{0} {1};\n", type, name));
+    if (op) {
+      auto type =
+          op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
+      return std::string(formatv("{0} {1};\n", type, name));
+    }
+    // TODO(suderman): Use a more exact type when available.
+    return std::string(formatv("Attribute {0};\n", name));
   }
   case Kind::Operand: {
     // Use operand range for captured operands (to support potential variadic
@@ -394,6 +398,11 @@ bool SymbolInfoMap::bindValue(StringRef symbol) {
   return symbolInfoMap.count(inserted->first) == 1;
 }
 
+bool SymbolInfoMap::bindAttr(StringRef symbol) {
+  auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getAttr());
+  return symbolInfoMap.count(inserted->first) == 1;
+}
+
 bool SymbolInfoMap::contains(StringRef symbol) const {
   return find(symbol) != symbolInfoMap.end();
 }
@@ -558,15 +567,15 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
   for (auto it : *listInit) {
     auto *dagInit = dyn_cast<llvm::DagInit>(it);
     if (!dagInit)
-      PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity "
-                                    "constraints should be DAG nodes");
+      PrintFatalError(&def, "all elements in Pattern multi-entity "
+                            "constraints should be DAG nodes");
 
     std::vector<std::string> entities;
     entities.reserve(dagInit->arg_size());
     for (auto *argName : dagInit->getArgNames()) {
       if (!argName) {
         PrintFatalError(
-            def.getLoc(),
+            &def,
             "operands to additional constraints can only be symbol references");
       }
       entities.push_back(std::string(argName->getValue()));
@@ -584,7 +593,7 @@ int Pattern::getBenefit() const {
   int initBenefit = getSourcePattern().getNumOps();
   llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
-    PrintFatalError(def.getLoc(),
+    PrintFatalError(&def,
                     "The 'addBenefit' takes and only takes one integer value");
   }
   return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
@@ -603,64 +612,120 @@ std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
   return result;
 }
 
+void Pattern::verifyBind(bool result, StringRef symbolName) {
+  if (!result) {
+    auto err = formatv("symbol '{0}' bound more than once", symbolName);
+    PrintFatalError(&def, err);
+  }
+}
+
 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
                                   bool isSrcPattern) {
   auto treeName = tree.getSymbol();
-  if (!tree.isOperation()) {
+  auto numTreeArgs = tree.getNumArgs();
+
+  if (tree.isNativeCodeCall()) {
     if (!treeName.empty()) {
       PrintFatalError(
-          def.getLoc(),
-          formatv("binding symbol '{0}' to non-operation unsupported right now",
-                  treeName));
+          &def,
+          formatv(
+              "binding symbol '{0}' to native code call unsupported right now",
+              treeName));
     }
-    return;
-  }
 
-  auto &op = getDialectOp(tree);
-  auto numOpArgs = op.getNumArgs();
-  auto numTreeArgs = tree.getNumArgs();
-
-  // The pattern might have the last argument specifying the location.
-  bool hasLocDirective = false;
-  if (numTreeArgs != 0) {
-    if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
-      hasLocDirective = lastArg.isLocationDirective();
-  }
+    for (int i = 0; i != numTreeArgs; ++i) {
+      if (auto treeArg = tree.getArgAsNestedDag(i)) {
+        // This DAG node argument is a DAG node itself. Go inside recursively.
+        collectBoundSymbols(treeArg, infoMap, isSrcPattern);
+        continue;
+      }
 
-  if (numOpArgs != numTreeArgs - hasLocDirective) {
-    auto err = formatv("op '{0}' argument number mismatch: "
-                       "{1} in pattern vs. {2} in definition",
-                       op.getOperationName(), numTreeArgs, numOpArgs);
-    PrintFatalError(def.getLoc(), err);
-  }
+      if (!isSrcPattern)
+        continue;
 
-  // The name attached to the DAG node's operator is for representing the
-  // results generated from this op. It should be remembered as bound results.
-  if (!treeName.empty()) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "found symbol bound to op result: " << treeName << '\n');
-    if (!infoMap.bindOpResult(treeName, op))
-      PrintFatalError(def.getLoc(),
-                      formatv("symbol '{0}' bound more than once", treeName));
-  }
-
-  for (int i = 0; i != numTreeArgs; ++i) {
-    if (auto treeArg = tree.getArgAsNestedDag(i)) {
-      // This DAG node argument is a DAG node itself. Go inside recursively.
-      collectBoundSymbols(treeArg, infoMap, isSrcPattern);
-    } else if (isSrcPattern) {
-      // We can only bind symbols to op arguments in source pattern. Those
+      // We can only bind symbols to arguments in source pattern. Those
       // symbols are referenced in result patterns.
       auto treeArgName = tree.getArgName(i);
+
       // `$_` is a special symbol meaning ignore the current argument.
       if (!treeArgName.empty() && treeArgName != "_") {
-        LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
-                                << treeArgName << '\n');
-        if (!infoMap.bindOpArgument(treeArgName, op, i)) {
-          auto err = formatv("symbol '{0}' bound more than once", treeArgName);
-          PrintFatalError(def.getLoc(), err);
+        if (tree.isNestedDagArg(i)) {
+          auto err = formatv("cannot bind '{0}' for nested native call arg",
+                             treeArgName);
+          PrintFatalError(&def, err);
         }
+
+        DagLeaf leaf = tree.getArgAsLeaf(i);
+        auto constraint = leaf.getAsConstraint();
+        bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
+                      leaf.isConstantAttr() ||
+                      constraint.getKind() == Constraint::Kind::CK_Attr;
+
+        if (isAttr) {
+          verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
+          continue;
+        }
+
+        verifyBind(infoMap.bindValue(treeArgName), treeArgName);
       }
     }
+
+    return;
+  }
+
+  if (tree.isOperation()) {
+    auto &op = getDialectOp(tree);
+    auto numOpArgs = op.getNumArgs();
+
+    // The pattern might have the last argument specifying the location.
+    bool hasLocDirective = false;
+    if (numTreeArgs != 0) {
+      if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
+        hasLocDirective = lastArg.isLocationDirective();
+    }
+
+    if (numOpArgs != numTreeArgs - hasLocDirective) {
+      auto err = formatv("op '{0}' argument number mismatch: "
+                         "{1} in pattern vs. {2} in definition",
+                         op.getOperationName(), numTreeArgs, numOpArgs);
+      PrintFatalError(&def, err);
+    }
+
+    // The name attached to the DAG node's operator is for representing the
+    // results generated from this op. It should be remembered as bound results.
+    if (!treeName.empty()) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "found symbol bound to op result: " << treeName << '\n');
+      verifyBind(infoMap.bindOpResult(treeName, op), treeName);
+    }
+
+    for (int i = 0; i != numTreeArgs; ++i) {
+      if (auto treeArg = tree.getArgAsNestedDag(i)) {
+        // This DAG node argument is a DAG node itself. Go inside recursively.
+        collectBoundSymbols(treeArg, infoMap, isSrcPattern);
+        continue;
+      }
+
+      if (isSrcPattern) {
+        // We can only bind symbols to op arguments in source pattern. Those
+        // symbols are referenced in result patterns.
+        auto treeArgName = tree.getArgName(i);
+        // `$_` is a special symbol meaning ignore the current argument.
+        if (!treeArgName.empty() && treeArgName != "_") {
+          LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
+                                  << treeArgName << '\n');
+          verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
+        }
+      }
+    }
+    return;
+  }
+
+  if (!treeName.empty()) {
+    PrintFatalError(
+        &def, formatv("binding symbol '{0}' to non-operation/native code call "
+                      "unsupported right now",
+                      treeName));
   }
+  return;
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 3bfb82495ce1..d34e997644a5 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -615,6 +615,10 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
   return operand();
 }
 
+OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
+  return getValue();
+}
+
 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
   for (Value input : this->operands()) {

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index aef39a9e19fe..fcc677361dcc 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -799,6 +799,22 @@ def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> {
   let hasCanonicalizer = 1;
 }
 
+def TestOpConstant : TEST_Op<"constant", [ConstantLike, NoSideEffect]> {
+  let arguments = (ins AnyAttr:$value);
+  let results = (outs AnyType);
+  let extraClassDeclaration = [{
+    Attribute getValue() { return getAttr("value"); }
+  }];
+
+  let hasFolder = 1;
+}
+
+def OpR : TEST_Op<"op_r">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>;
+def OpS : TEST_Op<"op_s">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>;
+
+def : Pat<(OpR $input1, (ConstantLikeMatcher I32Attr:$input2)),
+          (OpS:$unused $input1, $input2)>;
+
 // Op for testing trivial removal via folding of op with inner ops and no uses.
 def TestOpWithRegionFoldNoSideEffect : TEST_Op<
     "op_with_region_fold_no_side_effect", [NoSideEffect]> {

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 32d618d9008e..282d31065549 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -9,6 +9,7 @@
 #include "TestDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 5986be6240f9..616e116cb170 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -248,6 +248,58 @@ func @verifyUnitAttr() -> (i32, i32) {
   return %0, %1 : i32, i32
 }
 
+//===----------------------------------------------------------------------===//
+// Test Constant Matching
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: testConstOp
+func @testConstOp() -> (i32) {
+  // CHECK-NEXT: [[C0:%.+]] = constant 1
+  %0 = "test.constant"() {value = 1 : i32} : () -> i32
+
+  // CHECK-NEXT: return [[C0]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: testConstOpUsed
+func @testConstOpUsed() -> (i32) {
+  // CHECK-NEXT: [[C0:%.+]] = constant 1
+  %0 = "test.constant"() {value = 1 : i32} : () -> i32
+
+  // CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]])
+  %1 = "test.op_s"(%0) {value = 1 : i32} : (i32) -> i32
+
+  // CHECK-NEXT: return [[V0]]
+  return %1 : i32
+}
+
+// CHECK-LABEL: testConstOpReplaced
+func @testConstOpReplaced() -> (i32) {
+  // CHECK-NEXT: [[C0:%.+]] = constant 1
+  %0 = "test.constant"() {value = 1 : i32} : () -> i32
+  %1 = "test.constant"() {value = 2 : i32} : () -> i32
+
+  // CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) {value = 2 : i32}
+  %2 = "test.op_r"(%0, %1) : (i32, i32) -> i32
+
+  // CHECK: [[V0]]
+  return %2 : i32
+}
+// CHECK-LABEL: testConstOpMatchFailure
+func @testConstOpMatchFailure() -> (i64) {
+  // CHECK-DAG: [[C0:%.+]] = constant 1
+  %0 = "test.constant"() {value = 1 : i64} : () -> i64
+
+  // CHECK-DAG: [[C1:%.+]] = constant 2
+  %1 = "test.constant"() {value = 2 : i64} : () -> i64
+
+  // CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]])
+  %2 = "test.op_r"(%0, %1) : (i64, i64) -> i64
+
+  // CHECK: [[V0]]
+  return %2 : i64
+}
+
 //===----------------------------------------------------------------------===//
 // Test Enum Attributes
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td
new file mode 100644
index 000000000000..eeb049482b88
--- /dev/null
+++ b/mlir/test/mlir-tblgen/rewriter-errors.td
@@ -0,0 +1,29 @@
+// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
+// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
+
+include "mlir/IR/OpBase.td"
+
+// Check using the dialect name as the namespace
+def A_Dialect : Dialect {
+  let name = "a";
+}
+
+class A_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<A_Dialect, mnemonic, traits>;
+
+def OpA : A_Op<"op_a">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>;
+def OpB : A_Op<"op_b">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>;
+
+#ifdef ERROR1
+def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
+// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now
+def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg),
+          (OpB $val, $arg)>;
+#endif
+
+#ifdef ERROR2
+def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">;
+// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for 
+def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg),
+          (OpB $val, $arg)>;
+#endif

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 7bff3e3b40b6..5521eea38252 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -63,7 +63,7 @@ class PatternEmitter {
 
 private:
   // Emits the code for matching ops.
-  void emitMatchLogic(DagNode tree);
+  void emitMatchLogic(DagNode tree, StringRef opName);
 
   // Emits the code for rewriting ops.
   void emitRewriteLogic();
@@ -72,26 +72,34 @@ class PatternEmitter {
   // Match utilities
   //===--------------------------------------------------------------------===//
 
+  // Emits C++ statements for matching the DAG structure.
+  void emitMatch(DagNode tree, StringRef name, int depth);
+
+  // Emits C++ statements for matching using a native code call.
+  void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
+
   // Emits C++ statements for matching the op constrained by the given DAG
-  // `tree`.
-  void emitOpMatch(DagNode tree, int depth);
+  // `tree` returning the op's variable name.
+  void emitOpMatch(DagNode tree, StringRef opName, int depth);
 
   // Emits C++ statements for matching the `argIndex`-th argument of the given
   // DAG `tree` as an operand.
-  void emitOperandMatch(DagNode tree, int argIndex, int depth);
+  void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
+                        int depth);
 
   // Emits C++ statements for matching the `argIndex`-th argument of the given
   // DAG `tree` as an attribute.
-  void emitAttributeMatch(DagNode tree, int argIndex, int depth);
+  void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
+                          int depth);
 
   // Emits C++ for checking a match with a corresponding match failure
   // diagnostic.
-  void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
+  void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
                       const llvm::formatv_object_base &failureFmt);
 
   // Emits C++ for checking a match with a corresponding match failure
   // diagnostics.
-  void emitMatchCheck(int depth, const std::string &matchStr,
+  void emitMatchCheck(StringRef opName, const std::string &matchStr,
                       const std::string &failureStr);
 
   //===--------------------------------------------------------------------===//
@@ -113,7 +121,7 @@ class PatternEmitter {
 
   // Emits the C++ statement to replace the matched DAG with a value built via
   // calling native C++ code.
-  std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
+  std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
 
   // Returns the symbol of the old value serving as the replacement.
   StringRef handleReplaceWithValue(DagNode tree);
@@ -140,12 +148,13 @@ class PatternEmitter {
 
   // Emits the concrete arguments used to call an op's builder.
   void supplyValuesForOpArgs(DagNode node,
-                             const ChildNodeIndexNameMap &childNodeNames);
+                             const ChildNodeIndexNameMap &childNodeNames,
+                             int depth);
 
   // Emits the local variables for holding all values as a whole and all named
   // attributes as a whole to be used for creating an op.
   void createAggregateLocalVarsForOpArgs(
-      DagNode node, const ChildNodeIndexNameMap &childNodeNames);
+      DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
 
   // Returns the C++ expression to construct a constant attribute of the given
   // `value` for the given attribute kind `attr`.
@@ -218,21 +227,114 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
 }
 
 // Helper function to match patterns.
-void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
+void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
+  if (tree.isNativeCodeCall()) {
+    emitNativeCodeMatch(tree, name, depth);
+    return;
+  }
+
+  if (tree.isOperation()) {
+    emitOpMatch(tree, name, depth);
+    return;
+  }
+
+  PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
+}
+
+// Helper function to match patterns.
+void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
+                                         int depth) {
+  LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
+  LLVM_DEBUG(tree.print(llvm::dbgs()));
+  LLVM_DEBUG(llvm::dbgs() << '\n');
+
+  // TODO(suderman): iterate through arguments, determine their types, output
+  // names.
+  SmallVector<std::string, 8> capture(8);
+  if (tree.getNumArgs() > 8) {
+    PrintFatalError(loc,
+                    "unsupported NativeCodeCall matcher argument numbers: " +
+                        Twine(tree.getNumArgs()));
+  }
+
+  raw_indented_ostream::DelimitedScope scope(os);
+
+  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+    std::string argName = formatv("arg{0}_{1}", depth, i);
+    if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+      os << "Value " << argName << ";\n";
+    } else {
+      auto leaf = tree.getArgAsLeaf(i);
+      if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
+        os << "Attribute " << argName << ";\n";
+      } else if (leaf.isOperandMatcher()) {
+        os << "Operation " << argName << ";\n";
+      }
+    }
+
+    capture[i] = std::move(argName);
+  }
+
+  bool hasLocationDirective;
+  std::string locToUse;
+  std::tie(hasLocationDirective, locToUse) = getLocation(tree);
+
+  auto fmt = tree.getNativeCodeTemplate();
+  auto nativeCodeCall = std::string(tgfmt(
+      fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
+      capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
+
+  os << "if (failed(" << nativeCodeCall << ")) return failure();\n";
+
+  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+    auto name = tree.getArgName(i);
+    if (!name.empty() && name != "_") {
+      os << formatv("{0} = {1};\n", name, capture[i]);
+    }
+  }
+
+  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+    std::string argName = capture[i];
+
+    // Handle nested DAG construct first
+    if (DagNode argTree = tree.getArgAsNestedDag(i)) {
+      PrintFatalError(
+          loc, formatv("Matching nested tree in NativeCodecall not support for "
+                       "{0} as arg {1}",
+                       argName, i));
+    }
+
+    DagLeaf leaf = tree.getArgAsLeaf(i);
+    auto constraint = leaf.getAsConstraint();
+
+    auto self = formatv("{0}", argName);
+    emitMatchCheck(
+        opName,
+        tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
+        formatv("\"operand {0} of native code call '{1}' failed to satisfy "
+                "constraint: "
+                "'{2}'\"",
+                i, tree.getNativeCodeTemplate(), constraint.getDescription()));
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
+}
+
+// Helper function to match patterns.
+void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
   Operator &op = tree.getDialectOp(opMap);
   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
                           << op.getOperationName() << "' at depth " << depth
                           << '\n');
 
-  int indent = 4 + 2 * depth;
-  os.indent(indent) << formatv(
-      "auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); "
-      "(void)castedOp{0};\n",
-      depth, op.getQualCppClassName());
+  std::string castedName = formatv("castedOp{0}", depth);
+  os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); "
+                "(void){0};\n",
+                castedName, opName, op.getQualCppClassName());
   // Skip the operand matching at depth 0 as the pattern rewriter already does.
   if (depth != 0) {
     // Skip if there is no defining operation (e.g., arguments to function).
-    os << formatv("if (!castedOp{0})\n  return failure();\n", depth);
+    os << formatv("if (!{0}) return failure();\n", castedName);
   }
   if (tree.getNumArgs() != op.getNumArgs()) {
     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
@@ -244,10 +346,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
   // If the operand's name is set, set to that variable.
   auto name = tree.getSymbol();
   if (!name.empty())
-    os << formatv("{0} = castedOp{1};\n", name, depth);
+    os << formatv("{0} = {1};\n", name, castedName);
 
   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
     auto opArg = op.getArg(i);
+    std::string argName = formatv("op{0}", depth + 1);
 
     // Handle nested DAG construct first
     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
@@ -262,20 +365,20 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
       os << "{\n";
 
       os.indent() << formatv(
-          "auto *op{0} = "
-          "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
-          depth + 1, depth, i);
-      emitOpMatch(argTree, depth + 1);
-      os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
+          "auto *{0} = "
+          "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
+          argName, castedName, i);
+      emitMatch(argTree, argName, depth + 1);
+      os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
       os.unindent() << "}\n";
       continue;
     }
 
     // Next handle DAG leaf: operand or attribute
     if (opArg.is<NamedTypeConstraint *>()) {
-      emitOperandMatch(tree, i, depth);
+      emitOperandMatch(tree, castedName, i, depth);
     } else if (opArg.is<NamedAttribute *>()) {
-      emitAttributeMatch(tree, i, depth);
+      emitAttributeMatch(tree, opName, i, depth);
     } else {
       PrintFatalError(loc, "unhandled case when matching op");
     }
@@ -285,7 +388,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
                           << '\n');
 }
 
-void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
+void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
+                                      int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);
   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
   auto matcher = tree.getArgAsLeaf(argIndex);
@@ -309,11 +413,10 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
             op.getOperationName(), argIndex);
         PrintFatalError(loc, error);
       }
-      auto self =
-          formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth,
-                  argIndex);
+      auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
+                          opName, argIndex);
       emitMatchCheck(
-          depth,
+          opName,
           tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
           formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
                   "'{2}'\"",
@@ -333,21 +436,22 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
 
     auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
-    os << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
-                  res->second.getVarName(name), depth, argIndex - numPrevAttrs);
+    os << formatv("{0} = {1}.getODSOperands({2});\n",
+                  res->second.getVarName(name), opName,
+                  argIndex - numPrevAttrs);
   }
 }
 
-void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
+void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
+                                        int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);
   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
   const auto &attr = namedAttr->attr;
 
   os << "{\n";
-  os.indent() << formatv(
-      "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); "
-      "(void)tblgen_attr;\n",
-      depth, attr.getStorageType(), namedAttr->name);
+  os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
+                         "(void)tblgen_attr;\n",
+                         opName, attr.getStorageType(), namedAttr->name);
 
   // TODO: This should use getter method to avoid duplication.
   if (attr.hasDefaultValue()) {
@@ -360,7 +464,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
     // should just capture a mlir::Attribute() to signal the missing state.
     // That is precisely what getAttr() returns on missing attributes.
   } else {
-    emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx),
+    emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
                    formatv("\"expected op '{0}' to have attribute '{1}' "
                            "of type '{2}'\"",
                            op.getOperationName(), namedAttr->name,
@@ -378,7 +482,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
     // If a constraint is specified, we need to generate C++ statements to
     // check the constraint.
     emitMatchCheck(
-        depth,
+        opName,
         tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
                 "{2}\"",
@@ -397,24 +501,25 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
 }
 
 void PatternEmitter::emitMatchCheck(
-    int depth, const FmtObjectBase &matchFmt,
+    StringRef opName, const FmtObjectBase &matchFmt,
     const llvm::formatv_object_base &failureFmt) {
-  emitMatchCheck(depth, matchFmt.str(), failureFmt.str());
+  emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
 }
 
-void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr,
+void PatternEmitter::emitMatchCheck(StringRef opName,
+                                    const std::string &matchStr,
                                     const std::string &failureStr) {
+
   os << "if (!(" << matchStr << "))";
-  os.scope("{\n", "\n}\n").os
-      << "return rewriter.notifyMatchFailure(op" << depth
-      << ", [&](::mlir::Diagnostic &diag) {\n  diag << " << failureStr
-      << ";\n});";
+  os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
+                              << ", [&](::mlir::Diagnostic &diag) {\n  diag << "
+                              << failureStr << ";\n});";
 }
 
-void PatternEmitter::emitMatchLogic(DagNode tree) {
+void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
   int depth = 0;
-  emitOpMatch(tree, depth);
+  emitMatch(tree, opName, depth);
 
   for (auto &appliedConstraint : pattern.getConstraints()) {
     auto &constraint = appliedConstraint.constraint;
@@ -425,7 +530,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
       auto self = formatv("({0}.getType())",
                           symbolInfoMap.getValueAndRangeUse(entities.front()));
       emitMatchCheck(
-          depth, tgfmt(condition, &fmtCtx.withSelf(self.str())),
+          opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
           formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
                   entities.front(), constraint.getDescription()));
 
@@ -447,7 +552,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
         self = symbolInfoMap.getValueAndRangeUse(self);
       for (; i < 4; ++i)
         names.push_back("<unused>");
-      emitMatchCheck(depth,
+      emitMatchCheck(opName,
                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
                            names[1], names[2], names[3]),
                      formatv("\"entities '{0}' failed to satisfy constraint: "
@@ -471,7 +576,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
     for (++startRange; startRange != endRange; ++startRange) {
       auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
       emitMatchCheck(
-          depth,
+          opName,
           formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
           formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
                   secondOperand));
@@ -567,7 +672,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
 
       os << "// Match\n";
       os << "tblgen_ops[0] = op0;\n";
-      emitMatchLogic(sourceTree);
+      emitMatchLogic(sourceTree, "op0");
 
       os << "\n// Rewrite\n";
       emitRewriteLogic();
@@ -681,7 +786,7 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
   }
 
   if (resultTree.isNativeCodeCall()) {
-    auto symbol = handleReplaceWithNativeCodeCall(resultTree);
+    auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
     symbolInfoMap.bindValue(symbol);
     return symbol;
   }
@@ -798,7 +903,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
   PrintFatalError(loc, "unhandled case when rewriting op");
 }
 
-std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
+std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
+                                                            int depth) {
   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
   LLVM_DEBUG(tree.print(llvm::dbgs()));
   LLVM_DEBUG(llvm::dbgs() << '\n');
@@ -807,15 +913,20 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
   // TODO: replace formatv arguments with the exact specified args.
   SmallVector<std::string, 8> attrs(8);
   if (tree.getNumArgs() > 8) {
-    PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
-                             Twine(tree.getNumArgs()));
+    PrintFatalError(loc,
+                    "unsupported NativeCodeCall replace argument numbers: " +
+                        Twine(tree.getNumArgs()));
   }
   bool hasLocationDirective;
   std::string locToUse;
   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
 
   for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
-    attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
+    if (tree.isNestedDagArg(i)) {
+      attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
+    } else {
+      attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
+    }
     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
                             << " replacement: " << attrs[i] << "\n");
   }
@@ -924,7 +1035,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     // create the ops.
 
     // First prepare local variables for op arguments used in builder call.
-    createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+    createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
 
     // Then create the op.
     os.scope("", "\n}\n").os << formatv(
@@ -948,7 +1059,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
 
     os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
                              resultOp.getQualCppClassName(), locToUse);
-    supplyValuesForOpArgs(tree, childNodeNames);
+    supplyValuesForOpArgs(tree, childNodeNames, depth);
     os << "\n  );\n}\n";
     return resultValue;
   }
@@ -959,7 +1070,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   // here.
 
   // First prepare local variables for op arguments used in builder call.
-  createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+  createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
 
   // Then prepare the result types. We need to specify the types for all
   // results.
@@ -1037,7 +1148,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
 }
 
 void PatternEmitter::supplyValuesForOpArgs(
-    DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
+    DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
   Operator &resultOp = node.getDialectOp(opMap);
   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
        argIndex != numOpArgs; ++argIndex) {
@@ -1060,7 +1171,7 @@ void PatternEmitter::supplyValuesForOpArgs(
         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                              "for creating attribute");
       os << formatv("/*{0}=*/{1}", opArgName,
-                    handleReplaceWithNativeCodeCall(subTree));
+                    handleReplaceWithNativeCodeCall(subTree, depth));
     } else {
       auto leaf = node.getArgAsLeaf(argIndex);
       // The argument in the result DAG pattern.
@@ -1080,7 +1191,7 @@ void PatternEmitter::supplyValuesForOpArgs(
 }
 
 void PatternEmitter::createAggregateLocalVarsForOpArgs(
-    DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
+    DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
   Operator &resultOp = node.getDialectOp(opMap);
 
   auto scope = os.scope();
@@ -1102,7 +1213,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                                "for creating attribute");
         os << formatv(addAttrCmd, opArgName,
-                      handleReplaceWithNativeCodeCall(subTree));
+                      handleReplaceWithNativeCodeCall(subTree, depth + 1));
       } else {
         auto leaf = node.getArgAsLeaf(argIndex);
         // The argument in the result DAG pattern.


        


More information about the Mlir-commits mailing list