[Mlir-commits] [mlir] f3798ad - Static verifier for type/attribute in DRR

Chia-hung Duan llvmlistbot at llvm.org
Mon Nov 8 13:40:39 PST 2021


Author: Chia-hung Duan
Date: 2021-11-08T21:34:17Z
New Revision: f3798ad5fa845771846599f3c088016e3aef800c

URL: https://github.com/llvm/llvm-project/commit/f3798ad5fa845771846599f3c088016e3aef800c
DIFF: https://github.com/llvm/llvm-project/commit/f3798ad5fa845771846599f3c088016e3aef800c.diff

LOG: Static verifier for type/attribute in DRR

Generate static function for matching the type/attribute to reduce the
memory footprint.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D110199

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/CodeGenHelpers.h
    mlir/include/mlir/TableGen/Pattern.h
    mlir/test/mlir-tblgen/rewriter-static-matcher.td
    mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index c7252a6b99e3..2dd1cf64667f 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Support/IndentedOstream.h"
 #include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/Format.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
@@ -91,8 +92,7 @@ class NamespaceEmitter {
 ///
 class StaticVerifierFunctionEmitter {
 public:
-  StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
-                                raw_ostream &os);
+  StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records);
 
   /// Emit the static verifier functions for `llvm::Record`s. The
   /// `signatureFormat` describes the required arguments and it must have a
@@ -112,30 +112,40 @@ class StaticVerifierFunctionEmitter {
   ///
   /// `typeArgName` is used to identify the argument that needs to check its
   /// type. The constraint template will replace `$_self` with it.
-  void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat,
-                        StringRef typeArgName, ArrayRef<llvm::Record *> opDefs,
-                        bool emitDecl);
+
+  /// This is the helper to generate the constraint functions from op
+  /// definitions.
+  void emitConstraintMethodsInNamespace(StringRef signatureFormat,
+                                        StringRef errorHandlerFormat,
+                                        StringRef cppNamespace,
+                                        ArrayRef<const void *> constraints,
+                                        raw_ostream &rawOs, bool emitDecl);
+
+  /// Emit the static functions for the giving type constraints.
+  void emitConstraintMethods(StringRef signatureFormat,
+                             StringRef errorHandlerFormat,
+                             ArrayRef<const void *> constraints,
+                             raw_ostream &rawOs, bool emitDecl);
 
   /// Get the name of the local function used for the given type constraint.
   /// These functions are used for operand and result constraints and have the
   /// form:
   ///   LogicalResult(Operation *op, Type type, StringRef valueKind,
   ///                 unsigned valueGroupStartIndex);
-  StringRef getTypeConstraintFn(const Constraint &constraint) const;
+  StringRef getConstraintFn(const Constraint &constraint) const;
+
+  /// The setter to set `self` in format context.
+  StaticVerifierFunctionEmitter &setSelf(StringRef str);
+
+  /// The setter to set `builder` in format context.
+  StaticVerifierFunctionEmitter &setBuilder(StringRef str);
 
 private:
   /// Returns a unique name to use when generating local methods.
   static std::string getUniqueName(const llvm::RecordKeeper &records);
 
-  /// Emit local methods for the type constraints used within the provided op
-  /// definitions.
-  void emitTypeConstraintMethods(StringRef signatureFormat,
-                                 StringRef errorHandlerFormat,
-                                 StringRef typeArgName,
-                                 ArrayRef<llvm::Record *> opDefs,
-                                 bool emitDecl);
-
-  raw_indented_ostream os;
+  /// The format context used for building the verifier function.
+  FmtContext fctx;
 
   /// A unique label for the file currently being generated. This is used to
   /// ensure that the local functions have a unique name.

diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index fdc510447d1f..834ebdace12d 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -113,6 +113,9 @@ class DagLeaf {
   void print(raw_ostream &os) const;
 
 private:
+  friend llvm::DenseMapInfo<DagLeaf>;
+  const void *getAsOpaquePointer() const { return def; }
+
   // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
   // also a subclass of the given `superclass`.
   bool isSubClassOf(StringRef superclass) const;
@@ -523,6 +526,24 @@ struct DenseMapInfo<mlir::tblgen::DagNode> {
     return lhs.node == rhs.node;
   }
 };
