[Mlir-commits] [mlir] b8186b3 - [mlir][ods] Unique attribute, successor, region constraints
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 11 17:04:13 PST 2021
Author: Mogball
Date: 2021-11-12T01:04:08Z
New Revision: b8186b313c5926bf67155311987d976c7cde7a1a
URL: https://github.com/llvm/llvm-project/commit/b8186b313c5926bf67155311987d976c7cde7a1a
DIFF: https://github.com/llvm/llvm-project/commit/b8186b313c5926bf67155311987d976c7cde7a1a.diff
LOG: [mlir][ods] Unique attribute, successor, region constraints
With `-Os` turned on, results in 2-5% binary size reduction
(depends on the original binary). Without it, the binary size
is essentially unchanged.
Depends on D113128
Differential Revision: https://reviews.llvm.org/D113331
Added:
mlir/test/mlir-tblgen/constraint-unique.td
Modified:
mlir/include/mlir/TableGen/Attribute.h
mlir/include/mlir/TableGen/CodeGenHelpers.h
mlir/include/mlir/TableGen/Constraint.h
mlir/include/mlir/TableGen/Predicate.h
mlir/include/mlir/TableGen/Type.h
mlir/lib/TableGen/Attribute.cpp
mlir/lib/TableGen/Constraint.cpp
mlir/lib/TableGen/Type.cpp
mlir/test/mlir-tblgen/predicate.td
mlir/test/mlir-tblgen/rewriter-static-matcher.td
mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 579b93846b1c1..f8397128f03a0 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -32,7 +32,7 @@ class Type;
// in TableGen.
class AttrConstraint : public Constraint {
public:
- explicit AttrConstraint(const llvm::Record *record);
+ using Constraint::Constraint;
static bool classof(const Constraint *c) { return c->getKind() == CK_Attr; }
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index 2dd1cf64667fd..14af7d3380e48 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -13,10 +13,10 @@
#ifndef MLIR_TABLEGEN_CODEGENHELPERS_H
#define MLIR_TABLEGEN_CODEGENHELPERS_H
-#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/Format.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
@@ -26,8 +26,8 @@ class RecordKeeper;
namespace mlir {
namespace tblgen {
-
class Constraint;
+class DagLeaf;
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
@@ -92,68 +92,128 @@ class NamespaceEmitter {
///
class StaticVerifierFunctionEmitter {
public:
- StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records);
-
- /// Emit the static verifier functions for `llvm::Record`s. The
- /// `signatureFormat` describes the required arguments and it must have a
- /// placeholder for function name.
- /// Example,
- /// const char *typeVerifierSignature =
- /// "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type"
- /// " type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
+ StaticVerifierFunctionEmitter(raw_ostream &os,
+ const llvm::RecordKeeper &records);
+
+ /// Collect and unique all compatible type, attribute, successor, and region
+ /// constraints from the operations in the file and emit them at the top of
+ /// the generated file.
///
- /// `errorHandlerFormat` describes the error message to return. It may have a
- /// placeholder for the summary of Constraint and bring more information for
- /// the error message.
- /// Example,
- /// const char *typeVerifierErrorHandler =
- /// " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << "
- /// "\" must be {0}, but got \" << type";
+ /// Constraints that do not meet the restriction that they can only reference
+ /// `$_self` and `$_op` are not uniqued.
+ void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
+
+ /// Unique all compatible type and attribute constraints from a pattern file
+ /// and emit them at the top of the generated file.
///
- /// `typeArgName` is used to identify the argument that needs to check its
- /// type. The constraint template will replace `$_self` with it.
-
- /// This is the helper to generate the constraint functions from op
- /// definitions.
- void emitConstraintMethodsInNamespace(StringRef signatureFormat,
- StringRef errorHandlerFormat,
- StringRef cppNamespace,
- ArrayRef<const void *> constraints,
- raw_ostream &rawOs, bool emitDecl);
-
- /// Emit the static functions for the giving type constraints.
- void emitConstraintMethods(StringRef signatureFormat,
- StringRef errorHandlerFormat,
- ArrayRef<const void *> constraints,
- raw_ostream &rawOs, bool emitDecl);
-
- /// Get the name of the local function used for the given type constraint.
+ /// Constraints that do not meet the restriction that they can only reference
+ /// `$_self`, `$_op`, and `$_builder` are not uniqued.
+ void emitPatternConstraints(const DenseSet<DagLeaf> &constraints);
+
+ /// Get the name of the static function used for the given type constraint.
/// These functions are used for operand and result constraints and have the
/// form:
+ ///
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
- /// unsigned valueGroupStartIndex);
- StringRef getConstraintFn(const Constraint &constraint) const;
+ /// unsigned valueIndex);
+ ///
+ /// Pattern constraints have the form:
+ ///
+ /// LogicalResult(PatternRewriter &rewriter, Operation *op, Type type,
+ /// StringRef failureStr);
+ ///
+ StringRef getTypeConstraintFn(const Constraint &constraint) const;
+
+ /// Get the name of the static function used for the given attribute
+ /// constraint. These functions are in the form:
+ ///
+ /// LogicalResult(Operation *op, Attribute attr, StringRef attrName);
+ ///
+ /// If a uniqued constraint was not found, this function returns None. The
+ /// uniqued constraints cannot be used in the context of an OpAdaptor.
+ ///
+ /// Pattern constraints have the form:
+ ///
+ /// LogicalResult(PatternRewriter &rewriter, Operation *op, Attribute attr,
+ /// StringRef failureStr);
+ ///
+ Optional<StringRef> getAttrConstraintFn(const Constraint &constraint) const;
- /// The setter to set `self` in format context.
- StaticVerifierFunctionEmitter &setSelf(StringRef str);
+ /// Get the name of the static function used for the given successor
+ /// constraint. These functions are in the form:
+ ///
+ /// LogicalResult(Operation *op, Block *successor, StringRef successorName,
+ /// unsigned successorIndex);
+ ///
+ StringRef getSuccessorConstraintFn(const Constraint &constraint) const;
- /// The setter to set `builder` in format context.
- StaticVerifierFunctionEmitter &setBuilder(StringRef str);
+ /// Get the name of the static function used for the given region constraint.
+ /// These functions are in the form:
+ ///
+ /// LogicalResult(Operation *op, Region ®ion, StringRef regionName,
+ /// unsigned regionIndex);
+ ///
+ /// The region name may be empty.
+ StringRef getRegionConstraintFn(const Constraint &constraint) const;
private:
- /// Returns a unique name to use when generating local methods.
- static std::string getUniqueName(const llvm::RecordKeeper &records);
-
- /// The format context used for building the verifier function.
- FmtContext fctx;
+ /// Emit static type constraint functions.
+ void emitTypeConstraints();
+ /// Emit static attribute constraint functions.
+ void emitAttrConstraints();
+ /// Emit static successor constraint functions.
+ void emitSuccessorConstraints();
+ /// Emit static region constraint functions.
+ void emitRegionConstraints();
+
+ /// Emit pattern constraints.
+ void emitPatternConstraints();
+
+ /// Collect and unique all the constraints used by operations.
+ void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
+ /// Collect and unique all pattern constraints.
+ void collectPatternConstraints(const DenseSet<DagLeaf> &constraints);
+
+ /// The output stream.
+ raw_ostream &os;
/// A unique label for the file currently being generated. This is used to
- /// ensure that the local functions have a unique name.
+ /// ensure that the static functions have a unique name.
std::string uniqueOutputLabel;
- /// A set of functions implementing type constraints, used for operand and
- /// result verification.
- llvm::DenseMap<const void *, std::string> localTypeConstraints;
+ /// Unique constraints by their predicate and summary. Constraints that share
+ /// the same predicate may have
diff erent descriptions; ensure that the
+ /// correct error message is reported when verification fails.
+ struct ConstraintUniquer {
+ static Constraint getEmptyKey();
+ static Constraint getTombstoneKey();
+ static unsigned getHashValue(Constraint constraint);
+ static bool isEqual(Constraint lhs, Constraint rhs);
+ };
+ /// Use a MapVector to ensure that functions are generated deterministically.
+ using ConstraintMap =
+ llvm::MapVector<Constraint, std::string,
+ llvm::DenseMap<Constraint, unsigned, ConstraintUniquer>>;
+
+ /// A generic function to emit constraints
+ void emitConstraints(const ConstraintMap &constraints, StringRef selfName,
+ const char *const codeTemplate);
+
+ /// Assign a unique name to a unique constraint.
+ std::string getUniqueName(StringRef kind, unsigned index);
+ /// Unique a constraint in the map.
+ void collectConstraint(ConstraintMap &map, StringRef kind,
+ Constraint constraint);
+
+ /// The set of type constraints used for operand and result verification in
+ /// the current file.
+ ConstraintMap typeConstraints;
+ /// The set of attribute constraints used in the current file.
+ ConstraintMap attrConstraints;
+ /// The set of successor constraints used in the current file.
+ ConstraintMap successorConstraints;
+ /// The set of region constraints used in the current file.
+ ConstraintMap regionConstraints;
};
// Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar.
diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 5ecf326da1e79..ebb3d2955ddf6 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -29,8 +29,15 @@ namespace tblgen {
// TableGen.
class Constraint {
public:
+ // Constraint kind
+ enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized };
+
+ // Create a constraint with a TableGen definition and a kind.
+ Constraint(const llvm::Record *record, Kind kind) : def(record), kind(kind) {}
+ // Create a constraint with a TableGen definition, and infer the kind.
Constraint(const llvm::Record *record);
+ /// Constraints are pointer-comparable.
bool operator==(const Constraint &that) { return def == that.def; }
bool operator!=(const Constraint &that) { return def != that.def; }
@@ -47,24 +54,9 @@ class Constraint {
// description is not provided, returns the TableGen def name.
StringRef getSummary() const;
- // Constraint kind
- enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized };
-
Kind getKind() const { return kind; }
- /// Get an opaque pointer to the constraint.
- const void *getAsOpaquePointer() const { return def; }
- /// Construct a constraint from the opaque pointer representation.
- static Constraint getFromOpaquePointer(const void *ptr) {
- return Constraint(reinterpret_cast<const llvm::Record *>(ptr));
- }
-
- // Return the underlying def.
- const llvm::Record *getDef() const { return def; }
-
protected:
- Constraint(Kind kind, const llvm::Record *record);
-
// The TableGen definition of this constraint.
const llvm::Record *def;
diff --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h
index e23c751e52d09..eb3fe53856951 100644
--- a/mlir/include/mlir/TableGen/Predicate.h
+++ b/mlir/include/mlir/TableGen/Predicate.h
@@ -53,15 +53,21 @@ class Pred {
// record of type CombinedPred.
bool isCombined() const;
+ // Get the location of the predicate.
+ ArrayRef<llvm::SMLoc> getLoc() const;
+
// Records are pointer-comparable.
bool operator==(const Pred &other) const { return def == other.def; }
- // Get the location of the predicate.
- ArrayRef<llvm::SMLoc> getLoc() const;
+ // Return true if the predicate is not null.
+ operator bool() const { return def; }
-protected:
- friend llvm::DenseMapInfo<Pred>;
+ // Hash a predicate by its pointer value.
+ friend llvm::hash_code hash_value(Pred pred) {
+ return llvm::hash_value(pred.def);
+ }
+protected:
// The TableGen definition of this predicate.
const llvm::Record *def;
};
@@ -119,18 +125,4 @@ class ConcatPred : public CombinedPred {
} // end namespace tblgen
} // end namespace mlir
-namespace llvm {
-template <>
-struct DenseMapInfo<mlir::tblgen::Pred> {
- static mlir::tblgen::Pred getEmptyKey() { return mlir::tblgen::Pred(); }
- static mlir::tblgen::Pred getTombstoneKey() { return mlir::tblgen::Pred(); }
- static unsigned getHashValue(mlir::tblgen::Pred pred) {
- return llvm::hash_value(pred.def);
- }
- static bool isEqual(mlir::tblgen::Pred lhs, mlir::tblgen::Pred rhs) {
- return lhs == rhs;
- }
-};
-} // end namespace llvm
-
#endif // MLIR_TABLEGEN_PREDICATE_H_
diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index c996adabdcff1..7fc892690a25f 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -29,8 +29,9 @@ namespace tblgen {
// TableGen.
class TypeConstraint : public Constraint {
public:
- explicit TypeConstraint(const llvm::Record *record);
- explicit TypeConstraint(const llvm::DefInit *init);
+ using Constraint::Constraint;
+
+ TypeConstraint(const llvm::DefInit *record);
static bool classof(const Constraint *c) { return c->getKind() == CK_Type; }
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index 9664eb91b2351..5bc618c5e2094 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -31,12 +31,6 @@ static StringRef getValueAsString(const Init *init) {
return {};
}
-AttrConstraint::AttrConstraint(const Record *record)
- : Constraint(Constraint::CK_Attr, record) {
- assert(isSubClassOf("AttrConstraint") &&
- "must be subclass of TableGen 'AttrConstraint' class");
-}
-
bool AttrConstraint::isSubClassOf(StringRef className) const {
return def->isSubClassOf(className);
}
diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index f0dac0bccd3f0..759e28fbc903a 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -17,10 +17,11 @@ using namespace mlir;
using namespace mlir::tblgen;
Constraint::Constraint(const llvm::Record *record)
- : def(record), kind(CK_Uncategorized) {
+ : Constraint(record, CK_Uncategorized) {
// Look through OpVariable's to their constraint.
if (def->isSubClassOf("OpVariable"))
def = def->getValueAsDef("constraint");
+
if (def->isSubClassOf("TypeConstraint")) {
kind = CK_Type;
} else if (def->isSubClassOf("AttrConstraint")) {
@@ -34,13 +35,6 @@ Constraint::Constraint(const llvm::Record *record)
}
}
-Constraint::Constraint(Kind kind, const llvm::Record *record)
- : def(record), kind(kind) {
- // Look through OpVariable's to their constraint.
- if (def->isSubClassOf("OpVariable"))
- def = def->getValueAsDef("constraint");
-}
-
Pred Constraint::getPredicate() const {
auto *val = def->getValue("predicate");
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index 6691bb88c7680..601440ebc6435 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -19,12 +19,6 @@
using namespace mlir;
using namespace mlir::tblgen;
-TypeConstraint::TypeConstraint(const llvm::Record *record)
- : Constraint(Constraint::CK_Type, record) {
- assert(def->isSubClassOf("TypeConstraint") &&
- "must be subclass of TableGen 'TypeConstraint' class");
-}
-
TypeConstraint::TypeConstraint(const llvm::DefInit *init)
: TypeConstraint(init->getDef()) {}
diff --git a/mlir/test/mlir-tblgen/constraint-unique.td b/mlir/test/mlir-tblgen/constraint-unique.td
new file mode 100644
index 0000000000000..a2ee633f5c010
--- /dev/null
+++ b/mlir/test/mlir-tblgen/constraint-unique.td
@@ -0,0 +1,156 @@
+// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+}
+
+class NS_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<Test_Dialect, mnemonic, traits>;
+
+/// Test unique'ing of type, attribute, successor, and region constraints.
+
+def ATypePred : CPred<"typePred($_self, $_op)">;
+def AType : Type<ATypePred, "a type">;
+def OtherType : Type<ATypePred, "another type">;
+
+def AnAttrPred : CPred<"attrPred($_self, $_op)">;
+def AnAttr : Attr<AnAttrPred, "an attribute">;
+def OtherAttr : Attr<AnAttrPred, "another attribute">;
+
+def ASuccessorPred : CPred<"successorPred($_self, $_op)">;
+def ASuccessor : Successor<ASuccessorPred, "a successor">;
+def OtherSuccessor : Successor<ASuccessorPred, "another successor">;
+
+def ARegionPred : CPred<"regionPred($_self, $_op)">;
+def ARegion : Region<ARegionPred, "a region">;
+def OtherRegion : Region<ARegionPred, "another region">;
+
+// OpA and OpB have the same type, attribute, successor, and region constraints.
+
+def OpA : NS_Op<"op_a"> {
+ let arguments = (ins AType:$a, AnAttr:$b);
+ let results = (outs AType:$ret);
+ let successors = (successor ASuccessor:$c);
+ let regions = (region ARegion:$d);
+}
+
+def OpB : NS_Op<"op_b"> {
+ let arguments = (ins AType:$a, AnAttr:$b);
+ let successors = (successor ASuccessor:$c);
+ let regions = (region ARegion:$d);
+}
+
+// OpC has the same type, attribute, successor, and region predicates but has
+//
diff erence descriptions for them.
+
+def OpC : NS_Op<"op_c"> {
+ let arguments = (ins OtherType:$a, OtherAttr:$b);
+ let results = (outs OtherType:$ret);
+ let successors = (successor OtherSuccessor:$c);
+ let regions = (region OtherRegion:$d);
+}
+
+/// Test that a type contraint was generated.
+// CHECK: static ::mlir::LogicalResult [[$A_TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
+// CHECK: if (!((typePred(type, *op)))) {
+// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
+// CHECK-NEXT: << " must be a type, but got " << type;
+
+/// Test that duplicate type constraint was not generated.
+// CHECK-NOT: << " must be a type, but got " << type;
+
+/// Test that a type constraint with a
diff erent description was generated.
+// CHECK: static ::mlir::LogicalResult [[$O_TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
+// CHECK: if (!((typePred(type, *op)))) {
+// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
+// CHECK-NEXT: << " must be another type, but got " << type;
+
+/// Test that an attribute contraint was generated.
+// CHECK: static ::mlir::LogicalResult [[$A_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
+// CHECK: if (attr && !((attrPred(attr, *op)))) {
+// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
+// CHECK-NEXT: << "' failed to satisfy constraint: an attribute";
+
+/// Test that duplicate attribute constraint was not generated.
+// CHECK-NOT: << "' failed to satisfy constraint: an attribute";
+
+/// Test that a attribute constraint with a
diff erent description was generated.
+// CHECK: static ::mlir::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
+// CHECK: if (attr && !((attrPred(attr, *op)))) {
+// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
+// CHECK-NEXT: << "' failed to satisfy constraint: another attribute";
+
+/// Test that a successor contraint was generated.
+// CHECK: static ::mlir::LogicalResult [[$A_SUCCESSOR_CONSTRAINT:__mlir_ods_local_successor_constraint.*]](
+// CHECK: if (!((successorPred(successor, *op)))) {
+// CHECK-NEXT: return op->emitOpError("successor #") << successorIndex << " ('"
+// CHECK-NEXT: << successorName << ")' failed to verify constraint: a successor";
+
+/// Test that duplicate successor constraint was not generated.
+// CHECK-NOT: << successorName << ")' failed to verify constraint: a successor";
+
+/// Test that a successor constraint with a
diff erent description was generated.
+// CHECK: static ::mlir::LogicalResult [[$O_SUCCESSOR_CONSTRAINT:__mlir_ods_local_successor_constraint.*]](
+// CHECK: if (!((successorPred(successor, *op)))) {
+// CHECK-NEXT: return op->emitOpError("successor #") << successorIndex << " ('"
+// CHECK-NEXT: << successorName << ")' failed to verify constraint: another successor";
+
+/// Test that a region contraint was generated.
+// CHECK: static ::mlir::LogicalResult [[$A_REGION_CONSTRAINT:__mlir_ods_local_region_constraint.*]](
+// CHECK: if (!((regionPred(region, *op)))) {
+// CHECK-NEXT: return op->emitOpError("region #") << regionIndex
+// CHECK-NEXT: << (regionName.empty() ? " " : " ('" + regionName + "') ")
+// CHECK-NEXT: << "failed to verify constraint: a region";
+
+/// Test that duplicate region constraint was not generated.
+// CHECK-NOT: << "failed to verify constraint: a region";
+
+/// Test that a region constraint with a
diff erent description was generated.
+// CHECK: static ::mlir::LogicalResult [[$O_REGION_CONSTRAINT:__mlir_ods_local_region_constraint.*]](
+// CHECK: if (!((regionPred(region, *op)))) {
+// CHECK-NEXT: return op->emitOpError("region #") << regionIndex
+// CHECK-NEXT: << (regionName.empty() ? " " : " ('" + regionName + "') ")
+// CHECK-NEXT: << "failed to verify constraint: another region";
+
+/// Test that the uniqued constraints are being used.
+// CHECK-LABEL: OpA::verify
+// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName());
+// CHECK: if (::mlir::failed([[$A_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b")))
+// CHECK-NEXT: return ::mlir::failure();
+// CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0);
+// CHECK: for (auto [[$A_VALUE:.*]] : [[$A_VALUE_GROUP]])
+// CHECK-NEXT: if (::mlir::failed([[$A_TYPE_CONSTRAINT]](*this, [[$A_VALUE]].getType(), "operand", index++)))
+// CHECK-NEXT: return ::mlir::failure();
+// CHECK: auto [[$RET_VALUE_GROUP:.*]] = getODSResults(0);
+// CHECK: for (auto [[$RET_VALUE:.*]] : [[$RET_VALUE_GROUP]])
+// CHECK-NEXT: if (::mlir::failed([[$A_TYPE_CONSTRAINT]](*this, [[$RET_VALUE]].getType(), "result", index++)))
+// CHECK-NEXT: return ::mlir::failure();
+// CHECK: for (auto ®ion : ::llvm::makeMutableArrayRef((*this)->getRegion(0)))
+// CHECK-NEXT: if (::mlir::failed([[$A_REGION_CONSTRAINT]](*this, region, "d", index++)))
+// CHECK-NEXT: return ::mlir::failure();
+// CHECK: for (auto *successor : ::llvm::makeMutableArrayRef(c()))
+// CHECK-NEXT: if (::mlir::failed([[$A_SUCCESSOR_CONSTRAINT]](*this, successor, "c", index++)))
+// CHECK-NEXT: return ::mlir::failure();
+
+/// Test that the op with the same predicates but
diff erent with descriptions
+/// uses the
diff erent constraints.
+// CHECK-LABEL: OpC::verify
+// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName());
+// CHECK: if (::mlir::failed([[$O_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b")))
+// CHECK-NEXT: return ::mlir::failure();
+// CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0);
+// CHECK: for (auto [[$A_VALUE:.*]] : [[$A_VALUE_GROUP]])
+// CHECK-NEXT: if (::mlir::failed([[$O_TYPE_CONSTRAINT]](*this, [[$A_VALUE]].getType(), "operand", index++)))
+// CHECK-NEXT: return ::mlir::failure();
+// CHECK: auto [[$RET_VALUE_GROUP:.*]] = getODSResults(0);
+// CHECK: for (auto [[$RET_VALUE:.*]] : [[$RET_VALUE_GROUP]])
+// CHECK-NEXT: if (::mlir::failed([[$O_TYPE_CONSTRAINT]](*this, [[$RET_VALUE]].getType(), "result", index++)))
+// CHECK-NEXT: return ::mlir::failure();
+// CHECK: for (auto ®ion : ::llvm::makeMutableArrayRef((*this)->getRegion(0)))
+// CHECK-NEXT: if (::mlir::failed([[$O_REGION_CONSTRAINT]](*this, region, "d", index++)))
+// CHECK-NEXT: return ::mlir::failure();
+// CHECK: for (auto *successor : ::llvm::makeMutableArrayRef(c()))
+// CHECK-NEXT: if (::mlir::failed([[$O_SUCCESSOR_CONSTRAINT]](*this, successor, "c", index++)))
+// CHECK-NEXT: return ::mlir::failure();
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index dcfdd65e72581..ad41866e8c0b8 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -17,24 +17,28 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
}
// CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK-NEXT: if (!((type.isInteger(32) || type.isF32()))) {
-// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
+// CHECK: if (!((type.isInteger(32) || type.isF32()))) {
+// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
+// CHECK-NEXT: << " must be 32-bit integer or floating-point type, but got " << type;
// Check there is no verifier with same predicate generated.
// CHECK-NOT: if (!((type.isInteger(32) || type.isF32()))) {
-// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
+// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueIndex
+// CHECK-NOT. << " must be 32-bit integer or floating-point type, but got " << type;
// CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return (true); }(type.cast<::mlir::ShapedType>().getElementType())))) {
-// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
+// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return (true); }(type.cast<::mlir::ShapedType>().getElementType())))) {
+// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
+// CHECK-NEXT: << " must be tensor of any type values, but got " << type;
// CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(type.cast<::mlir::ShapedType>().getElementType())))) {
-// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
+// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(type.cast<::mlir::ShapedType>().getElementType())))) {
+// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
+// CHECK-NEXT: << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
// CHECK-LABEL: OpA::verify
// CHECK: auto valueGroup0 = getODSOperands(0);
-// CHECK: for (::mlir::Value v : valueGroup0) {
+// CHECK: for (auto v : valueGroup0) {
// CHECK: if (::mlir::failed([[$INTEGER_FLOAT_CONSTRAINT]]
def OpB : NS_Op<"op_for_And_PredOpTrait", [
@@ -109,7 +113,7 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
// CHECK-LABEL: OpK::verify
// CHECK: auto valueGroup0 = getODSOperands(0);
-// CHECK: for (::mlir::Value v : valueGroup0) {
+// CHECK: for (auto v : valueGroup0) {
// CHECK: if (::mlir::failed([[$TENSOR_INTEGER_FLOAT_CONSTRAINT]]
def OpL : NS_Op<"op_for_StringEscaping", []> {
diff --git a/mlir/test/mlir-tblgen/rewriter-static-matcher.td b/mlir/test/mlir-tblgen/rewriter-static-matcher.td
index 2e84e3476235b..b343f90c928a9 100644
--- a/mlir/test/mlir-tblgen/rewriter-static-matcher.td
+++ b/mlir/test/mlir-tblgen/rewriter-static-matcher.td
@@ -37,11 +37,13 @@ def COp : NS_Op<"c_op", []> {
// Test static matcher for duplicate DagNode
// ---
-// CHECK-DAG: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Type typeOrAttr}}
-// CHECK-DAG: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Attribute}}
-// CHECK-DAG: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]](
-// CHECK: if(failed([[$TYPE_CONSTRAINT]]
+// CHECK: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
+// CHECK-NEXT: {{.*::mlir::Type type}}
+// CHECK: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
+// CHECK-NEXT: {{.*::mlir::Attribute attr}}
+// CHECK: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]](
// CHECK: if(failed([[$ATTR_CONSTRAINT]]
+// CHECK: if(failed([[$TYPE_CONSTRAINT]]
// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)),
diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
index 6aab5abf0d6e7..e5ac14e8f7a2a 100644
--- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
@@ -13,6 +13,7 @@
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/Pattern.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Path.h"
@@ -22,43 +23,9 @@ using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
-StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
- const llvm::RecordKeeper &records)
- : uniqueOutputLabel(getUniqueName(records)) {}
-
-StaticVerifierFunctionEmitter &
-StaticVerifierFunctionEmitter::setSelf(StringRef str) {
- fctx.withSelf(str);
- return *this;
-}
-
-StaticVerifierFunctionEmitter &
-StaticVerifierFunctionEmitter::setBuilder(StringRef str) {
- fctx.withBuilder(str);
- return *this;
-}
-
-void StaticVerifierFunctionEmitter::emitConstraintMethodsInNamespace(
- StringRef signatureFormat, StringRef errorHandlerFormat,
- StringRef cppNamespace, ArrayRef<const void *> constraints, raw_ostream &os,
- bool emitDecl) {
- llvm::Optional<NamespaceEmitter> namespaceEmitter;
- if (!emitDecl)
- namespaceEmitter.emplace(os, cppNamespace);
-
- emitConstraintMethods(signatureFormat, errorHandlerFormat, constraints, os,
- emitDecl);
-}
-
-StringRef StaticVerifierFunctionEmitter::getConstraintFn(
- const Constraint &constraint) const {
- auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
- assert(it != localTypeConstraints.end() && "expected valid constraint fn");
- return it->second;
-}
-
-std::string StaticVerifierFunctionEmitter::getUniqueName(
- const llvm::RecordKeeper &records) {
+/// Generate a unique label based on the current file name to prevent name
+/// collisions if multiple generated files are included at once.
+static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();
@@ -77,66 +44,306 @@ std::string StaticVerifierFunctionEmitter::getUniqueName(
return uniqueName;
}
-void StaticVerifierFunctionEmitter::emitConstraintMethods(
- StringRef signatureFormat, StringRef errorHandlerFormat,
- ArrayRef<const void *> constraints, raw_ostream &rawOs, bool emitDecl) {
- raw_indented_ostream os(rawOs);
-
- // Record the mapping from predicate to constraint. If two constraints has the
- // same predicate and constraint summary, they can share the same verification
- // function.
- llvm::DenseMap<Pred, const void *> predToConstraint;
- for (auto it : llvm::enumerate(constraints)) {
- std::string name;
- Constraint constraint = Constraint::getFromOpaquePointer(it.value());
- Pred pred = constraint.getPredicate();
- auto iter = predToConstraint.find(pred);
- if (iter != predToConstraint.end()) {
- do {
- Constraint built = Constraint::getFromOpaquePointer(iter->second);
- // We may have the
diff erent constraints but have the same predicate,
- // for example, ConstraintA and Variadic<ConstraintA>, note that
- // Variadic<> doesn't introduce new predicate. In this case, we can
- // share the same predicate function if they also have consistent
- // summary, otherwise we may report the wrong message while verification
- // fails.
- if (constraint.getSummary() == built.getSummary()) {
- name = getConstraintFn(built).str();
- break;
- }
- ++iter;
- } while (iter != predToConstraint.end() && iter->first == pred);
- }
+StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
+ raw_ostream &os, const llvm::RecordKeeper &records)
+ : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
+
+void StaticVerifierFunctionEmitter::emitOpConstraints(
+ ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
+ collectOpConstraints(opDefs);
+ if (emitDecl)
+ return;
+
+ NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
+ emitTypeConstraints();
+ emitAttrConstraints();
+ emitSuccessorConstraints();
+ emitRegionConstraints();
+}
+
+void StaticVerifierFunctionEmitter::emitPatternConstraints(
+ const DenseSet<DagLeaf> &constraints) {
+ collectPatternConstraints(constraints);
+ emitPatternConstraints();
+}
+
+//===----------------------------------------------------------------------===//
+// Constraint Getters
+
+StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
+ const Constraint &constraint) const {
+ auto it = typeConstraints.find(constraint);
+ assert(it != typeConstraints.end() && "expected to find a type constraint");
+ return it->second;
+}
+
+// Find a uniqued attribute constraint. Since not all attribute constraints can
+// be uniqued, return None if one was not found.
+Optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn(
+ const Constraint &constraint) const {
+ auto it = attrConstraints.find(constraint);
+ return it == attrConstraints.end() ? Optional<StringRef>()
+ : StringRef(it->second);
+}
+
+StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn(
+ const Constraint &constraint) const {
+ auto it = successorConstraints.find(constraint);
+ assert(it != successorConstraints.end() &&
+ "expected to find a sucessor constraint");
+ return it->second;
+}
+
+StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn(
+ const Constraint &constraint) const {
+ auto it = regionConstraints.find(constraint);
+ assert(it != regionConstraints.end() &&
+ "expected to find a region constraint");
+ return it->second;
+}
+
+//===----------------------------------------------------------------------===//
+// Constraint Emission
+
+/// Code templates for emitting type, attribute, successor, and region
+/// constraints. Each of these templates require the following arguments:
+///
+/// {0}: The unique constraint name.
+/// {1}: The constraint code.
+/// {2}: The constraint description.
+
+/// Code for a type constraint. These may be called on the type of either
+/// operands or results.
+static const char *const typeConstraintCode = R"(
+static ::mlir::LogicalResult {0}(
+ ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
+ unsigned valueIndex) {
+ if (!({1})) {
+ return op->emitOpError(valueKind) << " #" << valueIndex
+ << " must be {2}, but got " << type;
+ }
+ return ::mlir::success();
+}
+)";
+
+/// Code for an attribute constraint. These may be called from ops only.
+/// Attribute constraints cannot reference anything other than `$_self` and
+/// `$_op`.
+///
+/// TODO: Unique constraints for adaptors. However, most Adaptor::verify
+/// functions are stripped anyways.
+static const char *const attrConstraintCode = R"(
+static ::mlir::LogicalResult {0}(
+ ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {
+ if (attr && !({1})) {
+ return op->emitOpError("attribute '") << attrName
+ << "' failed to satisfy constraint: {2}";
+ }
+ return ::mlir::success();
+}
+)";
- if (!name.empty()) {
- localTypeConstraints.try_emplace(it.value(), name);
- continue;
+/// Code for a successor constraint.
+static const char *const successorConstraintCode = R"(
+static ::mlir::LogicalResult {0}(
+ ::mlir::Operation *op, ::mlir::Block *successor,
+ ::llvm::StringRef successorName, unsigned successorIndex) {
+ if (!({1})) {
+ return op->emitOpError("successor #") << successorIndex << " ('"
+ << successorName << ")' failed to verify constraint: {2}";
+ }
+ return ::mlir::success();
+}
+)";
+
+/// Code for a region constraint. Callers will need to pass in the region's name
+/// for emitting an error message.
+static const char *const regionConstraintCode = R"(
+static ::mlir::LogicalResult {0}(
+ ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName,
+ unsigned regionIndex) {
+ if (!({1})) {
+ return op->emitOpError("region #") << regionIndex
+ << (regionName.empty() ? " " : " ('" + regionName + "') ")
+ << "failed to verify constraint: {2}";
+ }
+ return ::mlir::success();
+}
+)";
+
+/// Code for a pattern type or attribute constraint.
+///
+/// {3}: "Type type" or "Attribute attr".
+static const char *const patternAttrOrTypeConstraintCode = R"(
+static ::mlir::LogicalResult {0}(
+ ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3},
+ ::llvm::StringRef failureStr) {
+ if (!({1})) {
+ return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
+ diag << failureStr << ": {2}";
+ });
+ }
+ return ::mlir::success();
+}
+)";
+
+void StaticVerifierFunctionEmitter::emitConstraints(
+ const ConstraintMap &constraints, StringRef selfName,
+ const char *const codeTemplate) {
+ FmtContext ctx;
+ ctx.withOp("*op").withSelf(selfName);
+ for (auto &it : constraints) {
+ os << formatv(codeTemplate, it.second,
+ tgfmt(it.first.getConditionTemplate(), &ctx),
+ it.first.getSummary());
+ }
+}
+
+void StaticVerifierFunctionEmitter::emitTypeConstraints() {
+ emitConstraints(typeConstraints, "type", typeConstraintCode);
+}
+
+void StaticVerifierFunctionEmitter::emitAttrConstraints() {
+ emitConstraints(attrConstraints, "attr", attrConstraintCode);
+}
+
+void StaticVerifierFunctionEmitter::emitSuccessorConstraints() {
+ emitConstraints(successorConstraints, "successor", successorConstraintCode);
+}
+
+void StaticVerifierFunctionEmitter::emitRegionConstraints() {
+ emitConstraints(regionConstraints, "region", regionConstraintCode);
+}
+
+void StaticVerifierFunctionEmitter::emitPatternConstraints() {
+ FmtContext ctx;
+ ctx.withOp("*op").withBuilder("rewriter").withSelf("type");
+ for (auto &it : typeConstraints) {
+ os << formatv(patternAttrOrTypeConstraintCode, it.second,
+ tgfmt(it.first.getConditionTemplate(), &ctx),
+ it.first.getSummary(), "Type type");
+ }
+ ctx.withSelf("attr");
+ for (auto &it : attrConstraints) {
+ os << formatv(patternAttrOrTypeConstraintCode, it.second,
+ tgfmt(it.first.getConditionTemplate(), &ctx),
+ it.first.getSummary(), "Attribute attr");
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Constraint Uniquing
+
+using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
+
+Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() {
+ return Constraint(RecordDenseMapInfo::getEmptyKey(),
+ Constraint::CK_Uncategorized);
+}
+
+Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() {
+ return Constraint(RecordDenseMapInfo::getTombstoneKey(),
+ Constraint::CK_Uncategorized);
+}
+
+unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue(
+ Constraint constraint) {
+ if (constraint == getEmptyKey())
+ return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
+ if (constraint == getTombstoneKey()) {
+ return RecordDenseMapInfo::getHashValue(
+ RecordDenseMapInfo::getTombstoneKey());
+ }
+ return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
+}
+
+bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs,
+ Constraint rhs) {
+ if (lhs == rhs)
+ return true;
+ if (lhs == getEmptyKey() || lhs == getTombstoneKey())
+ return false;
+ if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+ return false;
+ return lhs.getPredicate() == rhs.getPredicate() &&
+ lhs.getSummary() == rhs.getSummary();
+}
+
+/// An attribute constraint that references anything other than itself and the
+/// current op cannot be generically extracted into a function. Most
+/// prohibitive are operands and results, which require calls to
+/// `getODSOperands` or `getODSResults`. Attribute references are tricky too
+/// because ops use cached identifiers.
+static bool canUniqueAttrConstraint(Attribute attr) {
+ FmtContext ctx;
+ auto test =
+ tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op"))
+ .str();
+ return !StringRef(test).contains("<no-subst-found>");
+}
+
+std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind,
+ unsigned index) {
+ return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel +
+ Twine(index))
+ .str();
+}
+
+void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map,
+ StringRef kind,
+ Constraint constraint) {
+ auto it = map.find(constraint);
+ if (it == map.end())
+ map.insert({constraint, getUniqueName(kind, map.size())});
+}
+
+void StaticVerifierFunctionEmitter::collectOpConstraints(
+ ArrayRef<Record *> opDefs) {
+ const auto collectTypeConstraints = [&](Operator::value_range values) {
+ for (const NamedTypeConstraint &value : values)
+ if (value.hasPredicate())
+ collectConstraint(typeConstraints, "type", value.constraint);
+ };
+
+ for (Record *def : opDefs) {
+ Operator op(*def);
+ /// Collect type constraints.
+ collectTypeConstraints(op.getOperands());
+ collectTypeConstraints(op.getResults());
+ /// Collect attribute constraints.
+ for (const NamedAttribute &namedAttr : op.getAttributes()) {
+ if (!namedAttr.attr.getPredicate().isNull() &&
+ canUniqueAttrConstraint(namedAttr.attr))
+ collectConstraint(attrConstraints, "attr", namedAttr.attr);
+ }
+ /// Collect successor constraints.
+ for (const NamedSuccessor &successor : op.getSuccessors()) {
+ if (!successor.constraint.getPredicate().isNull()) {
+ collectConstraint(successorConstraints, "successor",
+ successor.constraint);
+ }
}
+ /// Collect region constraints.
+ for (const NamedRegion ®ion : op.getRegions())
+ if (!region.constraint.getPredicate().isNull())
+ collectConstraint(regionConstraints, "region", region.constraint);
+ }
+}
- // Generate an obscure and unique name for this type constraint.
- name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
- Twine(it.index()))
- .str();
- predToConstraint.insert(
- std::make_pair(constraint.getPredicate(), it.value()));
- localTypeConstraints.try_emplace(it.value(), name);
-
- // Only generate the methods if we are generating definitions.
- if (emitDecl)
- continue;
-
- os << formatv(signatureFormat.data(), name) << " {\n";
- os.indent() << "if (!(" << tgfmt(constraint.getConditionTemplate(), &fctx)
- << ")) {\n";
- os.indent() << "return "
- << formatv(errorHandlerFormat.data(),
- escapeString(constraint.getSummary()))
- << ";\n";
- os.unindent() << "}\nreturn ::mlir::success();\n";
- os.unindent() << "}\n\n";
+void StaticVerifierFunctionEmitter::collectPatternConstraints(
+ const DenseSet<DagLeaf> &constraints) {
+ for (auto &leaf : constraints) {
+ assert(leaf.isOperandMatcher() || leaf.isAttrMatcher());
+ collectConstraint(
+ leaf.isOperandMatcher() ? typeConstraints : attrConstraints,
+ leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint());
}
}
+//===----------------------------------------------------------------------===//
+// Public Utility Functions
+//===----------------------------------------------------------------------===//
+
std::string mlir::tblgen::escapeString(StringRef value) {
std::string ret;
llvm::raw_string_ostream os(ret);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 8bfebb060629d..b60d8a2c98dd8 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -127,14 +127,6 @@ static const char *const valueRangeReturnCode = R"(
std::next({0}, valueRange.first + valueRange.second)};
)";
-static const char *const typeVerifierSignature =
- "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type "
- "type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
-
-static const char *const typeVerifierErrorHandler =
- " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << \" must "
- "be {0}, but got \" << type";
-
static const char *const opCommentHeader = R"(
//===----------------------------------------------------------------------===//
// {0} {1}
@@ -477,29 +469,42 @@ static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
// Generate attribute verification. If an op instance is not available, then
// attribute checks that require one will not be emitted.
-static void genAttributeVerifier(const OpOrAdaptorHelper &emitHelper,
- FmtContext &ctx, OpMethodBody &body) {
+static void genAttributeVerifier(
+ const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, OpMethodBody &body,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
// Check that a required attribute exists.
//
// {0}: Attribute variable name.
// {1}: Emit error prefix.
// {2}: Attribute name.
- const char *const checkRequiredAttr = R"(
+ const char *const verifyRequiredAttr = R"(
if (!{0})
return {1}"requires attribute '{2}'");
- )";
- // Check the condition on an attribute if it is required. This assumes that
- // default values are valid.
+)";
+ // Verify the attribute if it is present. This assumes that default values
+ // are valid. This code snippet pastes the condition inline.
+ //
// TODO: verify the default value is valid (perhaps in debug mode only).
//
// {0}: Attribute variable name.
// {1}: Attribute condition code.
// {2}: Emit error prefix.
- // {3}: Attribute/constraint description.
- const char *const checkAttrCondition = R"(
+ // {3}: Attribute name.
+ // {4}: Attribute/constraint description.
+ const char *const verifyAttrInline = R"(
if ({0} && !({1}))
return {2}"attribute '{3}' failed to satisfy constraint: {4}");
- )";
+)";
+ // Verify the attribute using a uniqued constraint. Can only be used within
+ // the context of an op.
+ //
+ // {0}: Unique constraint name.
+ // {1}: Attribute variable name.
+ // {2}: Attribute name.
+ const char *const verifyAttrUnique = R"(
+ if (::mlir::failed({0}(*this, {1}, "{2}")))
+ return ::mlir::failure();
+)";
for (const auto &namedAttr : emitHelper.getOp().getAttributes()) {
const auto &attr = namedAttr.attr;
@@ -513,7 +518,8 @@ static void genAttributeVerifier(const OpOrAdaptorHelper &emitHelper,
// If the attribute's condition needs an op but none is available, then the
// condition cannot be emitted.
bool canEmitCondition =
- !StringRef(condition).contains("$_op") || emitHelper.isEmittingForOp();
+ !condition.empty() && (!StringRef(condition).contains("$_op") ||
+ emitHelper.isEmittingForOp());
// Prefix with `tblgen_` to avoid hiding the attribute accessor.
Twine varName = tblgenNamePrefix + attrName;
@@ -527,16 +533,22 @@ static void genAttributeVerifier(const OpOrAdaptorHelper &emitHelper,
emitHelper.getAttr(attrName));
if (!allowMissingAttr) {
- body << formatv(checkRequiredAttr, varName, emitHelper.emitErrorPrefix(),
+ body << formatv(verifyRequiredAttr, varName, emitHelper.emitErrorPrefix(),
attrName);
}
if (canEmitCondition) {
- body << formatv(checkAttrCondition, varName,
- tgfmt(condition, &ctx.withSelf(varName)),
- emitHelper.emitErrorPrefix(), attrName,
- escapeString(attr.getSummary()));
+ Optional<StringRef> constraintFn;
+ if (emitHelper.isEmittingForOp() &&
+ (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
+ body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
+ } else {
+ body << formatv(verifyAttrInline, varName,
+ tgfmt(condition, &ctx.withSelf(varName)),
+ emitHelper.emitErrorPrefix(), attrName,
+ escapeString(attr.getSummary()));
+ }
}
- body << "}\n";
+ body << " }\n";
}
}
@@ -2209,7 +2221,7 @@ void OpEmitter::genVerifier() {
bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
populateSubstitutions(emitHelper, verifyCtx);
- genAttributeVerifier(emitHelper, verifyCtx, body);
+ genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
genOperandResultVerifier(body, op.getOperands(), "operand");
genOperandResultVerifier(body, op.getResults(), "result");
@@ -2238,10 +2250,38 @@ void OpEmitter::genVerifier() {
void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
Operator::value_range values,
StringRef valueKind) {
+ // Check that an optional value is at most 1 element.
+ //
+ // {0}: Value index.
+ // {1}: "operand" or "result"
+ const char *const verifyOptional = R"(
+ if (valueGroup{0}.size() > 1) {
+ return emitOpError("{1} group starting at #") << index
+ << " requires 0 or 1 element, but found " << valueGroup{0}.size();
+ }
+)";
+ // Check the types of a range of values.
+ //
+ // {0}: Value index.
+ // {1}: Type constraint function.
+ // {2}: "operand" or "result"
+ const char *const verifyValues = R"(
+ for (auto v : valueGroup{0}) {
+ if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
+ return ::mlir::failure();
+ }
+)";
+
+ const auto canSkip = [](const NamedTypeConstraint &value) {
+ return !value.hasPredicate() && !value.isOptional() &&
+ !value.isVariadicOfVariadic();
+ };
+ if (values.empty() || llvm::all_of(values, canSkip))
+ return;
+
FmtContext fctx;
- body << " {\n";
- body << " unsigned index = 0; (void)index;\n";
+ body << " {\n unsigned index = 0; (void)index;\n";
for (auto staticValue : llvm::enumerate(values)) {
const NamedTypeConstraint &value = staticValue.value();
@@ -2259,11 +2299,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
// If the constraint is optional check that the value group has at most 1
// value.
if (isOptional) {
- body << formatv(" if (valueGroup{0}.size() > 1)\n"
- " return emitOpError(\"{1} group starting at #\") "
- "<< index << \" requires 0 or 1 element, but found \" << "
- "valueGroup{0}.size();\n",
- staticValue.index(), valueKind);
+ body << formatv(verifyOptional, staticValue.index(), valueKind);
} else if (isVariadicOfVariadic) {
body << formatv(
" if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
@@ -2278,93 +2314,89 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
continue;
// Emit a loop to check all the dynamic values in the pack.
StringRef constraintFn =
- staticVerifierEmitter.getConstraintFn(value.constraint);
- body << " for (::mlir::Value v : valueGroup" << staticValue.index()
- << ") {\n"
- << " if (::mlir::failed(" << constraintFn
- << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n"
- << " return ::mlir::failure();\n"
- << " ++index;\n"
- << " }\n";
+ staticVerifierEmitter.getTypeConstraintFn(value.constraint);
+ body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind);
}
body << " }\n";
}
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
+ /// Code to verify a region.
+ ///
+ /// {0}: Getter for the regions.
+ /// {1}: The region constraint.
+ /// {2}: The region's name.
+ /// {3}: The region description.
+ const char *const verifyRegion = R"(
+ for (auto ®ion : {0})
+ if (::mlir::failed({1}(*this, region, "{2}", index++)))
+ return ::mlir::failure();
+)";
+ /// Get a single region.
+ ///
+ /// {0}: The region's index.
+ const char *const getSingleRegion =
+ "::llvm::makeMutableArrayRef((*this)->getRegion({0}))";
+
// If we have no regions, there is nothing more to do.
- unsigned numRegions = op.getNumRegions();
- if (numRegions == 0)
+ const auto canSkip = [](const NamedRegion ®ion) {
+ return region.constraint.getPredicate().isNull();
+ };
+ auto regions = op.getRegions();
+ if (regions.empty() && llvm::all_of(regions, canSkip))
return;
- body << "{\n";
- body << " unsigned index = 0; (void)index;\n";
-
- for (unsigned i = 0; i < numRegions; ++i) {
- const auto ®ion = op.getRegion(i);
- if (region.constraint.getPredicate().isNull())
+ body << " {\n unsigned index = 0; (void)index;\n";
+ for (auto it : llvm::enumerate(regions)) {
+ const auto ®ion = it.value();
+ if (canSkip(region))
continue;
- body << " for (::mlir::Region ®ion : ";
- body << formatv(region.isVariadic()
- ? "{0}()"
- : "::mlir::MutableArrayRef<::mlir::Region>((*this)"
- "->getRegion({1}))",
- op.getGetterName(region.name), i);
- body << ") {\n";
- auto constraint = tgfmt(region.constraint.getConditionTemplate(),
- &verifyCtx.withSelf("region"))
- .str();
-
- body << formatv(" (void)region;\n"
- " if (!({0})) {\n "
- "return emitOpError(\"region #\") << index << \" {1}"
- "failed to "
- "verify constraint: {2}\";\n }\n",
- constraint,
- region.name.empty() ? "" : "('" + region.name + "') ",
- region.constraint.getSummary())
- << " ++index;\n"
- << " }\n";
+ auto getRegion = region.isVariadic()
+ ? formatv("{0}()", op.getGetterName(region.name)).str()
+ : formatv(getSingleRegion, it.index()).str();
+ auto constraintFn =
+ staticVerifierEmitter.getRegionConstraintFn(region.constraint);
+ body << formatv(verifyRegion, getRegion, constraintFn, region.name);
}
body << " }\n";
}
void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
+ const char *const verifySuccessor = R"(
+ for (auto *successor : {0})
+ if (::mlir::failed({1}(*this, successor, "{2}", index++)))
+ return ::mlir::failure();
+)";
+ /// Get a single successor.
+ ///
+ /// {0}: The successor's name.
+ const char *const getSingleSuccessor = "::llvm::makeMutableArrayRef({0}())";
+
// If we have no successors, there is nothing more to do.
- unsigned numSuccessors = op.getNumSuccessors();
- if (numSuccessors == 0)
+ const auto canSkip = [](const NamedSuccessor &successor) {
+ return successor.constraint.getPredicate().isNull();
+ };
+ auto successors = op.getSuccessors();
+ if (successors.empty() && llvm::all_of(successors, canSkip))
return;
- body << "{\n";
- body << " unsigned index = 0; (void)index;\n";
+ body << " {\n unsigned index = 0; (void)index;\n";
- for (unsigned i = 0; i < numSuccessors; ++i) {
- const auto &successor = op.getSuccessor(i);
- if (successor.constraint.getPredicate().isNull())
+ for (auto it : llvm::enumerate(successors)) {
+ const auto &successor = it.value();
+ if (canSkip(successor))
continue;
- if (successor.isVariadic()) {
- body << formatv(" for (::mlir::Block *successor : {0}()) {\n",
- successor.name);
- } else {
- body << " {\n";
- body << formatv(" ::mlir::Block *successor = {0}();\n",
- successor.name);
- }
- auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
- &verifyCtx.withSelf("successor"))
- .str();
-
- body << formatv(" (void)successor;\n"
- " if (!({0})) {\n "
- "return emitOpError(\"successor #\") << index << \"('{1}') "
- "failed to "
- "verify constraint: {2}\";\n }\n",
- constraint, successor.name,
- successor.constraint.getSummary())
- << " ++index;\n"
- << " }\n";
+ auto getSuccessor =
+ formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor,
+ successor.name, it.index())
+ .str();
+ auto constraintFn =
+ staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint);
+ body << formatv(verifySuccessor, getSuccessor, constraintFn,
+ successor.name);
}
body << " }\n";
}
@@ -2504,11 +2536,16 @@ namespace {
// getters identical to those defined in the Op.
class OpOperandAdaptorEmitter {
public:
- static void emitDecl(const Operator &op, raw_ostream &os);
- static void emitDef(const Operator &op, raw_ostream &os);
+ static void emitDecl(const Operator &op,
+ StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ raw_ostream &os);
+ static void emitDef(const Operator &op,
+ StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ raw_ostream &os);
private:
- explicit OpOperandAdaptorEmitter(const Operator &op);
+ explicit OpOperandAdaptorEmitter(
+ const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter);
// Add verification function. This generates a verify method for the adaptor
// which verifies all the op-independent attribute constraints.
@@ -2516,11 +2553,14 @@ class OpOperandAdaptorEmitter {
const Operator &op;
Class adaptor;
+ StaticVerifierFunctionEmitter &staticVerifierEmitter;
};
} // end anonymous namespace
-OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
- : op(op), adaptor(op.getAdaptorName()) {
+OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
+ const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter)
+ : op(op), adaptor(op.getAdaptorName()),
+ staticVerifierEmitter(staticVerifierEmitter) {
adaptor.newField("::mlir::ValueRange", "odsOperands");
adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
adaptor.newField("::mlir::RegionRange", "odsRegions");
@@ -2644,17 +2684,21 @@ void OpOperandAdaptorEmitter::addVerification() {
FmtContext verifyCtx;
populateSubstitutions(emitHelper, verifyCtx);
- genAttributeVerifier(emitHelper, verifyCtx, body);
+ genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
body << " return ::mlir::success();";
}
-void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
- OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
+void OpOperandAdaptorEmitter::emitDecl(
+ const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ raw_ostream &os) {
+ OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os);
}
-void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
- OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
+void OpOperandAdaptorEmitter::emitDef(
+ const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ raw_ostream &os) {
+ OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os);
}
// Emits the opcode enum and op classes.
@@ -2679,27 +2723,9 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
return;
// Generate all of the locally instantiated methods first.
- StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper);
+ StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
- staticVerifierEmitter.setSelf("type");
-
- // Collect a set of all of the used type constraints within the operation
- // definitions.
- llvm::SetVector<const void *> typeConstraints;
- for (Record *def : defs) {
- Operator op(*def);
- for (NamedTypeConstraint &operand : op.getOperands())
- if (operand.hasPredicate())
- typeConstraints.insert(operand.constraint.getAsOpaquePointer());
- for (NamedTypeConstraint &result : op.getResults())
- if (result.hasPredicate())
- typeConstraints.insert(result.constraint.getAsOpaquePointer());
- }
-
- staticVerifierEmitter.emitConstraintMethodsInNamespace(
- typeVerifierSignature, typeVerifierErrorHandler,
- Operator(*defs[0]).getCppNamespace(), typeConstraints.getArrayRef(), os,
- emitDecl);
+ staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
for (auto *def : defs) {
Operator op(*def);
@@ -2708,7 +2734,7 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
NamespaceEmitter emitter(os, op.getCppNamespace());
os << formatv(opCommentHeader, op.getQualCppClassName(),
"declarations");
- OpOperandAdaptorEmitter::emitDecl(op, os);
+ OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
}
// Emit the TypeID explicit specialization to have a single definition.
@@ -2719,7 +2745,7 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
{
NamespaceEmitter emitter(os, op.getCppNamespace());
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
- OpOperandAdaptorEmitter::emitDef(op, os);
+ OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
OpEmitter::emitDef(op, os, staticVerifierEmitter);
}
// Emit the TypeID explicit specialization to have a single definition.
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 886fd1d07ac45..bc74503f709e9 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -42,23 +42,6 @@ using llvm::RecordKeeper;
#define DEBUG_TYPE "mlir-tblgen-rewritergen"
-// The signature of static type verification function
-static const char *typeVerifierSignature =
- "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
- "::mlir::Operation *op, ::mlir::Type typeOrAttr, "
- "::llvm::StringRef failureStr)";
-
-// The signature of static attribute verification function
-static const char *attrVerifierSignature =
- "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
- "::mlir::Operation *op, ::mlir::Attribute typeOrAttr, "
- "::llvm::StringRef failureStr)";
-
-// The template of error handler in static type/attribute verification function
-static const char *verifierErrorHandler =
- "rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {\n diag "
- "<< failureStr << \": {0}\";\n});";
-
namespace llvm {
template <>
struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
@@ -273,7 +256,7 @@ class PatternEmitter {
// inlining them.
class StaticMatcherHelper {
public:
- StaticMatcherHelper(const RecordKeeper &recordKeeper,
+ StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
RecordOperatorMap &mapper);
// Determine if we should inline the match logic or delegate to a static
@@ -289,7 +272,7 @@ class StaticMatcherHelper {
}
// Get the name of static type/attribute verification function.
- StringRef getVerifierName(Constraint constraint);
+ StringRef getVerifierName(DagLeaf leaf);
// Collect the `Record`s, i.e., the DRR, so that we can get the information of
// the duplicated DAGs.
@@ -541,7 +524,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
self = argName;
else
self = formatv("{0}.getType()", argName);
- StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
+ StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
emitStaticVerifierCall(
verifier, opName, self,
formatv("\"operand {0} of native code call '{1}' failed to satisfy "
@@ -684,7 +667,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
PrintFatalError(loc, error);
}
auto self = formatv("(*{0}.begin()).getType()", operandName);
- StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
+ StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher);
emitStaticVerifierCall(
verifier, opName, self.str(),
formatv(
@@ -809,8 +792,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
// If a constraint is specified, we need to generate function call to its
// static verifier.
- StringRef verifier =
- staticMatcherHelper.getVerifierName(matcher.getAsConstraint());
+ StringRef verifier = staticMatcherHelper.getVerifierName(matcher);
emitStaticVerifierCall(
verifier, opName, "tblgen_attr",
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
@@ -1690,9 +1672,10 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
}
}
-StaticMatcherHelper::StaticMatcherHelper(const RecordKeeper &recordKeeper,
+StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
+ const RecordKeeper &recordKeeper,
RecordOperatorMap &mapper)
- : opMap(mapper), staticVerifierEmitter(recordKeeper) {}
+ : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}
void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
// PatternEmitter will use the static matcher if there's one generated. To
@@ -1713,28 +1696,7 @@ void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
}
void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
- llvm::SetVector<const void *> typeConstraints;
- llvm::SetVector<const void *> attrConstraints;
- for (DagLeaf leaf : constraints) {
- if (leaf.isOperandMatcher()) {
- typeConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer());
- } else {
- assert(leaf.isAttrMatcher());
- attrConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer());
- }
- }
-
- staticVerifierEmitter.setBuilder("rewriter").setSelf("typeOrAttr");
-
- staticVerifierEmitter.emitConstraintMethods(typeVerifierSignature,
- verifierErrorHandler,
- typeConstraints.getArrayRef(), os,
- /*emitDecl=*/false);
-
- staticVerifierEmitter.emitConstraintMethods(attrVerifierSignature,
- verifierErrorHandler,
- attrConstraints.getArrayRef(), os,
- /*emitDecl=*/false);
+ staticVerifierEmitter.emitPatternConstraints(constraints);
}
void StaticMatcherHelper::addPattern(Record *record) {
@@ -1765,8 +1727,15 @@ void StaticMatcherHelper::addPattern(Record *record) {
dfs(pat.getSourcePattern());
}
-StringRef StaticMatcherHelper::getVerifierName(Constraint constraint) {
- return staticVerifierEmitter.getConstraintFn(constraint);
+StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
+ if (leaf.isAttrMatcher()) {
+ Optional<StringRef> constraint =
+ staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint());
+ assert(constraint.hasValue() && "attribute constraint was not uniqued");
+ return *constraint;
+ }
+ assert(leaf.isOperandMatcher());
+ return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
}
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
@@ -1779,7 +1748,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
// Exam all the patterns and generate static matcher for the duplicated
// DagNode.
- StaticMatcherHelper staticMatcher(recordKeeper, recordOpMap);
+ StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap);
for (Record *p : patterns)
staticMatcher.addPattern(p);
staticMatcher.populateStaticConstraintFunctions(os);
More information about the Mlir-commits
mailing list