[Mlir-commits] [mlir] f3798ad - Static verifier for type/attribute in DRR
Chia-hung Duan
llvmlistbot at llvm.org
Mon Nov 8 13:40:39 PST 2021
Author: Chia-hung Duan
Date: 2021-11-08T21:34:17Z
New Revision: f3798ad5fa845771846599f3c088016e3aef800c
URL: https://github.com/llvm/llvm-project/commit/f3798ad5fa845771846599f3c088016e3aef800c
DIFF: https://github.com/llvm/llvm-project/commit/f3798ad5fa845771846599f3c088016e3aef800c.diff
LOG: Static verifier for type/attribute in DRR
Generate static function for matching the type/attribute to reduce the
memory footprint.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D110199
Added:
Modified:
mlir/include/mlir/TableGen/CodeGenHelpers.h
mlir/include/mlir/TableGen/Pattern.h
mlir/test/mlir-tblgen/rewriter-static-matcher.td
mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index c7252a6b99e3..2dd1cf64667f 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -15,6 +15,7 @@
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/Format.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
@@ -91,8 +92,7 @@ class NamespaceEmitter {
///
class StaticVerifierFunctionEmitter {
public:
- StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
- raw_ostream &os);
+ StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records);
/// Emit the static verifier functions for `llvm::Record`s. The
/// `signatureFormat` describes the required arguments and it must have a
@@ -112,30 +112,40 @@ class StaticVerifierFunctionEmitter {
///
/// `typeArgName` is used to identify the argument that needs to check its
/// type. The constraint template will replace `$_self` with it.
- void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat,
- StringRef typeArgName, ArrayRef<llvm::Record *> opDefs,
- bool emitDecl);
+
+ /// This is the helper to generate the constraint functions from op
+ /// definitions.
+ void emitConstraintMethodsInNamespace(StringRef signatureFormat,
+ StringRef errorHandlerFormat,
+ StringRef cppNamespace,
+ ArrayRef<const void *> constraints,
+ raw_ostream &rawOs, bool emitDecl);
+
+ /// Emit the static functions for the giving type constraints.
+ void emitConstraintMethods(StringRef signatureFormat,
+ StringRef errorHandlerFormat,
+ ArrayRef<const void *> constraints,
+ raw_ostream &rawOs, bool emitDecl);
/// Get the name of the local function used for the given type constraint.
/// These functions are used for operand and result constraints and have the
/// form:
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
/// unsigned valueGroupStartIndex);
- StringRef getTypeConstraintFn(const Constraint &constraint) const;
+ StringRef getConstraintFn(const Constraint &constraint) const;
+
+ /// The setter to set `self` in format context.
+ StaticVerifierFunctionEmitter &setSelf(StringRef str);
+
+ /// The setter to set `builder` in format context.
+ StaticVerifierFunctionEmitter &setBuilder(StringRef str);
private:
/// Returns a unique name to use when generating local methods.
static std::string getUniqueName(const llvm::RecordKeeper &records);
- /// Emit local methods for the type constraints used within the provided op
- /// definitions.
- void emitTypeConstraintMethods(StringRef signatureFormat,
- StringRef errorHandlerFormat,
- StringRef typeArgName,
- ArrayRef<llvm::Record *> opDefs,
- bool emitDecl);
-
- raw_indented_ostream os;
+ /// The format context used for building the verifier function.
+ FmtContext fctx;
/// A unique label for the file currently being generated. This is used to
/// ensure that the local functions have a unique name.
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index fdc510447d1f..834ebdace12d 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -113,6 +113,9 @@ class DagLeaf {
void print(raw_ostream &os) const;
private:
+ friend llvm::DenseMapInfo<DagLeaf>;
+ const void *getAsOpaquePointer() const { return def; }
+
// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
// also a subclass of the given `superclass`.
bool isSubClassOf(StringRef superclass) const;
@@ -523,6 +526,24 @@ struct DenseMapInfo<mlir::tblgen::DagNode> {
return lhs.node == rhs.node;
}
};
+
+template <>
+struct DenseMapInfo<mlir::tblgen::DagLeaf> {
+ static mlir::tblgen::DagLeaf getEmptyKey() {
+ return mlir::tblgen::DagLeaf(
+ llvm::DenseMapInfo<llvm::Init *>::getEmptyKey());
+ }
+ static mlir::tblgen::DagLeaf getTombstoneKey() {
+ return mlir::tblgen::DagLeaf(
+ llvm::DenseMapInfo<llvm::Init *>::getTombstoneKey());
+ }
+ static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) {
+ return llvm::hash_value(leaf.getAsOpaquePointer());
+ }
+ static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs) {
+ return lhs.def == rhs.def;
+ }
+};
} // end namespace llvm
#endif // MLIR_TABLEGEN_PATTERN_H_
diff --git a/mlir/test/mlir-tblgen/rewriter-static-matcher.td b/mlir/test/mlir-tblgen/rewriter-static-matcher.td
index cfd80a40fb1c..2e84e3476235 100644
--- a/mlir/test/mlir-tblgen/rewriter-static-matcher.td
+++ b/mlir/test/mlir-tblgen/rewriter-static-matcher.td
@@ -37,12 +37,16 @@ def COp : NS_Op<"c_op", []> {
// Test static matcher for duplicate DagNode
// ---
-// CHECK: static ::mlir::LogicalResult static_dag_matcher_0
+// CHECK-DAG: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Type typeOrAttr}}
+// CHECK-DAG: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Attribute}}
+// CHECK-DAG: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]](
+// CHECK: if(failed([[$TYPE_CONSTRAINT]]
+// CHECK: if(failed([[$ATTR_CONSTRAINT]]
-// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
+// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)),
(AOp $int)>;
-// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
+// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
def : Pat<(COp $_, (BOp I32Attr:$attr, I32:$int)),
(COp $attr, $int)>;
diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
index 1528c56f906e..6aab5abf0d6e 100644
--- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
@@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/CodeGenHelpers.h"
-#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/FormatVariadic.h"
@@ -24,21 +23,34 @@ using namespace mlir;
using namespace mlir::tblgen;
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
- const llvm::RecordKeeper &records, raw_ostream &os)
- : os(os), uniqueOutputLabel(getUniqueName(records)) {}
+ const llvm::RecordKeeper &records)
+ : uniqueOutputLabel(getUniqueName(records)) {}
-void StaticVerifierFunctionEmitter::emitFunctionsFor(
+StaticVerifierFunctionEmitter &
+StaticVerifierFunctionEmitter::setSelf(StringRef str) {
+ fctx.withSelf(str);
+ return *this;
+}
+
+StaticVerifierFunctionEmitter &
+StaticVerifierFunctionEmitter::setBuilder(StringRef str) {
+ fctx.withBuilder(str);
+ return *this;
+}
+
+void StaticVerifierFunctionEmitter::emitConstraintMethodsInNamespace(
StringRef signatureFormat, StringRef errorHandlerFormat,
- StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
+ StringRef cppNamespace, ArrayRef<const void *> constraints, raw_ostream &os,
+ bool emitDecl) {
llvm::Optional<NamespaceEmitter> namespaceEmitter;
if (!emitDecl)
- namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
+ namespaceEmitter.emplace(os, cppNamespace);
- emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName,
- opDefs, emitDecl);
+ emitConstraintMethods(signatureFormat, errorHandlerFormat, constraints, os,
+ emitDecl);
}
-StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
+StringRef StaticVerifierFunctionEmitter::getConstraintFn(
const Constraint &constraint) const {
auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
assert(it != localTypeConstraints.end() && "expected valid constraint fn");
@@ -65,28 +77,16 @@ std::string StaticVerifierFunctionEmitter::getUniqueName(
return uniqueName;
}
-void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
+void StaticVerifierFunctionEmitter::emitConstraintMethods(
StringRef signatureFormat, StringRef errorHandlerFormat,
- StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
- // Collect a set of all of the used type constraints within the operation
- // definitions.
- llvm::SetVector<const void *> typeConstraints;
- for (Record *def : opDefs) {
- Operator op(*def);
- for (NamedTypeConstraint &operand : op.getOperands())
- if (operand.hasPredicate())
- typeConstraints.insert(operand.constraint.getAsOpaquePointer());
- for (NamedTypeConstraint &result : op.getResults())
- if (result.hasPredicate())
- typeConstraints.insert(result.constraint.getAsOpaquePointer());
- }
+ ArrayRef<const void *> constraints, raw_ostream &rawOs, bool emitDecl) {
+ raw_indented_ostream os(rawOs);
// Record the mapping from predicate to constraint. If two constraints has the
// same predicate and constraint summary, they can share the same verification
// function.
llvm::DenseMap<Pred, const void *> predToConstraint;
- FmtContext fctx;
- for (auto it : llvm::enumerate(typeConstraints)) {
+ for (auto it : llvm::enumerate(constraints)) {
std::string name;
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
Pred pred = constraint.getPredicate();
@@ -101,7 +101,7 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
// summary, otherwise we may report the wrong message while verification
// fails.
if (constraint.getSummary() == built.getSummary()) {
- name = getTypeConstraintFn(built).str();
+ name = getConstraintFn(built).str();
break;
}
++iter;
@@ -126,12 +126,11 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
continue;
os << formatv(signatureFormat.data(), name) << " {\n";
- os.indent() << "if (!("
- << tgfmt(constraint.getConditionTemplate(),
- &fctx.withSelf(typeArgName))
+ os.indent() << "if (!(" << tgfmt(constraint.getConditionTemplate(), &fctx)
<< ")) {\n";
os.indent() << "return "
- << formatv(errorHandlerFormat.data(), constraint.getSummary())
+ << formatv(errorHandlerFormat.data(),
+ escapeString(constraint.getSummary()))
<< ";\n";
os.unindent() << "}\nreturn ::mlir::success();\n";
os.unindent() << "}\n\n";
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 6e88480c2fb1..73e5ef951a9a 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2233,7 +2233,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
continue;
// Emit a loop to check all the dynamic values in the pack.
StringRef constraintFn =
- staticVerifierEmitter.getTypeConstraintFn(value.constraint);
+ staticVerifierEmitter.getConstraintFn(value.constraint);
body << " for (::mlir::Value v : valueGroup" << staticValue.index()
<< ") {\n"
<< " if (::mlir::failed(" << constraintFn
@@ -2639,11 +2639,27 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
return;
// Generate all of the locally instantiated methods first.
- StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os);
+ StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper);
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
- staticVerifierEmitter.emitFunctionsFor(
- typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type",
- defs, emitDecl);
+ staticVerifierEmitter.setSelf("type");
+
+ // Collect a set of all of the used type constraints within the operation
+ // definitions.
+ llvm::SetVector<const void *> typeConstraints;
+ for (Record *def : defs) {
+ Operator op(*def);
+ for (NamedTypeConstraint &operand : op.getOperands())
+ if (operand.hasPredicate())
+ typeConstraints.insert(operand.constraint.getAsOpaquePointer());
+ for (NamedTypeConstraint &result : op.getResults())
+ if (result.hasPredicate())
+ typeConstraints.insert(result.constraint.getAsOpaquePointer());
+ }
+
+ staticVerifierEmitter.emitConstraintMethodsInNamespace(
+ typeVerifierSignature, typeVerifierErrorHandler,
+ Operator(*defs[0]).getCppNamespace(), typeConstraints.getArrayRef(), os,
+ emitDecl);
for (auto *def : defs) {
Operator op(*def);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 2d2694940850..318f006cdf45 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -20,6 +20,7 @@
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
@@ -41,6 +42,23 @@ using llvm::RecordKeeper;
#define DEBUG_TYPE "mlir-tblgen-rewritergen"
+// The signature of static type verification function
+static const char *typeVerifierSignature =
+ "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
+ "::mlir::Operation *op, ::mlir::Type typeOrAttr, "
+ "::llvm::StringRef failureStr)";
+
+// The signature of static attribute verification function
+static const char *attrVerifierSignature =
+ "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
+ "::mlir::Operation *op, ::mlir::Attribute typeOrAttr, "
+ "::llvm::StringRef failureStr)";
+
+// The template of error handler in static type/attribute verification function
+static const char *verifierErrorHandler =
+ "rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {\n diag "
+ "<< failureStr << \": {0}\";\n});";
+
namespace llvm {
template <>
struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
@@ -87,6 +105,10 @@ class PatternEmitter {
// Emit C++ function call to static DAG matcher.
void emitStaticMatchCall(DagNode tree, StringRef name);
+ // Emit C++ function call to static type/attribute constraint function.
+ void emitStaticVerifierCall(StringRef funcName, StringRef opName,
+ StringRef arg, StringRef failureStr);
+
// Emits C++ statements for matching using a native code call.
void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
@@ -244,7 +266,8 @@ class PatternEmitter {
// inlining them.
class StaticMatcherHelper {
public:
- StaticMatcherHelper(RecordOperatorMap &mapper);
+ StaticMatcherHelper(const RecordKeeper &recordKeeper,
+ RecordOperatorMap &mapper);
// Determine if we should inline the match logic or delegate to a static
// function.
@@ -258,6 +281,9 @@ class StaticMatcherHelper {
return matcherNames[node];
}
+ // Get the name of static type/attribute verification function.
+ StringRef getVerifierName(Constraint constraint);
+
// Collect the `Record`s, i.e., the DRR, so that we can get the information of
// the duplicated DAGs.
void addPattern(Record *record);
@@ -265,6 +291,9 @@ class StaticMatcherHelper {
// Emit all static functions of DAG Matcher.
void populateStaticMatchers(raw_ostream &os);
+ // Emit all static functions for Constraints.
+ void populateStaticConstraintFunctions(raw_ostream &os);
+
private:
static constexpr unsigned kStaticMatcherThreshold = 1;
@@ -301,6 +330,12 @@ class StaticMatcherHelper {
// Number of static matcher generated. This is used to generate a unique name
// for each DagNode.
int staticMatcherCounter = 0;
+
+ // The DagLeaf which contains type or attr constraint.
+ DenseSet<DagLeaf> constraints;
+
+ // Static type/attribute verification function emitter.
+ StaticVerifierFunctionEmitter staticVerifierEmitter;
};
} // end anonymous namespace
@@ -395,6 +430,15 @@ void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
os << "}\n";
}
+void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
+ StringRef opName, StringRef arg,
+ StringRef failureStr) {
+ os << formatv("if(failed({0}(rewriter, {1}, {2}, {3}))) {{\n", funcName,
+ opName, arg, failureStr);
+ os.scope().os << "return ::mlir::failure();\n";
+ os << "}\n";
+}
+
// Helper function to match patterns.
void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
int depth) {
@@ -487,14 +531,15 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
self = argName;
else
self = formatv("{0}.getType()", argName);
- emitMatchCheck(
- opName,
- tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
+ StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
+ emitStaticVerifierCall(
+ verifier, opName, self,
formatv("\"operand {0} of native code call '{1}' failed to satisfy "
"constraint: "
"'{2}'\"",
i, tree.getNativeCodeTemplate(),
- escapeString(constraint.getSummary())));
+ escapeString(constraint.getSummary()))
+ .str());
}
LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
@@ -626,13 +671,14 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
}
auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
opName, operandIndex);
- emitMatchCheck(
- opName,
- tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
- formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
- "'{2}'\"",
- operand - op.operand_begin(), op.getOperationName(),
- escapeString(constraint.getSummary())));
+ StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
+ emitStaticVerifierCall(
+ verifier, opName, self.str(),
+ formatv(
+ "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
+ operandIndex, op.getOperationName(),
+ escapeString(constraint.getSummary()))
+ .str());
}
}
@@ -690,15 +736,17 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
op.getOperationName(), argIndex + 1));
}
- // If a constraint is specified, we need to generate C++ statements to
- // check the constraint.
- emitMatchCheck(
- opName,
- tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
+ // If a constraint is specified, we need to generate function call to its
+ // static verifier.
+ StringRef verifier =
+ staticMatcherHelper.getVerifierName(matcher.getAsConstraint());
+ emitStaticVerifierCall(
+ verifier, opName, "tblgen_attr",
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
"'{2}'\"",
op.getOperationName(), namedAttr->name,
- escapeString(matcher.getAsConstraint().getSummary())));
+ escapeString(matcher.getAsConstraint().getSummary()))
+ .str());
}
// Capture the value
@@ -1571,8 +1619,9 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
}
}
-StaticMatcherHelper::StaticMatcherHelper(RecordOperatorMap &mapper)
- : opMap(mapper) {}
+StaticMatcherHelper::StaticMatcherHelper(const RecordKeeper &recordKeeper,
+ RecordOperatorMap &mapper)
+ : opMap(mapper), staticVerifierEmitter(recordKeeper) {}
void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
// PatternEmitter will use the static matcher if there's one generated. To
@@ -1592,6 +1641,31 @@ void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
}
}
+void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
+ llvm::SetVector<const void *> typeConstraints;
+ llvm::SetVector<const void *> attrConstraints;
+ for (DagLeaf leaf : constraints) {
+ if (leaf.isOperandMatcher()) {
+ typeConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer());
+ } else {
+ assert(leaf.isAttrMatcher());
+ attrConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer());
+ }
+ }
+
+ staticVerifierEmitter.setBuilder("rewriter").setSelf("typeOrAttr");
+
+ staticVerifierEmitter.emitConstraintMethods(typeVerifierSignature,
+ verifierErrorHandler,
+ typeConstraints.getArrayRef(), os,
+ /*emitDecl=*/false);
+
+ staticVerifierEmitter.emitConstraintMethods(attrVerifierSignature,
+ verifierErrorHandler,
+ attrConstraints.getArrayRef(), os,
+ /*emitDecl=*/false);
+}
+
void StaticMatcherHelper::addPattern(Record *record) {
Pattern pat(record, &opMap);
@@ -1608,6 +1682,11 @@ void StaticMatcherHelper::addPattern(Record *record) {
for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
if (DagNode sibling = node.getArgAsNestedDag(i))
dfs(sibling);
+ else {
+ DagLeaf leaf = node.getArgAsLeaf(i);
+ if (!leaf.isUnspecified())
+ constraints.insert(leaf);
+ }
topologicalOrder.push_back(std::make_pair(node, record));
};
@@ -1615,6 +1694,10 @@ void StaticMatcherHelper::addPattern(Record *record) {
dfs(pat.getSourcePattern());
}
+StringRef StaticMatcherHelper::getVerifierName(Constraint constraint) {
+ return staticVerifierEmitter.getConstraintFn(constraint);
+}
+
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Rewriters", os);
@@ -1625,9 +1708,10 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
// Exam all the patterns and generate static matcher for the duplicated
// DagNode.
- StaticMatcherHelper staticMatcher(recordOpMap);
+ StaticMatcherHelper staticMatcher(recordKeeper, recordOpMap);
for (Record *p : patterns)
staticMatcher.addPattern(p);
+ staticMatcher.populateStaticConstraintFunctions(os);
staticMatcher.populateStaticMatchers(os);
std::vector<std::string> rewriterNames;
More information about the Mlir-commits
mailing list