+
+template <>
+struct DenseMapInfo<mlir::tblgen::DagLeaf> {
+  static mlir::tblgen::DagLeaf getEmptyKey() {
+    return mlir::tblgen::DagLeaf(
+        llvm::DenseMapInfo<llvm::Init *>::getEmptyKey());
+  }
+  static mlir::tblgen::DagLeaf getTombstoneKey() {
+    return mlir::tblgen::DagLeaf(
+        llvm::DenseMapInfo<llvm::Init *>::getTombstoneKey());
+  }
+  static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) {
+    return llvm::hash_value(leaf.getAsOpaquePointer());
+  }
+  static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs) {
+    return lhs.def == rhs.def;
+  }
+};
 } // end namespace llvm
 
 #endif // MLIR_TABLEGEN_PATTERN_H_

diff  --git a/mlir/test/mlir-tblgen/rewriter-static-matcher.td b/mlir/test/mlir-tblgen/rewriter-static-matcher.td
index cfd80a40fb1c..2e84e3476235 100644
--- a/mlir/test/mlir-tblgen/rewriter-static-matcher.td
+++ b/mlir/test/mlir-tblgen/rewriter-static-matcher.td
@@ -37,12 +37,16 @@ def COp : NS_Op<"c_op", []> {
 // Test static matcher for duplicate DagNode
 // ---
 
-// CHECK: static ::mlir::LogicalResult static_dag_matcher_0
+// CHECK-DAG: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Type typeOrAttr}}
+// CHECK-DAG: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Attribute}}
+// CHECK-DAG: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]](
+// CHECK: if(failed([[$TYPE_CONSTRAINT]]
+// CHECK: if(failed([[$ATTR_CONSTRAINT]]
 
-// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
+// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
 def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)),
           (AOp $int)>;
 
-// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
+// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
 def : Pat<(COp $_, (BOp I32Attr:$attr, I32:$int)),
           (COp $attr, $int)>;

diff  --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
index 1528c56f906e..6aab5abf0d6e 100644
--- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
@@ -12,7 +12,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/TableGen/CodeGenHelpers.h"
-#include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -24,21 +23,34 @@ using namespace mlir;
 using namespace mlir::tblgen;
 
 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
-    const llvm::RecordKeeper &records, raw_ostream &os)
-    : os(os), uniqueOutputLabel(getUniqueName(records)) {}
+    const llvm::RecordKeeper &records)
+    : uniqueOutputLabel(getUniqueName(records)) {}
 
-void StaticVerifierFunctionEmitter::emitFunctionsFor(
+StaticVerifierFunctionEmitter &
+StaticVerifierFunctionEmitter::setSelf(StringRef str) {
+  fctx.withSelf(str);
+  return *this;
+}
+
+StaticVerifierFunctionEmitter &
+StaticVerifierFunctionEmitter::setBuilder(StringRef str) {
+  fctx.withBuilder(str);
+  return *this;
+}
+
+void StaticVerifierFunctionEmitter::emitConstraintMethodsInNamespace(
     StringRef signatureFormat, StringRef errorHandlerFormat,
-    StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
+    StringRef cppNamespace, ArrayRef<const void *> constraints, raw_ostream &os,
+    bool emitDecl) {
   llvm::Optional<NamespaceEmitter> namespaceEmitter;
   if (!emitDecl)
-    namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
+    namespaceEmitter.emplace(os, cppNamespace);
 
-  emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName,
-                            opDefs, emitDecl);
+  emitConstraintMethods(signatureFormat, errorHandlerFormat, constraints, os,
+                        emitDecl);
 }
 
-StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
+StringRef StaticVerifierFunctionEmitter::getConstraintFn(
     const Constraint &constraint) const {
   auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
   assert(it != localTypeConstraints.end() && "expected valid constraint fn");
@@ -65,28 +77,16 @@ std::string StaticVerifierFunctionEmitter::getUniqueName(
   return uniqueName;
 }
 
-void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
+void StaticVerifierFunctionEmitter::emitConstraintMethods(
     StringRef signatureFormat, StringRef errorHandlerFormat,
-    StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
-  // Collect a set of all of the used type constraints within the operation
-  // definitions.
-  llvm::SetVector<const void *> typeConstraints;
-  for (Record *def : opDefs) {
-    Operator op(*def);
-    for (NamedTypeConstraint &operand : op.getOperands())
-      if (operand.hasPredicate())
-        typeConstraints.insert(operand.constraint.getAsOpaquePointer());
-    for (NamedTypeConstraint &result : op.getResults())
-      if (result.hasPredicate())
-        typeConstraints.insert(result.constraint.getAsOpaquePointer());
-  }
+    ArrayRef<const void *> constraints, raw_ostream &rawOs, bool emitDecl) {
+  raw_indented_ostream os(rawOs);
 
   // Record the mapping from predicate to constraint. If two constraints has the
   // same predicate and constraint summary, they can share the same verification
   // function.
   llvm::DenseMap<Pred, const void *> predToConstraint;
-  FmtContext fctx;
-  for (auto it : llvm::enumerate(typeConstraints)) {
+  for (auto it : llvm::enumerate(constraints)) {
     std::string name;
     Constraint constraint = Constraint::getFromOpaquePointer(it.value());
     Pred pred = constraint.getPredicate();
@@ -101,7 +101,7 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
         // summary, otherwise we may report the wrong message while verification
         // fails.
         if (constraint.getSummary() == built.getSummary()) {
-          name = getTypeConstraintFn(built).str();
+          name = getConstraintFn(built).str();
           break;
         }
         ++iter;
@@ -126,12 +126,11 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
       continue;
 
     os << formatv(signatureFormat.data(), name) << " {\n";
-    os.indent() << "if (!("
-                << tgfmt(constraint.getConditionTemplate(),
-                         &fctx.withSelf(typeArgName))
+    os.indent() << "if (!(" << tgfmt(constraint.getConditionTemplate(), &fctx)
                 << ")) {\n";
     os.indent() << "return "
-                << formatv(errorHandlerFormat.data(), constraint.getSummary())
+                << formatv(errorHandlerFormat.data(),
+                           escapeString(constraint.getSummary()))
                 << ";\n";
     os.unindent() << "}\nreturn ::mlir::success();\n";
     os.unindent() << "}\n\n";

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 6e88480c2fb1..73e5ef951a9a 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2233,7 +2233,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
       continue;
     // Emit a loop to check all the dynamic values in the pack.
     StringRef constraintFn =
-        staticVerifierEmitter.getTypeConstraintFn(value.constraint);
+        staticVerifierEmitter.getConstraintFn(value.constraint);
     body << "    for (::mlir::Value v : valueGroup" << staticValue.index()
          << ") {\n"
          << "      if (::mlir::failed(" << constraintFn
@@ -2639,11 +2639,27 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
     return;
 
   // Generate all of the locally instantiated methods first.
-  StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os);
+  StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper);
   os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
-  staticVerifierEmitter.emitFunctionsFor(
-      typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type",
-      defs, emitDecl);
+  staticVerifierEmitter.setSelf("type");
+
+  // Collect a set of all of the used type constraints within the operation
+  // definitions.
+  llvm::SetVector<const void *> typeConstraints;
+  for (Record *def : defs) {
+    Operator op(*def);
+    for (NamedTypeConstraint &operand : op.getOperands())
+      if (operand.hasPredicate())
+        typeConstraints.insert(operand.constraint.getAsOpaquePointer());
+    for (NamedTypeConstraint &result : op.getResults())
+      if (result.hasPredicate())
+        typeConstraints.insert(result.constraint.getAsOpaquePointer());
+  }
+
+  staticVerifierEmitter.emitConstraintMethodsInNamespace(
+      typeVerifierSignature, typeVerifierErrorHandler,
+      Operator(*defs[0]).getCppNamespace(), typeConstraints.getArrayRef(), os,
+      emitDecl);
 
   for (auto *def : defs) {
     Operator op(*def);

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 2d2694940850..318f006cdf45 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -20,6 +20,7 @@
 #include "mlir/TableGen/Predicate.h"
 #include "mlir/TableGen/Type.h"
 #include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
@@ -41,6 +42,23 @@ using llvm::RecordKeeper;
 
 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
 
+// The signature of static type verification function
+static const char *typeVerifierSignature =
+    "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
+    "::mlir::Operation *op, ::mlir::Type typeOrAttr, "
+    "::llvm::StringRef failureStr)";
+
+// The signature of static attribute verification function
+static const char *attrVerifierSignature =
+    "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
+    "::mlir::Operation *op, ::mlir::Attribute typeOrAttr, "
+    "::llvm::StringRef failureStr)";
+
+// The template of error handler in static type/attribute verification function
+static const char *verifierErrorHandler =
+    "rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {\n  diag "
+    "<< failureStr <<  \": {0}\";\n});";
+
 namespace llvm {
 template <>
 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
@@ -87,6 +105,10 @@ class PatternEmitter {
   // Emit C++ function call to static DAG matcher.
   void emitStaticMatchCall(DagNode tree, StringRef name);
 
+  // Emit C++ function call to static type/attribute constraint function.
+  void emitStaticVerifierCall(StringRef funcName, StringRef opName,
+                              StringRef arg, StringRef failureStr);
+
   // Emits C++ statements for matching using a native code call.
   void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
 
@@ -244,7 +266,8 @@ class PatternEmitter {
 // inlining them.
 class StaticMatcherHelper {
 public:
-  StaticMatcherHelper(RecordOperatorMap &mapper);
+  StaticMatcherHelper(const RecordKeeper &recordKeeper,
+                      RecordOperatorMap &mapper);
 
   // Determine if we should inline the match logic or delegate to a static
   // function.
@@ -258,6 +281,9 @@ class StaticMatcherHelper {
     return matcherNames[node];
   }
 
+  // Get the name of static type/attribute verification function.
+  StringRef getVerifierName(Constraint constraint);
+
   // Collect the `Record`s, i.e., the DRR, so that we can get the information of
   // the duplicated DAGs.
   void addPattern(Record *record);
@@ -265,6 +291,9 @@ class StaticMatcherHelper {
   // Emit all static functions of DAG Matcher.
   void populateStaticMatchers(raw_ostream &os);
 
+  // Emit all static functions for Constraints.
+  void populateStaticConstraintFunctions(raw_ostream &os);
+
 private:
   static constexpr unsigned kStaticMatcherThreshold = 1;
 
@@ -301,6 +330,12 @@ class StaticMatcherHelper {
   // Number of static matcher generated. This is used to generate a unique name
   // for each DagNode.
   int staticMatcherCounter = 0;
+
+  // The DagLeaf which contains type or attr constraint.
+  DenseSet<DagLeaf> constraints;
+
+  // Static type/attribute verification function emitter.
+  StaticVerifierFunctionEmitter staticVerifierEmitter;
 };
 
 } // end anonymous namespace
@@ -395,6 +430,15 @@ void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
   os << "}\n";
 }
 
+void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
+                                            StringRef opName, StringRef arg,
+                                            StringRef failureStr) {
+  os << formatv("if(failed({0}(rewriter, {1}, {2}, {3}))) {{\n", funcName,
+                opName, arg, failureStr);
+  os.scope().os << "return ::mlir::failure();\n";
+  os << "}\n";
+}
+
 // Helper function to match patterns.
 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
                                          int depth) {
@@ -487,14 +531,15 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
       self = argName;
     else
       self = formatv("{0}.getType()", argName);
-    emitMatchCheck(
-        opName,
-        tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
+    StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
+    emitStaticVerifierCall(
+        verifier, opName, self,
         formatv("\"operand {0} of native code call '{1}' failed to satisfy "
                 "constraint: "
                 "'{2}'\"",
                 i, tree.getNativeCodeTemplate(),
-                escapeString(constraint.getSummary())));
+                escapeString(constraint.getSummary()))
+            .str());
   }
 
   LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
@@ -626,13 +671,14 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
       }
       auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
                           opName, operandIndex);
-      emitMatchCheck(
-          opName,
-          tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
-          formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
-                  "'{2}'\"",
-                  operand - op.operand_begin(), op.getOperationName(),
-                  escapeString(constraint.getSummary())));
+      StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
+      emitStaticVerifierCall(
+          verifier, opName, self.str(),
+          formatv(
+              "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
+              operandIndex, op.getOperationName(),
+              escapeString(constraint.getSummary()))
+              .str());
     }
   }
 
