[Mlir-commits] [mlir] 15a6559 - [mlir][ods] ODS-level Attribute Optimizations
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 11 11:34:11 PDT 2022
Author: Mogball
Date: 2022-04-11T18:34:07Z
New Revision: 15a65594c1c0321671a827582fd66a22e0dbdc6e
URL: https://github.com/llvm/llvm-project/commit/15a65594c1c0321671a827582fd66a22e0dbdc6e
DIFF: https://github.com/llvm/llvm-project/commit/15a65594c1c0321671a827582fd66a22e0dbdc6e.diff
LOG: [mlir][ods] ODS-level Attribute Optimizations
This patch contains several ODS-level optimizations to attribute getters and getting.
1. OpAdaptors, when provided a DictionaryAttr, will instantiate an OperationName so that adaptor attribute getters can used cached identifiers.
2. Verifiers will take advantage of attributes stored in sorted order to get all required (non-optional, non-default valued, and non-derived) attributes in one pass over the attribute dictionary and verify that they are present.
3. ODS-generated attribute getters will use "subrange" lookup. Because the attributes are stored in sorted order and ODS knows which attributes are required, the number of required attributes less than and greater than each attribute can be computed. When searching for an attribute, the ends of the search range can be dropped.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D122430
Added:
Modified:
mlir/include/mlir/IR/OperationSupport.h
mlir/test/Dialect/LLVMIR/global.mlir
mlir/test/IR/attribute.mlir
mlir/test/mlir-tblgen/constraint-unique.td
mlir/test/mlir-tblgen/op-attribute.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index e67dc63f1db4b..52138b79554eb 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -428,6 +428,23 @@ std::pair<IteratorT, bool> findAttrSorted(IteratorT first, IteratorT last,
return findAttrUnsorted(first, last, name);
}
+/// Get an attribute from a sorted range of named attributes. Returns null if
+/// the attribute was not found.
+template <typename IteratorT, typename NameT>
+Attribute getAttrFromSortedRange(IteratorT first, IteratorT last, NameT name) {
+ std::pair<IteratorT, bool> result = findAttrSorted(first, last, name);
+ return result.second ? result.first->getValue() : Attribute();
+}
+
+/// Get an attribute from a sorted range of named attributes. Returns None if
+/// the attribute was not found.
+template <typename IteratorT, typename NameT>
+Optional<NamedAttribute>
+getNamedAttrFromSortedRange(IteratorT first, IteratorT last, NameT name) {
+ std::pair<IteratorT, bool> result = findAttrSorted(first, last, name);
+ return result.second ? *result.first : Optional<NamedAttribute>();
+}
+
} // namespace impl
//===----------------------------------------------------------------------===//
@@ -447,7 +464,7 @@ class NamedAttrList {
NamedAttrList() : dictionarySorted({}, true) {}
NamedAttrList(ArrayRef<NamedAttribute> attributes);
NamedAttrList(DictionaryAttr attributes);
- NamedAttrList(const_iterator in_start, const_iterator in_end);
+ NamedAttrList(const_iterator inStart, const_iterator inEnd);
bool operator!=(const NamedAttrList &other) const {
return !(*this == other);
@@ -478,15 +495,15 @@ class NamedAttrList {
typename = std::enable_if_t<std::is_convertible<
typename std::iterator_traits<IteratorT>::iterator_category,
std::input_iterator_tag>::value>>
- void append(IteratorT in_start, IteratorT in_end) {
+ void append(IteratorT inStart, IteratorT inEnd) {
// TODO: expand to handle case where values appended are in order & after
// end of current list.
dictionarySorted.setPointerAndInt(nullptr, false);
- attrs.append(in_start, in_end);
+ attrs.append(inStart, inEnd);
}
/// Replaces the attributes with new list of attributes.
- void assign(const_iterator in_start, const_iterator in_end);
+ void assign(const_iterator inStart, const_iterator inEnd);
/// Replaces the attributes with new list of attributes.
void assign(ArrayRef<NamedAttribute> range) {
diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index 07e5ac0f061ad..bb0f6d7704e3c 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -81,7 +81,7 @@ llvm.mlir.global internal constant @sectionvar("teststring") {section = ".mysec
// -----
// expected-error @+1 {{op requires attribute 'sym_name'}}
-"llvm.mlir.global"() ({}) {type = i64, constant, global_type = i64, value = 42 : i64} : () -> ()
+"llvm.mlir.global"() ({}) {linkage = "private", type = i64, constant, global_type = i64, value = 42 : i64} : () -> ()
// -----
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 5da93acb80ffb..29235df349d25 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -230,7 +230,8 @@ func @wrong_int_attrs_type_fail() {
"test.int_attrs"() {
any_i32_attr = 5.0 : f32,
si32_attr = 7 : si32,
- ui32_attr = 6 : ui32
+ ui32_attr = 6 : ui32,
+ index_attr = 1 : index
} : () -> ()
return
}
diff --git a/mlir/test/mlir-tblgen/constraint-unique.td b/mlir/test/mlir-tblgen/constraint-unique.td
index 9fc60ebb6a4bf..aafea05250bf0 100644
--- a/mlir/test/mlir-tblgen/constraint-unique.td
+++ b/mlir/test/mlir-tblgen/constraint-unique.td
@@ -116,7 +116,7 @@ def OpC : NS_Op<"op_c"> {
/// Test that the uniqued constraints are being used.
// CHECK-LABEL: OpA::verify
-// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName());
+// CHECK: ::mlir::Attribute [[$B_ATTR:.*b]];
// CHECK: if (::mlir::failed([[$A_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b")))
// CHECK-NEXT: return ::mlir::failure();
// CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0);
@@ -137,7 +137,7 @@ def OpC : NS_Op<"op_c"> {
/// Test that the op with the same predicates but
diff erent with descriptions
/// uses the
diff erent constraints.
// CHECK-LABEL: OpC::verify
-// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName());
+// CHECK: ::mlir::Attribute [[$B_ATTR:.*b]];
// CHECK: if (::mlir::failed([[$O_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b")))
// CHECK-NEXT: return ::mlir::failure();
// CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0);
diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 16cc898becc7a..7058102b91e39 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -65,29 +65,40 @@ def AOp : NS_Op<"a_op", []> {
// ---
// DEF: ::mlir::LogicalResult AOpAdaptor::verify
-// DEF: auto tblgen_aAttr = odsAttrs.get("aAttr");
-// DEF-NEXT: if (!tblgen_aAttr)
-// DEF-NEXT: return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'");
+// DEF: ::mlir::Attribute tblgen_aAttr;
+// DEF-NEXT: while (true) {
+// DEF-NEXT: if (namedAttrIt == namedAttrRange.end())
+// DEF-NEXT: return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'");
+// DEF-NEXT: if (namedAttrIt->getName() == AOp::aAttrAttrName(*odsOpName)) {
+// DEF-NEXT: tblgen_aAttr = namedAttrIt->getValue();
+// DEF-NEXT: break;
+// DEF: ::mlir::Attribute tblgen_bAttr;
+// DEF-NEXT: ::mlir::Attribute tblgen_cAttr;
+// DEF-NEXT: while (true) {
+// DEF-NEXT: if (namedAttrIt == namedAttrRange.end())
+// DEF-NEXT: break;
+// DEF: if (namedAttrIt->getName() == AOp::bAttrAttrName(*odsOpName))
+// DEF-NEXT: tblgen_bAttr = namedAttrIt->getValue();
+// DEF: if (namedAttrIt->getName() == AOp::cAttrAttrName(*odsOpName))
+// DEF-NEXT: tblgen_cAttr = namedAttrIt->getValue();
// DEF: if (tblgen_aAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind");
-// DEF: auto tblgen_bAttr = odsAttrs.get("bAttr");
-// DEF-NEXT: if (tblgen_bAttr && !((some-condition)))
+// DEF: if (tblgen_bAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind");
-// DEF: auto tblgen_cAttr = odsAttrs.get("cAttr");
-// DEF-NEXT: if (tblgen_cAttr && !((some-condition)))
+// DEF: if (tblgen_cAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind");
// Test getter methods
// ---
// DEF: some-attr-kind AOp::aAttrAttr()
-// DEF-NEXT: (*this)->getAttr(aAttrAttrName()).cast<some-attr-kind>()
+// DEF-NEXT: ::mlir::impl::getAttrFromSortedRange((*this)->getAttrs().begin() + 0, (*this)->getAttrs().end() - 0, aAttrAttrName()).cast<some-attr-kind>()
// DEF: some-return-type AOp::aAttr() {
// DEF-NEXT: auto attr = aAttrAttr()
// DEF-NEXT: return attr.some-convert-from-storage();
// DEF: some-attr-kind AOp::bAttrAttr()
-// DEF-NEXT: return (*this)->getAttr(bAttrAttrName()).dyn_cast_or_null<some-attr-kind>()
+// DEF-NEXT: ::mlir::impl::getAttrFromSortedRange((*this)->getAttrs().begin() + 1, (*this)->getAttrs().end() - 0, bAttrAttrName()).dyn_cast_or_null<some-attr-kind>()
// DEF: some-return-type AOp::bAttr() {
// DEF-NEXT: auto attr = bAttrAttr();
// DEF-NEXT: if (!attr)
@@ -95,7 +106,7 @@ def AOp : NS_Op<"a_op", []> {
// DEF-NEXT: return attr.some-convert-from-storage();
// DEF: some-attr-kind AOp::cAttrAttr()
-// DEF-NEXT: return (*this)->getAttr(cAttrAttrName()).dyn_cast_or_null<some-attr-kind>()
+// DEF-NEXT: ::mlir::impl::getAttrFromSortedRange((*this)->getAttrs().begin() + 1, (*this)->getAttrs().end() - 0, cAttrAttrName()).dyn_cast_or_null<some-attr-kind>()
// DEF: ::llvm::Optional<some-return-type> AOp::cAttr() {
// DEF-NEXT: auto attr = cAttrAttr()
// DEF-NEXT: return attr ? ::llvm::Optional<some-return-type>(attr.some-convert-from-storage()) : (::llvm::None);
@@ -179,29 +190,29 @@ def AgetOp : Op<Test2_Dialect, "a_get_op", []> {
// ---
// DEF: ::mlir::LogicalResult AgetOpAdaptor::verify
-// DEF: auto tblgen_aAttr = odsAttrs.get("aAttr");
-// DEF-NEXT: if (!tblgen_aAttr)
-// DEF-NEXT. return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'");
+// DEF: ::mlir::Attribute tblgen_aAttr;
+// DEF-NEXT: while (true)
+// DEF: ::mlir::Attribute tblgen_bAttr;
+// DEF-NEXT: ::mlir::Attribute tblgen_cAttr;
+// DEF-NEXT: while (true)
// DEF: if (tblgen_aAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind");
-// DEF: auto tblgen_bAttr = odsAttrs.get("bAttr");
-// DEF-NEXT: if (tblgen_bAttr && !((some-condition)))
+// DEF: if (tblgen_bAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind");
-// DEF: auto tblgen_cAttr = odsAttrs.get("cAttr");
-// DEF-NEXT: if (tblgen_cAttr && !((some-condition)))
+// DEF: if (tblgen_cAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind");
// Test getter methods
// ---
// DEF: some-attr-kind AgetOp::getAAttrAttr()
-// DEF-NEXT: (*this)->getAttr(getAAttrAttrName()).cast<some-attr-kind>()
+// DEF-NEXT: ::mlir::impl::getAttrFromSortedRange({{.*}}).cast<some-attr-kind>()
// DEF: some-return-type AgetOp::getAAttr() {
// DEF-NEXT: auto attr = getAAttrAttr()
// DEF-NEXT: return attr.some-convert-from-storage();
// DEF: some-attr-kind AgetOp::getBAttrAttr()
-// DEF-NEXT: return (*this)->getAttr(getBAttrAttrName()).dyn_cast_or_null<some-attr-kind>()
+// DEF-NEXT: return ::mlir::impl::getAttrFromSortedRange({{.*}}).dyn_cast_or_null<some-attr-kind>()
// DEF: some-return-type AgetOp::getBAttr() {
// DEF-NEXT: auto attr = getBAttrAttr();
// DEF-NEXT: if (!attr)
@@ -209,7 +220,7 @@ def AgetOp : Op<Test2_Dialect, "a_get_op", []> {
// DEF-NEXT: return attr.some-convert-from-storage();
// DEF: some-attr-kind AgetOp::getCAttrAttr()
-// DEF-NEXT: return (*this)->getAttr(getCAttrAttrName()).dyn_cast_or_null<some-attr-kind>()
+// DEF-NEXT: return ::mlir::impl::getAttrFromSortedRange({{.*}}).dyn_cast_or_null<some-attr-kind>()
// DEF: ::llvm::Optional<some-return-type> AgetOp::getCAttr() {
// DEF-NEXT: auto attr = getCAttrAttr()
// DEF-NEXT: return attr ? ::llvm::Optional<some-return-type>(attr.some-convert-from-storage()) : (::llvm::None);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index be7ff05e2c993..f4ccd3012d9da 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -26,6 +26,7 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
@@ -43,10 +44,22 @@ static const char *const generatedArgName = "odsArg";
static const char *const odsBuilder = "odsBuilder";
static const char *const builderOpState = "odsState";
-/// Code for an Op to lookup an attribute. Uses cached identifiers.
+/// The names of the implicit attributes that contain variadic operand and
+/// result segment sizes.
+static const char *const operandSegmentAttrName = "operand_segment_sizes";
+static const char *const resultSegmentAttrName = "result_segment_sizes";
+
+/// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
+/// lookup.
///
-/// {0}: The attribute's getter name.
-static const char *const opGetAttr = "(*this)->getAttr({0}AttrName())";
+/// {0}: Code snippet to get the attribute's name or identifier.
+/// {1}: The lower bound on the sorted subrange.
+/// {2}: The upper bound on the sorted subrange.
+/// {3}: Code snippet to get the array of named attributes.
+/// {4}: "Named" to get the named attribute.
+static const char *const subrangeGetAttr =
+ "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
+ "{2}, {0})";
/// The logic to calculate the actual value range for a declared operand/result
/// of an op with variadic operands/results. Note that this logic is not for
@@ -80,16 +93,6 @@ static const char *const sameVariadicSizeValueRangeCalcCode = R"(
/// of an op with variadic operands/results. Note that this logic is assumes
/// the op has an attribute specifying the size of each operand/result segment
/// (variadic or not).
-///
-/// {0}: The name of the attribute specifying the segment sizes.
-static const char *const adapterSegmentSizeAttrInitCode = R"(
- assert(odsAttrs && "missing segment size attribute for op");
- auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
-)";
-static const char *const opSegmentSizeAttrInitCode = R"(
- auto sizeAttr =
- (*this)->getAttr({0}AttrName()).cast<::mlir::DenseIntElementsAttr>();
-)";
static const char *const attrSizedSegmentValueRangeCalcCode = R"(
const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>();
if (sizeAttr.isSplat())
@@ -100,6 +103,19 @@ static const char *const attrSizedSegmentValueRangeCalcCode = R"(
start += sizeAttrValueIt[i];
return {start, sizeAttrValueIt[index]};
)";
+/// The code snippet to initialize the sizes for the value range calculation.
+///
+/// {0}: The code to get the attribute.
+static const char *const adapterSegmentSizeAttrInitCode = R"(
+ assert(odsAttrs && "missing segment size attribute for op");
+ auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>();
+)";
+/// The code snippet to initialize the sizes for the value range calculation.
+///
+/// {0}: The code to get the attribute.
+static const char *const opSegmentSizeAttrInitCode = R"(
+ auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>();
+)";
/// The logic to calculate the actual value range for a declared operand
/// of an op with variadic of variadic operands within the OpAdaptor.
@@ -179,11 +195,31 @@ static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
}
namespace {
+/// Metadata on a registered attribute. Given that attributes are stored in
+/// sorted order on operations, we can use information from ODS to deduce the
+/// number of required attributes less and and greater than each attribute,
+/// allowing us to search only a subrange of the attributes in ODS-generated
+/// getters.
+struct AttributeMetadata {
+ /// The attribute name.
+ StringRef attrName;
+ /// Whether the attribute is required.
+ bool isRequired;
+ /// The ODS attribute constraint. Not present for implicit attributes.
+ Optional<Attribute> constraint;
+ /// The number of required attributes less than this attribute.
+ unsigned lowerBound = 0;
+ /// The number of required attributes greater than this attribute.
+ unsigned upperBound = 0;
+};
+
/// Helper class to select between OpAdaptor and Op code templates.
class OpOrAdaptorHelper {
public:
OpOrAdaptorHelper(const Operator &op, bool emitForOp)
- : op(op), emitForOp(emitForOp) {}
+ : op(op), emitForOp(emitForOp) {
+ computeAttrMetadata();
+ }
/// Object that wraps a functor in a stream operator for interop with
/// llvm::formatv.
@@ -208,14 +244,31 @@ class OpOrAdaptorHelper {
};
// Generate code for getting an attribute.
- Formatter getAttr(StringRef attrName) const {
+ Formatter getAttr(StringRef attrName, bool isNamed = false) const {
+ assert(attrMetadata.count(attrName) && "expected attribute metadata");
+ return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
+ const AttributeMetadata &attr = attrMetadata.find(attrName)->second;
+ return os << formatv(subrangeGetAttr, getAttrName(attrName),
+ attr.lowerBound, attr.upperBound, getAttrRange(),
+ isNamed ? "Named" : "");
+ };
+ }
+
+ // Generate code for getting the name of an attribute.
+ Formatter getAttrName(StringRef attrName) const {
return [this, attrName](raw_ostream &os) -> raw_ostream & {
- if (!emitForOp)
- return os << formatv("odsAttrs.get(\"{0}\")", attrName);
- return os << formatv(opGetAttr, op.getGetterName(attrName));
+ if (emitForOp)
+ return os << op.getGetterName(attrName) << "AttrName()";
+ return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(),
+ op.getGetterName(attrName));
};
}
+ // Get the code snippet for getting the named attribute range.
+ StringRef getAttrRange() const {
+ return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
+ }
+
// Get the prefix code for emitting an error.
Formatter emitErrorPrefix() const {
return [this](raw_ostream &os) -> raw_ostream & {
@@ -254,14 +307,74 @@ class OpOrAdaptorHelper {
// Return the ODS operation wrapper.
const Operator &getOp() const { return op; }
+ // Get the attribute metadata sorted by name.
+ const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const {
+ return attrMetadata;
+ }
+
private:
+ // Compute the attribute metadata.
+ void computeAttrMetadata();
+
// The operation ODS wrapper.
const Operator &op;
// True if code is being generate for an op. False for an adaptor.
const bool emitForOp;
+
+ // The attribute metadata, mapped by name.
+ llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
+ // The number of required attributes.
+ unsigned numRequired;
};
+
} // namespace
+void OpOrAdaptorHelper::computeAttrMetadata() {
+ // Enumerate the attribute names of this op, ensuring the attribute names are
+ // unique in case implicit attributes are explicitly registered.
+ for (const NamedAttribute &namedAttr : op.getAttributes()) {
+ Attribute attr = namedAttr.attr;
+ bool isOptional =
+ attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr();
+ attrMetadata.insert(
+ {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}});
+ }
+ // Include key attributes from several traits as implicitly registered.
+ if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
+ attrMetadata.insert(
+ {operandSegmentAttrName,
+ AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true,
+ /*attr=*/llvm::None}});
+ }
+ if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
+ attrMetadata.insert(
+ {resultSegmentAttrName,
+ AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true,
+ /*attr=*/llvm::None}});
+ }
+
+ // Store the metadata in sorted order.
+ SmallVector<AttributeMetadata> sortedAttrMetadata =
+ llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector()));
+ llvm::sort(sortedAttrMetadata,
+ [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) {
+ return lhs.attrName < rhs.attrName;
+ });
+
+ // Compute the subrange bounds for each attribute.
+ numRequired = 0;
+ for (AttributeMetadata &attr : sortedAttrMetadata) {
+ attr.lowerBound = numRequired;
+ numRequired += attr.isRequired;
+ };
+ for (AttributeMetadata &attr : sortedAttrMetadata)
+ attr.upperBound = numRequired - attr.lowerBound - attr.isRequired;
+
+ // Store the results back into the map.
+ for (const AttributeMetadata &attr : sortedAttrMetadata)
+ attrMetadata.insert({attr.attrName, attr});
+}
+
//===----------------------------------------------------------------------===//
// Op emitter
//===----------------------------------------------------------------------===//
@@ -438,7 +551,7 @@ class OpEmitter {
const Record &def;
// The wrapper operator class for querying information from this op.
- Operator op;
+ const Operator &op;
// The C++ code builder for this op
OpClass opClass;
@@ -448,6 +561,9 @@ class OpEmitter {
// The emitter containing all of the locally emitted verification functions.
const StaticVerifierFunctionEmitter &staticVerifierEmitter;
+
+ // Helper for emitting op code.
+ OpOrAdaptorHelper emitHelper;
};
} // namespace
@@ -476,20 +592,59 @@ static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
}
}
+/// Generate verification on native traits requiring attributes.
+static void genNativeTraitAttrVerifier(MethodBody &body,
+ const OpOrAdaptorHelper &emitHelper) {
+ // Check that the variadic segment sizes attribute exists and contains the
+ // expected number of elements.
+ //
+ // {0}: Attribute name.
+ // {1}: Expected number of elements.
+ // {2}: "operand" or "result".
+ // {3}: Emit error prefix.
+ const char *const checkAttrSizedValueSegmentsCode = R"(
+ {
+ auto sizeAttr = tblgen_{0}.cast<::mlir::DenseIntElementsAttr>();
+ auto numElements =
+ sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
+ if (numElements != {1})
+ return {3}"'{0}' attribute for specifying {2} segments must have {1} "
+ "elements, but got ") << numElements;
+ }
+ )";
+
+ // Verify a few traits first so that we can use getODSOperands() and
+ // getODSResults() in the rest of the verifier.
+ auto &op = emitHelper.getOp();
+ if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
+ body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName,
+ op.getNumOperands(), "operand",
+ emitHelper.emitErrorPrefix());
+ }
+ if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
+ body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName,
+ op.getNumResults(), "result", emitHelper.emitErrorPrefix());
+ }
+}
+
// Generate attribute verification. If an op instance is not available, then
// attribute checks that require one will not be emitted.
+//
+// Attribute verification is performed as follows:
+//
+// 1. Verify that all required attributes are present in sorted order. This
+// ensures that we can use subrange lookup even with potentially missing
+// attributes.
+// 2. Verify native trait attributes so that other attributes may call methods
+// that depend on the validity of these attributes, e.g. segment size attributes
+// and operand or result getters.
+// 3. Verify the constraints on all present attributes.
static void genAttributeVerifier(
const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
- // Check that a required attribute exists.
- //
- // {0}: Attribute variable name.
- // {1}: Emit error prefix.
- // {2}: Attribute name.
- const char *const verifyRequiredAttr = R"(
- if (!{0})
- return {1}"requires attribute '{2}'");
-)";
+ if (emitHelper.getAttrMetadata().empty())
+ return;
+
// Verify the attribute if it is present. This assumes that default values
// are valid. This code snippet pastes the condition inline.
//
@@ -501,8 +656,8 @@ static void genAttributeVerifier(
// {3}: Attribute name.
// {4}: Attribute/constraint description.
const char *const verifyAttrInline = R"(
- if ({0} && !({1}))
- return {2}"attribute '{3}' failed to satisfy constraint: {4}");
+ if ({0} && !({1}))
+ return {2}"attribute '{3}' failed to satisfy constraint: {4}");
)";
// Verify the attribute using a uniqued constraint. Can only be used within
// the context of an op.
@@ -511,54 +666,128 @@ static void genAttributeVerifier(
// {1}: Attribute variable name.
// {2}: Attribute name.
const char *const verifyAttrUnique = R"(
- if (::mlir::failed({0}(*this, {1}, "{2}")))
- return ::mlir::failure();
+ if (::mlir::failed({0}(*this, {1}, "{2}")))
+ return ::mlir::failure();
)";
- for (const auto &namedAttr : emitHelper.getOp().getAttributes()) {
- const auto &attr = namedAttr.attr;
- StringRef attrName = namedAttr.name;
+ // Traverse the array until the required attribute is found. Return an error
+ // if the traversal reached the end.
+ //
+ // {0}: Code to get the name of the attribute.
+ // {1}: The emit error prefix.
+ // {2}: The name of the attribute.
+ const char *const findRequiredAttr = R"(while (true) {{
+ if (namedAttrIt == namedAttrRange.end())
+ return {1}"requires attribute '{2}'");
+ if (namedAttrIt->getName() == {0}) {{
+ tblgen_{2} = namedAttrIt->getValue();
+ break;
+ })";
+
+ // Emit a check to see if the iteration has encountered an optional attribute.
+ //
+ // {0}: Code to get the name of the attribute.
+ // {1}: The name of the attribute.
+ const char *const checkOptionalAttr = R"(
+ else if (namedAttrIt->getName() == {0}) {{
+ tblgen_{1} = namedAttrIt->getValue();
+ })";
+
+ // Emit the start of the loop for checking trailing attributes.
+ const char *const checkTrailingAttrs = R"(while (true) {
+ if (namedAttrIt == namedAttrRange.end()) {
+ break;
+ })";
+
+ // Return true if a verifier can be emitted for the attribute: it is not a
+ // derived attribute, it has a predicate, its condition is not empty, and, for
+ // adaptors, the condition does not reference the op.
+ const auto canEmitVerifier = [&](Attribute attr) {
if (attr.isDerivedAttr())
- continue;
+ return false;
+ Pred pred = attr.getPredicate();
+ if (pred.isNull())
+ return false;
+ std::string condition = pred.getCondition();
+ return !condition.empty() && (!StringRef(condition).contains("$_op") ||
+ emitHelper.isEmittingForOp());
+ };
- bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
- auto attrPred = attr.getPredicate();
- std::string condition = attrPred.isNull() ? "" : attrPred.getCondition();
- // If the attribute's condition needs an op but none is available, then the
- // condition cannot be emitted.
- bool canEmitCondition =
- !condition.empty() && (!StringRef(condition).contains("$_op") ||
- emitHelper.isEmittingForOp());
-
- // Prefix with `tblgen_` to avoid hiding the attribute accessor.
- std::string varName = (tblgenNamePrefix + attrName).str();
-
- // If the attribute is not required and we cannot emit the condition, then
- // there is nothing to be done.
- if (allowMissingAttr && !canEmitCondition)
- continue;
+ // Emit the verifier for the attribute.
+ const auto emitVerifier = [&](Attribute attr, StringRef attrName,
+ StringRef varName) {
+ std::string condition = attr.getPredicate().getCondition();
- body << formatv(" {\n auto {0} = {1};", varName,
- emitHelper.getAttr(attrName));
+ Optional<StringRef> constraintFn;
+ if (emitHelper.isEmittingForOp() &&
+ (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
+ body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
+ } else {
+ body << formatv(verifyAttrInline, varName,
+ tgfmt(condition, &ctx.withSelf(varName)),
+ emitHelper.emitErrorPrefix(), attrName,
+ escapeString(attr.getSummary()));
+ }
+ };
- if (!allowMissingAttr) {
- body << formatv(verifyRequiredAttr, varName, emitHelper.emitErrorPrefix(),
- attrName);
+ // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
+ const auto getVarName = [&](StringRef attrName) {
+ return (tblgenNamePrefix + attrName).str();
+ };
+
+ body.indent() << formatv("auto namedAttrRange = {0};\n",
+ emitHelper.getAttrRange());
+ body << "auto namedAttrIt = namedAttrRange.begin();\n";
+
+ // Iterate over the attributes in sorted order. Keep track of the optional
+ // attributes that may be encountered along the way.
+ SmallVector<const AttributeMetadata *> optionalAttrs;
+ for (const std::pair<StringRef, AttributeMetadata> &it :
+ emitHelper.getAttrMetadata()) {
+ const AttributeMetadata &metadata = it.second;
+ if (!metadata.isRequired) {
+ optionalAttrs.push_back(&metadata);
+ continue;
}
- if (canEmitCondition) {
- Optional<StringRef> constraintFn;
- if (emitHelper.isEmittingForOp() &&
- (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
- body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
- } else {
- body << formatv(verifyAttrInline, varName,
- tgfmt(condition, &ctx.withSelf(varName)),
- emitHelper.emitErrorPrefix(), attrName,
- escapeString(attr.getSummary()));
- }
+
+ body << formatv("::mlir::Attribute {0};\n", getVarName(it.first));
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv("::mlir::Attribute {0};\n",
+ getVarName(optional->attrName));
+ }
+ body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first),
+ emitHelper.emitErrorPrefix(), it.first);
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv(checkOptionalAttr,
+ emitHelper.getAttrName(optional->attrName),
+ optional->attrName);
+ }
+ body << "\n ++namedAttrIt;\n}\n";
+ optionalAttrs.clear();
+ }
+ // Get trailing optional attributes.
+ if (!optionalAttrs.empty()) {
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv("::mlir::Attribute {0};\n",
+ getVarName(optional->attrName));
+ }
+ body << checkTrailingAttrs;
+ for (const AttributeMetadata *optional : optionalAttrs) {
+ body << formatv(checkOptionalAttr,
+ emitHelper.getAttrName(optional->attrName),
+ optional->attrName);
}
- body << " }\n";
+ body << "\n ++namedAttrIt;\n}\n";
}
+ body.unindent();
+
+ // Emit the checks for segment attributes first so that the other constraints
+ // can call operand and result getters.
+ genNativeTraitAttrVerifier(body, emitHelper);
+
+ for (const auto &namedAttr : emitHelper.getOp().getAttributes())
+ if (canEmitVerifier(namedAttr.attr))
+ emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
}
/// Op extra class definitions have a `$cppClass` substitution that is to be
@@ -573,7 +802,8 @@ OpEmitter::OpEmitter(const Operator &op,
: def(op.getDef()), op(op),
opClass(op.getCppClassName(), op.getExtraClassDeclaration(),
formatExtraDefinitions(op)),
- staticVerifierEmitter(staticVerifierEmitter) {
+ staticVerifierEmitter(staticVerifierEmitter),
+ emitHelper(op, /*emitForOp=*/true) {
verifyCtx.withOp("(*this->getOperation())");
verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
@@ -635,26 +865,12 @@ static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
op.getOperationName() + " (from line " +
Twine(line) + ")");
}
+
#define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
void OpEmitter::genAttrNameGetters() {
- // A map of attribute names (including implicit attributes) registered to the
- // current operation, to the relative order in which they were registered.
- llvm::MapVector<StringRef, unsigned> attributeNames;
-
- // Enumerate the attribute names of this op, assigning each a relative
- // ordering.
- auto addAttrName = [&](StringRef name) {
- unsigned index = attributeNames.size();
- attributeNames.insert({name, index});
- };
- for (const NamedAttribute &namedAttr : op.getAttributes())
- addAttrName(namedAttr.name);
- // Include key attributes from several traits as implicitly registered.
- if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
- addAttrName("operand_segment_sizes");
- if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
- addAttrName("result_segment_sizes");
+ const llvm::MapVector<StringRef, AttributeMetadata> &attributes =
+ emitHelper.getAttrMetadata();
// Emit the getAttributeNames method.
{
@@ -662,20 +878,18 @@ void OpEmitter::genAttrNameGetters() {
"::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
ERROR_IF_PRUNED(method, "getAttributeNames", op);
auto &body = method->body();
- if (attributeNames.empty()) {
+ if (attributes.empty()) {
body << " return {};";
- } else {
- body << " static ::llvm::StringRef attrNames[] = {";
- llvm::interleaveComma(llvm::make_first_range(attributeNames), body,
- [&](StringRef attrName) {
- body << "::llvm::StringRef(\"" << attrName
- << "\")";
- });
- body << "};\n return ::llvm::makeArrayRef(attrNames);";
+ // Nothing else to do if there are no registered attributes. Exit early.
+ return;
}
+ body << " static ::llvm::StringRef attrNames[] = {";
+ llvm::interleaveComma(llvm::make_first_range(attributes), body,
+ [&](StringRef attrName) {
+ body << "::llvm::StringRef(\"" << attrName << "\")";
+ });
+ body << "};\n return ::llvm::makeArrayRef(attrNames);";
}
- if (attributeNames.empty())
- return;
// Emit the getAttributeNameForIndex methods.
{
@@ -697,14 +911,14 @@ void OpEmitter::genAttrNameGetters() {
assert(index < {0} && "invalid attribute index");
return name.getRegisteredInfo()->getAttributeNames()[index];
)";
- method->body() << formatv(getAttrName, attributeNames.size());
+ method->body() << formatv(getAttrName, attributes.size());
}
// Generate the <attr>AttrName methods, that expose the attribute names to
// users.
const char *attrNameMethodBody = " return getAttributeNameForIndex({0});";
- for (const std::pair<StringRef, unsigned> &attrIt : attributeNames) {
- for (StringRef name : op.getGetterNames(attrIt.first)) {
+ for (auto &attrIt : llvm::enumerate(llvm::make_first_range(attributes))) {
+ for (StringRef name : op.getGetterNames(attrIt.value())) {
std::string methodName = (name + "AttrName").str();
// Generate the non-static variant.
@@ -712,7 +926,7 @@ void OpEmitter::genAttrNameGetters() {
auto *method =
opClass.addInlineMethod("::mlir::StringAttr", methodName);
ERROR_IF_PRUNED(method, methodName, op);
- method->body() << llvm::formatv(attrNameMethodBody, attrIt.second);
+ method->body() << llvm::formatv(attrNameMethodBody, attrIt.index());
}
// Generate the static variant.
@@ -722,7 +936,7 @@ void OpEmitter::genAttrNameGetters() {
MethodParameter("::mlir::OperationName", "name"));
ERROR_IF_PRUNED(method, methodName, op);
method->body() << llvm::formatv(attrNameMethodBody,
- "name, " + Twine(attrIt.second));
+ "name, " + Twine(attrIt.index()));
}
}
}
@@ -772,12 +986,13 @@ void OpEmitter::genAttrGetters() {
// Generate named accessor with Attribute return type. This is a wrapper class
// that allows referring to the attributes via accessors instead of having to
// use the string interface for better compile time verification.
- auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
+ auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
+ Attribute attr) {
auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr");
if (!method)
return;
method->body() << formatv(
- " return {0}.{1}<{2}>();", formatv(opGetAttr, name),
+ " return {0}.{1}<{2}>();", emitHelper.getAttr(attrName),
attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
: "cast",
attr.getStorageType());
@@ -788,7 +1003,7 @@ void OpEmitter::genAttrGetters() {
if (namedAttr.attr.isDerivedAttr()) {
emitDerivedAttr(name, namedAttr.attr);
} else {
- emitAttrWithStorageType(name, namedAttr.attr);
+ emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
}
}
@@ -1041,8 +1256,8 @@ void OpEmitter::genNamedOperandGetters() {
// array.
std::string attrSizeInitCode;
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- std::string attr = op.getGetterName("operand_segment_sizes");
- attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str();
+ attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
+ emitHelper.getAttr(operandSegmentAttrName));
}
generateNamedOperandGetters(
@@ -1073,17 +1288,17 @@ void OpEmitter::genNamedOperandSetters() {
<< " auto mutableRange = "
"::mlir::MutableOperandRange(getOperation(), "
"range.first, range.second";
- if (attrSizedOperands)
- body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
- << "u, *getOperation()->getAttrDictionary().getNamed("
- << op.getGetterName("operand_segment_sizes") << "AttrName()))";
+ if (attrSizedOperands) {
+ body << formatv(
+ ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
+ emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+ }
body << ");\n";
// If this operand is a nested variadic, we split the range into a
// MutableOperandRangeRange that provides a range over all of the
// sub-ranges.
if (operand.isVariadicOfVariadic()) {
- //
body << " return "
"mutableRange.split(*(*this)->getAttrDictionary().getNamed("
<< op.getGetterName(
@@ -1129,8 +1344,8 @@ void OpEmitter::genNamedResultGetters() {
// Build the initializer string for the result segment size attribute.
std::string attrSizeInitCode;
if (attrSizedResults) {
- std::string attr = op.getGetterName("result_segment_sizes");
- attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str();
+ attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
+ emitHelper.getAttr(resultSegmentAttrName));
}
generateValueRangeStartAndEnd(
@@ -1220,7 +1435,7 @@ void OpEmitter::genNamedSuccessorGetters() {
}
}
-static bool canGenerateUnwrappedBuilder(Operator &op) {
+static bool canGenerateUnwrappedBuilder(const Operator &op) {
// If this op does not have native attributes at all, return directly to avoid
// redefining builders.
if (op.getNumNativeAttributes() == 0)
@@ -1232,7 +1447,7 @@ static bool canGenerateUnwrappedBuilder(Operator &op) {
//
diff erent from the wrapped mlir::Attribute type to avoid redefining
// builders. This checks for the op has at least one such native attribute.
for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
- NamedAttribute &namedAttr = op.getAttribute(i);
+ const NamedAttribute &namedAttr = op.getAttribute(i);
if (canUseUnwrappedRawValue(namedAttr.attr)) {
canGenerate = true;
break;
@@ -1241,7 +1456,7 @@ static bool canGenerateUnwrappedBuilder(Operator &op) {
return canGenerate;
}
-static bool canInferType(Operator &op) {
+static bool canInferType(const Operator &op) {
return op.getTrait("::mlir::InferTypeOpInterface::Trait");
}
@@ -1727,7 +1942,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
const Attribute &attr = namedAttr.attr;
- // inferred attributes don't need to be added to the param list.
+ // Inferred attributes don't need to be added to the param list.
if (inferredAttributes.contains(namedAttr.name))
continue;
@@ -1774,7 +1989,7 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
// Push all operands to the result.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
std::string argName = getArgumentName(op, i);
- NamedTypeConstraint &operand = op.getOperand(i);
+ const NamedTypeConstraint &operand = op.getOperand(i);
if (operand.constraint.isVariadicOfVariadic()) {
body << " for (::mlir::ValueRange range : " << argName << ")\n "
<< builderOpState << ".addOperands(range);\n";
@@ -1800,7 +2015,7 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
// If the operation has the operand segment size attribute, add it here.
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- std::string sizes = op.getGetterName("operand_segment_sizes");
+ std::string sizes = op.getGetterName(operandSegmentAttrName);
body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
<< builderOpState << ".name), "
<< "odsBuilder.getI32VectorAttr({";
@@ -2164,64 +2379,13 @@ void OpEmitter::genPrinter() {
ERROR_IF_PRUNED(method, "print", op);
}
-/// Generate verification on native traits requiring attributes.
-static void genNativeTraitAttrVerifier(MethodBody &body,
- const OpOrAdaptorHelper &emitHelper) {
- // Check that the variadic segment sizes attribute exists and contains the
- // expected number of elements.
- //
- // {0}: Attribute name.
- // {1}: Expected number of elements.
- // {2}: "operand" or "result".
- // {3}: Attribute getter call.
- // {4}: Emit error prefix.
- const char *const checkAttrSizedValueSegmentsCode = R"(
- {
- auto sizeAttr = {3}.dyn_cast<::mlir::DenseIntElementsAttr>();
- if (!sizeAttr)
- return {4}"missing segment sizes attribute '{0}'");
- auto numElements =
- sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
- if (numElements != {1})
- return {4}"'{0}' attribute for specifying {2} segments must have {1} "
- "elements, but got ") << numElements;
- }
- )";
-
- // Verify a few traits first so that we can use getODSOperands() and
- // getODSResults() in the rest of the verifier.
- auto &op = emitHelper.getOp();
- for (auto &trait : op.getTraits()) {
- auto *t = dyn_cast<tblgen::NativeTrait>(&trait);
- if (!t)
- continue;
- std::string traitName = t->getFullyQualifiedTraitName();
- if (traitName == "::mlir::OpTrait::AttrSizedOperandSegments") {
- StringRef attrName = "operand_segment_sizes";
- body << formatv(checkAttrSizedValueSegmentsCode, attrName,
- op.getNumOperands(), "operand",
- emitHelper.getAttr(attrName),
- emitHelper.emitErrorPrefix());
- } else if (traitName == "::mlir::OpTrait::AttrSizedResultSegments") {
- StringRef attrName = "result_segment_sizes";
- body << formatv(
- checkAttrSizedValueSegmentsCode, attrName, op.getNumResults(),
- "result", emitHelper.getAttr(attrName), emitHelper.emitErrorPrefix());
- }
- }
-}
-
void OpEmitter::genVerifier() {
auto *implMethod =
opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl");
ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
auto &implBody = implMethod->body();
- OpOrAdaptorHelper emitHelper(op, /*isOp=*/true);
- genNativeTraitAttrVerifier(implBody, emitHelper);
-
populateSubstitutions(emitHelper, verifyCtx);
-
genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter);
genOperandResultVerifier(implBody, op.getOperands(), "operand");
genOperandResultVerifier(implBody, op.getResults(), "result");
@@ -2574,34 +2738,49 @@ namespace {
// getters identical to those defined in the Op.
class OpOperandAdaptorEmitter {
public:
- static void emitDecl(const Operator &op,
- StaticVerifierFunctionEmitter &staticVerifierEmitter,
- raw_ostream &os);
- static void emitDef(const Operator &op,
- StaticVerifierFunctionEmitter &staticVerifierEmitter,
- raw_ostream &os);
+ static void
+ emitDecl(const Operator &op,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ raw_ostream &os);
+ static void
+ emitDef(const Operator &op,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ raw_ostream &os);
private:
explicit OpOperandAdaptorEmitter(
- const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter);
+ const Operator &op,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter);
// Add verification function. This generates a verify method for the adaptor
// which verifies all the op-independent attribute constraints.
void addVerification();
+ // The operation for which to emit an adaptor.
const Operator &op;
+
+ // The generated adaptor class.
Class adaptor;
- StaticVerifierFunctionEmitter &staticVerifierEmitter;
+
+ // The emitter containing all of the locally emitted verification functions.
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter;
+
+ // Helper for emitting adaptor code.
+ OpOrAdaptorHelper emitHelper;
};
} // namespace
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
- const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter)
+ const Operator &op,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter)
: op(op), adaptor(op.getAdaptorName()),
- staticVerifierEmitter(staticVerifierEmitter) {
+ staticVerifierEmitter(staticVerifierEmitter),
+ emitHelper(op, /*emitForOp=*/false) {
adaptor.addField("::mlir::ValueRange", "odsOperands");
adaptor.addField("::mlir::DictionaryAttr", "odsAttrs");
adaptor.addField("::mlir::RegionRange", "odsRegions");
+ adaptor.addField("::llvm::Optional<::mlir::OperationName>", "odsOpName");
+
const auto *attrSizedOperands =
op.getTrait("::m::OpTrait::AttrSizedOperandSegments");
{
@@ -2615,14 +2794,21 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
constructor->addMemberInitializer("odsOperands", "values");
constructor->addMemberInitializer("odsAttrs", "attrs");
constructor->addMemberInitializer("odsRegions", "regions");
+
+ MethodBody &body = constructor->body();
+ body.indent() << "if (odsAttrs)\n";
+ body.indent() << formatv(
+ "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
+ op.getOperationName());
}
{
- auto *constructor = adaptor.addConstructor(
- MethodParameter(op.getCppClassName() + " &", "op"));
+ auto *constructor =
+ adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
constructor->addMemberInitializer("odsOperands", "op->getOperands()");
constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
constructor->addMemberInitializer("odsRegions", "op->getRegions()");
+ constructor->addMemberInitializer("odsOpName", "op->getName()");
}
{
@@ -2630,8 +2816,11 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
ERROR_IF_PRUNED(m, "getOperands", op);
m->body() << " return odsOperands;";
}
- std::string attr = "operand_segment_sizes";
- std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, attr);
+ std::string sizeAttrInit;
+ if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
+ sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
+ emitHelper.getAttr(operandSegmentAttrName));
+ }
generateNamedOperandGetters(op, adaptor,
/*isAdaptor=*/true, sizeAttrInit,
/*rangeType=*/"::mlir::ValueRange",
@@ -2647,15 +2836,13 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
Attribute attr) {
auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr");
ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
- auto &body = method->body();
- body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
- << "\n " << attr.getStorageType() << " attr = "
- << "odsAttrs.get(\"" << name << "\").";
- if (attr.hasDefaultValue() || attr.isOptional())
- body << "dyn_cast_or_null<";
- else
- body << "cast<";
- body << attr.getStorageType() << ">();\n";
+ auto &body = method->body().indent();
+ body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n"
+ << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name),
+ attr.hasDefaultValue() || attr.isOptional()
+ ? "dyn_cast_or_null"
+ : "cast",
+ attr.getStorageType());
if (attr.hasDefaultValue()) {
// Use the default value if attribute is not set.
@@ -2721,24 +2908,23 @@ void OpOperandAdaptorEmitter::addVerification() {
ERROR_IF_PRUNED(method, "verify", op);
auto &body = method->body();
- OpOrAdaptorHelper emitHelper(op, /*isOp=*/false);
- genNativeTraitAttrVerifier(body, emitHelper);
FmtContext verifyCtx;
populateSubstitutions(emitHelper, verifyCtx);
-
genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
body << " return ::mlir::success();";
}
void OpOperandAdaptorEmitter::emitDecl(
- const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ const Operator &op,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
raw_ostream &os) {
OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os);
}
void OpOperandAdaptorEmitter::emitDef(
- const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ const Operator &op,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
raw_ostream &os) {
OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os);
}
More information about the Mlir-commits
mailing list