@@ -690,15 +736,17 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
                        op.getOperationName(), argIndex + 1));
     }
 
-    // If a constraint is specified, we need to generate C++ statements to
-    // check the constraint.
-    emitMatchCheck(
-        opName,
-        tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
+    // If a constraint is specified, we need to generate function call to its
+    // static verifier.
+    StringRef verifier =
+        staticMatcherHelper.getVerifierName(matcher.getAsConstraint());
+    emitStaticVerifierCall(
+        verifier, opName, "tblgen_attr",
         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
                 "'{2}'\"",
                 op.getOperationName(), namedAttr->name,
-                escapeString(matcher.getAsConstraint().getSummary())));
+                escapeString(matcher.getAsConstraint().getSummary()))
+            .str());
   }
 
   // Capture the value
@@ -1571,8 +1619,9 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
   }
 }
 
-StaticMatcherHelper::StaticMatcherHelper(RecordOperatorMap &mapper)
-    : opMap(mapper) {}
+StaticMatcherHelper::StaticMatcherHelper(const RecordKeeper &recordKeeper,
+                                         RecordOperatorMap &mapper)
+    : opMap(mapper), staticVerifierEmitter(recordKeeper) {}
 
 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
   // PatternEmitter will use the static matcher if there's one generated. To
@@ -1592,6 +1641,31 @@ void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
   }
 }
 
+void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
+  llvm::SetVector<const void *> typeConstraints;
+  llvm::SetVector<const void *> attrConstraints;
+  for (DagLeaf leaf : constraints) {
+    if (leaf.isOperandMatcher()) {
+      typeConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer());
+    } else {
+      assert(leaf.isAttrMatcher());
+      attrConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer());
+    }
+  }
+
+  staticVerifierEmitter.setBuilder("rewriter").setSelf("typeOrAttr");
+
+  staticVerifierEmitter.emitConstraintMethods(typeVerifierSignature,
+                                              verifierErrorHandler,
+                                              typeConstraints.getArrayRef(), os,
+                                              /*emitDecl=*/false);
+
+  staticVerifierEmitter.emitConstraintMethods(attrVerifierSignature,
+                                              verifierErrorHandler,
+                                              attrConstraints.getArrayRef(), os,
+                                              /*emitDecl=*/false);
+}
+
 void StaticMatcherHelper::addPattern(Record *record) {
   Pattern pat(record, &opMap);
 
@@ -1608,6 +1682,11 @@ void StaticMatcherHelper::addPattern(Record *record) {
     for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
       if (DagNode sibling = node.getArgAsNestedDag(i))
         dfs(sibling);
+      else {
+        DagLeaf leaf = node.getArgAsLeaf(i);
+        if (!leaf.isUnspecified())
+          constraints.insert(leaf);
+      }
 
     topologicalOrder.push_back(std::make_pair(node, record));
   };
@@ -1615,6 +1694,10 @@ void StaticMatcherHelper::addPattern(Record *record) {
   dfs(pat.getSourcePattern());
 }
 
+StringRef StaticMatcherHelper::getVerifierName(Constraint constraint) {
+  return staticVerifierEmitter.getConstraintFn(constraint);
+}
+
 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Rewriters", os);
 
@@ -1625,9 +1708,10 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
 
   // Exam all the patterns and generate static matcher for the duplicated
   // DagNode.
-  StaticMatcherHelper staticMatcher(recordOpMap);
+  StaticMatcherHelper staticMatcher(recordKeeper, recordOpMap);
   for (Record *p : patterns)
     staticMatcher.addPattern(p);
+  staticMatcher.populateStaticConstraintFunctions(os);
   staticMatcher.populateStaticMatchers(os);
 
   std::vector<std::string> rewriterNames;


        


More information about the Mlir-commits mailing list