[Mlir-commits] [mlir] e4e31e1 - [mlir][OpGen] Cache Identifiers for known attribute names in AbstractOperation.

River Riddle llvmlistbot at llvm.org
Tue Jun 22 12:56:58 PDT 2021


Author: River Riddle
Date: 2021-06-22T19:56:05Z
New Revision: e4e31e19bb87d154a2da1f5479f778c13a120b3c

URL: https://github.com/llvm/llvm-project/commit/e4e31e19bb87d154a2da1f5479f778c13a120b3c
DIFF: https://github.com/llvm/llvm-project/commit/e4e31e19bb87d154a2da1f5479f778c13a120b3c.diff

LOG: [mlir][OpGen] Cache Identifiers for known attribute names in AbstractOperation.

Operations currently rely on the string name of attributes during attribute lookup/removal/replacement, in build methods, and more. This unfortunately means that some of the most used APIs in MLIR require string comparisons, additional hashing(+mutex locking) to construct Identifiers, and more. This revision remedies this by caching identifiers for all of the attributes of the operation in its corresponding AbstractOperation. Just updating the autogenerated usages brings up to a 15% reduction in compile time, greatly reducing the cost of interacting with the attributes of an operation. This number can grow even higher as we use these methods in handwritten C++ code.

Methods for accessing these cached identifiers are exposed via `<attr-name>AttrName` methods on the derived operation class. Moving forward, users should generally use these methods over raw strings when an attribute name is necessary.

Differential Revision: https://reviews.llvm.org/D104167

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/TableGen/OpClass.cpp
    mlir/test/mlir-tblgen/op-attribute.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 164e8b3b83eed..1199d1bb272e2 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -77,6 +77,7 @@ class AffineDmaStartOp
                 AffineMapAccessInterface::Trait> {
 public:
   using Op::Op;
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
 
   static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
                     AffineMap srcMap, ValueRange srcIndices, Value destMemRef,
@@ -268,6 +269,7 @@ class AffineDmaWaitOp
                 AffineMapAccessInterface::Trait> {
 public:
   using Op::Op;
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
 
   static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
                     AffineMap tagMap, ValueRange tagIndices, Value numElements);

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 20ff5df914250..757c9e9a37e54 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -94,6 +94,7 @@ class DmaStartOp
     : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
 public:
   using Op::Op;
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
 
   static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
                     ValueRange srcIndices, Value destMemRef,
@@ -217,6 +218,7 @@ class DmaWaitOp
     : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
 public:
   using Op::Op;
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
 
   static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
                     ValueRange tagIndices, Value numElements);

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 898454a67de7f..43be1ab67d9ba 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -174,7 +174,7 @@ class AbstractOperation {
            T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
            T::getVerifyInvariantsFn(), T::getFoldHookFn(),
            T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
-           T::getHasTraitFn());
+           T::getHasTraitFn(), T::getAttributeNames());
   }
 
   /// Register a new operation in a Dialect object.
@@ -185,7 +185,26 @@ class AbstractOperation {
          ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
          VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
          GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
-         detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait);
+         detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
+         ArrayRef<StringRef> attrNames);
+
+  /// Return the list of cached attribute names registered to this operation.
+  /// The order of attributes cached here is unique to each type of operation,
+  /// and the interpretation of this attribute list should generally be driven
+  /// by the respective operation. In many cases, this caching removes the need
+  /// to use the raw string name of a known attribute.
+  ///
+  /// For example the ODS generator, with an op defining the following
+  /// attributes:
+  ///
+  ///   let arguments = (ins I32Attr:$attr1, I32Attr:$attr2);
+  ///
+  /// ... may produce an order here of ["attr1", "attr2"]. This allows for the
+  /// ODS generator to directly access the cached name for a known attribute,
+  /// greatly simplifying the cost and complexity of attribute usage produced by
+  /// the generator.
+  ///
+  ArrayRef<Identifier> getAttributeNames() const { return attributeNames; }
 
 private:
   AbstractOperation(StringRef name, Dialect &dialect, TypeID typeID,
@@ -194,7 +213,8 @@ class AbstractOperation {
                     VerifyInvariantsFn &&verifyInvariants,
                     FoldHookFn &&foldHook,
                     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
-                    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait);
+                    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
+                    ArrayRef<Identifier> attrNames);
 
   /// Give Op access to lookupMutable.
   template <typename ConcreteType, template <typename T> class... Traits>
@@ -215,6 +235,11 @@ class AbstractOperation {
   ParseAssemblyFn parseAssemblyFn;
   PrintAssemblyFn printAssemblyFn;
   VerifyInvariantsFn verifyInvariantsFn;
+
+  /// A list of attribute names registered to this operation in identifier form.
+  /// This allows for operation classes to use identifiers for attribute
+  /// lookup/creation/etc., as opposed to strings.
+  ArrayRef<Identifier> attributeNames;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index ab12e5ff56693..1041f1c9f5033 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -710,17 +710,30 @@ void AbstractOperation::insert(
     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
-    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait) {
-  AbstractOperation opInfo(name, dialect, typeID, std::move(parseAssembly),
-                           std::move(printAssembly),
-                           std::move(verifyInvariants), std::move(foldHook),
-                           std::move(getCanonicalizationPatterns),
-                           std::move(interfaceMap), std::move(hasTrait));
-
-  auto &impl = dialect.getContext()->getImpl();
+    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
+    ArrayRef<StringRef> attrNames) {
+  MLIRContext *ctx = dialect.getContext();
+  auto &impl = ctx->getImpl();
   assert(impl.multiThreadedExecutionContext == 0 &&
          "Registering a new operation kind while in a multi-threaded execution "
          "context");
+
+  // Register the attribute names of this operation.
+  MutableArrayRef<Identifier> cachedAttrNames;
+  if (!attrNames.empty()) {
+    cachedAttrNames = MutableArrayRef<Identifier>(
+        impl.identifierAllocator.Allocate<Identifier>(attrNames.size()),
+        attrNames.size());
+    for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
+      new (&cachedAttrNames[i]) Identifier(Identifier::get(attrNames[i], ctx));
+  }
+
+  // Register the information for this operation.
+  AbstractOperation opInfo(
+      name, dialect, typeID, std::move(parseAssembly), std::move(printAssembly),
+      std::move(verifyInvariants), std::move(foldHook),
+      std::move(getCanonicalizationPatterns), std::move(interfaceMap),
+      std::move(hasTrait), cachedAttrNames);
   if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) {
     llvm::errs() << "error: operation named '" << name
                  << "' is already registered.\n";
@@ -733,7 +746,8 @@ AbstractOperation::AbstractOperation(
     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
-    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait)
+    detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
+    ArrayRef<Identifier> attrNames)
     : name(Identifier::get(name, dialect.getContext())), dialect(dialect),
       typeID(typeID), interfaceMap(std::move(interfaceMap)),
       foldHookFn(std::move(foldHook)),
@@ -741,7 +755,8 @@ AbstractOperation::AbstractOperation(
       hasTraitFn(std::move(hasTrait)),
       parseAssemblyFn(std::move(parseAssembly)),
       printAssemblyFn(std::move(printAssembly)),
-      verifyInvariantsFn(std::move(verifyInvariants)) {}
+      verifyInvariantsFn(std::move(verifyInvariants)),
+      attributeNames(attrNames) {}
 
 //===----------------------------------------------------------------------===//
 // AbstractType

diff  --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp
index ae20a17f5ca4c..d1453bacdd05f 100644
--- a/mlir/lib/TableGen/OpClass.cpp
+++ b/mlir/lib/TableGen/OpClass.cpp
@@ -201,15 +201,15 @@ void OpMethod::writeDeclTo(raw_ostream &os) const {
   os.indent(2);
   if (isStatic())
     os << "static ";
-  if (properties & MP_Constexpr)
+  if ((properties & MP_Constexpr) == MP_Constexpr)
     os << "constexpr ";
   methodSignature.writeDeclTo(os);
-  if (!isInline())
+  if (!isInline()) {
     os << ";";
-  else {
+  } else {
     os << " {\n";
-    methodBody.writeTo(os);
-    os << "}";
+    methodBody.writeTo(os.indent(2));
+    os.indent(2) << "}";
   }
 }
 

diff  --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 987f417d867ce..cef8cfcb913cd 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -28,6 +28,31 @@ def AOp : NS_Op<"a_op", []> {
   );
 }
 
+// DECL-LABEL: AOp declarations
+
+// Test attribute name methods
+// ---
+
+// DECL:      static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames()
+// DECL-NEXT:   static ::llvm::StringRef attrNames[] =
+// DECL-SAME:     {::llvm::StringRef("aAttr"), ::llvm::StringRef("bAttr"), ::llvm::StringRef("cAttr")};
+// DECL-NEXT:   return attrNames;
+
+// DECL:      ::mlir::Identifier aAttrAttrName()
+// DECL-NEXT:      return getAttributeNameForIndex(0);
+// DECL:      ::mlir::Identifier aAttrAttrName(::mlir::OperationName name)
+// DECL-NEXT:      return getAttributeNameForIndex(name, 0);
+
+// DECL:      ::mlir::Identifier bAttrAttrName()
+// DECL-NEXT:      return getAttributeNameForIndex(1);
+// DECL:      ::mlir::Identifier bAttrAttrName(::mlir::OperationName name)
+// DECL-NEXT:      return getAttributeNameForIndex(name, 1);
+
+// DECL:      ::mlir::Identifier cAttrAttrName()
+// DECL-NEXT:      return getAttributeNameForIndex(2);
+// DECL:      ::mlir::Identifier cAttrAttrName(::mlir::OperationName name)
+// DECL-NEXT:      return getAttributeNameForIndex(name, 2);
+
 // DEF-LABEL: AOp definitions
 
 // Test verify method
@@ -48,13 +73,13 @@ def AOp : NS_Op<"a_op", []> {
 // ---
 
 // DEF:      some-attr-kind AOp::aAttrAttr()
-// DEF-NEXT:   (*this)->getAttr("aAttr").template cast<some-attr-kind>()
+// DEF-NEXT:   (*this)->getAttr(aAttrAttrName()).template cast<some-attr-kind>()
 // DEF:      some-return-type AOp::aAttr() {
 // DEF-NEXT:   auto attr = aAttrAttr()
 // DEF-NEXT:   return attr.some-convert-from-storage();
 
 // DEF:      some-attr-kind AOp::bAttrAttr()
-// DEF-NEXT:   return (*this)->getAttr("bAttr").template dyn_cast_or_null<some-attr-kind>()
+// DEF-NEXT:   return (*this)->getAttr(bAttrAttrName()).template dyn_cast_or_null<some-attr-kind>()
 // DEF:      some-return-type AOp::bAttr() {
 // DEF-NEXT:   auto attr = bAttrAttr();
 // DEF-NEXT:   if (!attr)
@@ -62,7 +87,7 @@ def AOp : NS_Op<"a_op", []> {
 // DEF-NEXT:   return attr.some-convert-from-storage();
 
 // DEF:      some-attr-kind AOp::cAttrAttr()
-// DEF-NEXT:   return (*this)->getAttr("cAttr").template dyn_cast_or_null<some-attr-kind>()
+// DEF-NEXT:   return (*this)->getAttr(cAttrAttrName()).template dyn_cast_or_null<some-attr-kind>()
 // DEF:      ::llvm::Optional<some-return-type> AOp::cAttr() {
 // DEF-NEXT:   auto attr = cAttrAttr()
 // DEF-NEXT:   return attr ? ::llvm::Optional<some-return-type>(attr.some-convert-from-storage()) : (::llvm::None);
@@ -71,30 +96,30 @@ def AOp : NS_Op<"a_op", []> {
 // ---
 
 // DEF:      void AOp::aAttrAttr(some-attr-kind attr) {
-// DEF-NEXT:   (*this)->setAttr("aAttr", attr);
+// DEF-NEXT:   (*this)->setAttr(aAttrAttrName(), attr);
 // DEF:      void AOp::bAttrAttr(some-attr-kind attr) {
-// DEF-NEXT:   (*this)->setAttr("bAttr", attr);
+// DEF-NEXT:   (*this)->setAttr(bAttrAttrName(), attr);
 // DEF:      void AOp::cAttrAttr(some-attr-kind attr) {
-// DEF-NEXT:   (*this)->setAttr("cAttr", attr);
+// DEF-NEXT:   (*this)->setAttr(cAttrAttrName(), attr);
 
 // Test remove methods
 // ---
 
 // DEF: ::mlir::Attribute AOp::removeCAttrAttr() {
-// DEF-NEXT: return (*this)->removeAttr("cAttr");
+// DEF-NEXT: return (*this)->removeAttr(cAttrAttrName());
 
 // Test build methods
 // ---
 
 // DEF:      void AOp::build(
-// DEF:        odsState.addAttribute("aAttr", aAttr);
-// DEF:        odsState.addAttribute("bAttr", bAttr);
+// DEF:        odsState.addAttribute(aAttrAttrName(odsState.name), aAttr);
+// DEF:        odsState.addAttribute(bAttrAttrName(odsState.name), bAttr);
 // DEF:        if (cAttr) {
-// DEF-NEXT:     odsState.addAttribute("cAttr", cAttr);
+// DEF-NEXT:     odsState.addAttribute(cAttrAttrName(odsState.name), cAttr);
 
 // DEF:      void AOp::build(
 // DEF:        some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr
-// DEF:        odsState.addAttribute("aAttr", some-const-builder-call(odsBuilder, aAttr));
+// DEF:        odsState.addAttribute(aAttrAttrName(odsState.name), some-const-builder-call(odsBuilder, aAttr));
 
 // DEF:      void AOp::build(
 // DEF:        ::llvm::ArrayRef<::mlir::NamedAttribute> attributes
@@ -205,8 +230,8 @@ def DOp : NS_Op<"d_op", []> {
 // DECL: static void build({{.*}}, uint32_t i32_attr, ::llvm::APFloat f64_attr, ::llvm::StringRef str_attr, bool bool_attr, ::SomeI32Enum enum_attr, uint32_t dv_i32_attr, ::llvm::APFloat dv_f64_attr, ::llvm::StringRef dv_str_attr = "abc", bool dv_bool_attr = true, ::SomeI32Enum dv_enum_attr = ::SomeI32Enum::case5)
 
 // DEF-LABEL: DOp definitions
-// DEF: odsState.addAttribute("str_attr", odsBuilder.getStringAttr(str_attr));
-// DEF: odsState.addAttribute("dv_str_attr", odsBuilder.getStringAttr(dv_str_attr));
+// DEF: odsState.addAttribute(str_attrAttrName(odsState.name), odsBuilder.getStringAttr(str_attr));
+// DEF: odsState.addAttribute(dv_str_attrAttrName(odsState.name), odsBuilder.getStringAttr(dv_str_attr));
 
 // Test derived type attr.
 // ---
@@ -272,7 +297,7 @@ def UnitAttrOp : NS_Op<"unit_attr_op", []> {
 // DEF:   return {{.*}} != nullptr
 
 // DEF: ::mlir::Attribute UnitAttrOp::removeAttrAttr() {
-// DEF-NEXT:   (*this)->removeAttr("attr");
+// DEF-NEXT:   (*this)->removeAttr(attrAttrName());
 
 // DEF: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::UnitAttr attr)
 

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index cded9afe9dfdb..ea68fe5af0cd3 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -21,6 +21,7 @@
 #include "mlir/TableGen/Operator.h"
 #include "mlir/TableGen/SideEffects.h"
 #include "mlir/TableGen/Trait.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Path.h"
@@ -79,7 +80,7 @@ const char *adapterSegmentSizeAttrInitCode = R"(
   auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
 )";
 const char *opSegmentSizeAttrInitCode = R"(
-  auto sizeAttr = (*this)->getAttr("{0}").cast<::mlir::DenseIntElementsAttr>();
+  auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>();
 )";
 const char *attrSizedSegmentValueRangeCalcCode = R"(
   auto sizeAttrValues = sizeAttr.getValues<uint32_t>();
@@ -306,6 +307,13 @@ class OpEmitter {
   void emitDecl(raw_ostream &os);
   void emitDef(raw_ostream &os);
 
+  // Generate methods for accessing the attribute names of this operation.
+  void genAttrNameGetters();
+
+  // Return the index of the given attribute name. This is a relative ordering
+  // for this name, used in attribute getters.
+  unsigned getAttrNameIndex(StringRef attrName) const;
+
   // Generates the OpAsmOpInterface for this operation if possible.
   void genOpAsmInterface();
 
@@ -460,6 +468,10 @@ class OpEmitter {
 
   // The emitter containing all of the locally emitted verification functions.
   const StaticVerifierFunctionEmitter &staticVerifierEmitter;
+
+  // A map of attribute names (including implicit attributes) registered to the
+  // current operation, to the relative order in which they were registered.
+  llvm::MapVector<StringRef, unsigned> attributeNames;
 };
 } // end anonymous namespace
 
@@ -585,6 +597,7 @@ OpEmitter::OpEmitter(const Operator &op,
 
   // Generate C++ code for various op methods. The order here determines the
   // methods in the generated file.
+  genAttrNameGetters();
   genOpAsmInterface();
   genOpNameGetter();
   genNamedOperandGetters();
@@ -623,18 +636,103 @@ void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
 
 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
 
+void OpEmitter::genAttrNameGetters() {
+  // Enumerate the attribute names of this op, assigning each a relative
+  // ordering.
+  auto addAttrName = [&](StringRef name) {
+    unsigned index = attributeNames.size();
+    attributeNames.insert({name, index});
+  };
+  for (const NamedAttribute &namedAttr : op.getAttributes())
+    addAttrName(namedAttr.name);
+  // Include key attributes from several traits as implicitly registered.
+  if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
+    addAttrName("operand_segment_sizes");
+  if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
+    addAttrName("result_segment_sizes");
+
+  // Emit the getAttributeNames method.
+  {
+    auto *method = opClass.addMethodAndPrune(
+        "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames",
+        OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline));
+    auto &body = method->body();
+    if (attributeNames.empty()) {
+      body << "  return {};";
+    } else {
+      body << "  static ::llvm::StringRef attrNames[] = {";
+      llvm::interleaveComma(llvm::make_first_range(attributeNames), body,
+                            [&](StringRef attrName) {
+                              body << "::llvm::StringRef(\"" << attrName
+                                   << "\")";
+                            });
+      body << "};\n  return attrNames;";
+    }
+  }
+
+  // Emit the getAttributeNameForIndex methods.
+  {
+    auto *method = opClass.addMethodAndPrune(
+        "::mlir::Identifier", "getAttributeNameForIndex",
+        OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline),
+        "unsigned", "index");
+    method->body()
+        << "  return getAttributeNameForIndex((*this)->getName(), index);";
+  }
+  {
+    auto *method = opClass.addMethodAndPrune(
+        "::mlir::Identifier", "getAttributeNameForIndex",
+        OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline |
+                           OpMethod::MP_Static),
+        "::mlir::OperationName name, unsigned index");
+    method->body() << "assert(index < " << attributeNames.size()
+                   << " && \"invalid attribute index\");\n"
+                      "  return name.getAbstractOperation()"
+                      "->getAttributeNames()[index];";
+  }
+
+  // Generate the <attr>AttrName methods, that expose the attribute names to
+  // users.
+  const char *attrNameMethodBody = "  return getAttributeNameForIndex({0});";
+  for (const std::pair<StringRef, unsigned> &attrIt : attributeNames) {
+    std::string methodName = (attrIt.first + "AttrName").str();
+
+    // Generate the non-static variant.
+    {
+      auto *method =
+          opClass.addMethodAndPrune("::mlir::Identifier", methodName,
+                                    OpMethod::Property(OpMethod::MP_Inline));
+      method->body() << llvm::formatv(attrNameMethodBody, attrIt.second).str();
+    }
+
+    // Generate the static variant.
+    {
+      auto *method = opClass.addMethodAndPrune(
+          "::mlir::Identifier", methodName,
+          OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static),
+          "::mlir::OperationName", "name");
+      method->body() << llvm::formatv(attrNameMethodBody,
+                                      "name, " + Twine(attrIt.second))
+                            .str();
+    }
+  }
+}
+
+unsigned OpEmitter::getAttrNameIndex(StringRef attrName) const {
+  auto it = attributeNames.find(attrName);
+  assert(it != attributeNames.end() && "expected attribute name to have been "
+                                       "registered in genAttrNameGetters");
+  return it->second;
+}
+
 void OpEmitter::genAttrGetters() {
   FmtContext fctx;
   fctx.withBuilder("::mlir::Builder((*this)->getContext())");
 
-  Dialect opDialect = op.getDialect();
   // Emit the derived attribute body.
   auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
-    auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
-    if (!method)
-      return;
-    auto &body = method->body();
-    body << "  " << attr.getDerivedCodeBody() << "\n";
+    if (auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name))
+      method->body() << "  " << attr.getDerivedCodeBody() << "\n";
   };
 
   // Emit with return type specified.
@@ -667,7 +765,7 @@ void OpEmitter::genAttrGetters() {
     if (!method)
       return;
     auto &body = method->body();
-    body << "  return (*this)->getAttr(\"" << name << "\").template ";
+    body << "  return (*this)->getAttr(" << name << "AttrName()).template ";
     if (attr.isOptional() || attr.hasDefaultValue())
       body << "dyn_cast_or_null<";
     else
@@ -675,14 +773,12 @@ void OpEmitter::genAttrGetters() {
     body << attr.getStorageType() << ">();";
   };
 
-  for (auto &namedAttr : op.getAttributes()) {
-    const auto &name = namedAttr.name;
-    const auto &attr = namedAttr.attr;
-    if (attr.isDerivedAttr()) {
-      emitDerivedAttr(name, attr);
+  for (const NamedAttribute &namedAttr : op.getAttributes()) {
+    if (namedAttr.attr.isDerivedAttr()) {
+      emitDerivedAttr(namedAttr.name, namedAttr.attr);
     } else {
-      emitAttrWithStorageType(name, attr);
-      emitAttrWithReturnType(name, attr);
+      emitAttrWithStorageType(namedAttr.name, namedAttr.attr);
+      emitAttrWithReturnType(namedAttr.name, namedAttr.attr);
     }
   }
 
@@ -739,8 +835,7 @@ void OpEmitter::genAttrGetters() {
           derivedAttrs, body,
           [&](const NamedAttribute &namedAttr) {
             auto tmpl = namedAttr.attr.getConvertFromStorageCall();
-            body << "    {::mlir::Identifier::get(\"" << namedAttr.name
-                 << "\", ctx),\n"
+            body << "    {" << namedAttr.name << "AttrName(),\n"
                  << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
                                      .withBuilder("odsBuilder")
                                      .addSubst("_ctx", "ctx"))
@@ -759,18 +854,13 @@ void OpEmitter::genAttrSetters() {
   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
     auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
                                              attr.getStorageType(), "attr");
-    if (!method)
-      return;
-    auto &body = method->body();
-    body << "  (*this)->setAttr(\"" << name << "\", attr);";
+    if (method)
+      method->body() << "  (*this)->setAttr(" << name << "AttrName(), attr);";
   };
 
-  for (auto &namedAttr : op.getAttributes()) {
-    const auto &name = namedAttr.name;
-    const auto &attr = namedAttr.attr;
-    if (!attr.isDerivedAttr())
-      emitAttrWithStorageType(name, attr);
-  }
+  for (const NamedAttribute &namedAttr : op.getAttributes())
+    if (!namedAttr.attr.isDerivedAttr())
+      emitAttrWithStorageType(namedAttr.name, namedAttr.attr);
 }
 
 void OpEmitter::genOptionalAttrRemovers() {
@@ -783,16 +873,12 @@ void OpEmitter::genOptionalAttrRemovers() {
         "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
     if (!method)
       return;
-    auto &body = method->body();
-    body << "  return (*this)->removeAttr(\"" << name << "\");";
+    method->body() << "  return (*this)->removeAttr(" << name << "AttrName());";
   };
 
-  for (const auto &namedAttr : op.getAttributes()) {
-    const auto &name = namedAttr.name;
-    const auto &attr = namedAttr.attr;
-    if (attr.isOptional())
-      emitRemoveAttr(name);
-  }
+  for (const NamedAttribute &namedAttr : op.getAttributes())
+    if (namedAttr.attr.isOptional())
+      emitRemoveAttr(namedAttr.name);
 }
 
 // Generates the code to compute the start and end index of an operand or result
@@ -904,10 +990,18 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
 }
 
 void OpEmitter::genNamedOperandGetters() {
+  // Build the code snippet used for initializing the operand_segment_sizes
+  // array.
+  std::string attrSizeInitCode;
+  if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
+    attrSizeInitCode =
+        formatv(opSegmentSizeAttrInitCode, "operand_segment_sizesAttrName()")
+            .str();
+  }
+
   generateNamedOperandGetters(
       op, opClass,
-      /*sizeAttrInit=*/
-      formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(),
+      /*sizeAttrInit=*/attrSizeInitCode,
       /*rangeType=*/"::mlir::Operation::operand_range",
       /*rangeBeginCall=*/"getOperation()->operand_begin()",
       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
@@ -930,7 +1024,7 @@ void OpEmitter::genNamedOperandSetters() {
     if (attrSizedOperands)
       body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
            << "u, *getOperation()->getAttrDictionary().getNamed("
-              "\"operand_segment_sizes\"))";
+              "operand_segment_sizesAttrName()))";
     body << ");\n";
   }
 }
@@ -964,11 +1058,18 @@ void OpEmitter::genNamedResultGetters() {
                     "'SameVariadicResultSize' traits");
   }
 
+  // Build the initializer string for the result segment size attribute.
+  std::string attrSizeInitCode;
+  if (attrSizedResults) {
+    attrSizeInitCode =
+        formatv(opSegmentSizeAttrInitCode, "result_segment_sizesAttrName()")
+            .str();
+  }
+
   generateValueRangeStartAndEnd(
       opClass, "getODSResultIndexAndLength", numVariadicResults,
       numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
-      formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(),
-      op.getResults());
+      attrSizeInitCode, op.getResults());
 
   auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
                                       "getODSResults", "unsigned", "index");
@@ -1298,8 +1399,11 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
   std::string resultType;
   const auto &namedAttr = op.getAttribute(0);
 
-  body << "  for (auto attr : attributes) {\n";
-  body << "    if (attr.first != \"" << namedAttr.name << "\") continue;\n";
+  body << "  auto attrName = " << namedAttr.name << "AttrName("
+       << builderOpState
+       << ".name);\n"
+          "  for (auto attr : attributes) {\n"
+          "    if (attr.first != attrName) continue;\n";
   if (namedAttr.attr.isTypeAttr()) {
     resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
   } else {
@@ -1596,8 +1700,9 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
   // If the operation has the operand segment size attribute, add it here.
   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
     body << "  " << builderOpState
-         << ".addAttribute(\"operand_segment_sizes\", "
-            "odsBuilder.getI32VectorAttr({";
+         << ".addAttribute(operand_segment_sizesAttrName(" << builderOpState
+         << ".name), "
+         << "odsBuilder.getI32VectorAttr({";
     interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
       if (op.getOperand(i).isOptional())
         body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
@@ -1614,9 +1719,9 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
     auto &attr = namedAttr.attr;
     if (!attr.isDerivedAttr()) {
       bool emitNotNullCheck = attr.isOptional();
-      if (emitNotNullCheck) {
+      if (emitNotNullCheck)
         body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
-      }
+
       if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
         // If this is a raw value, then we need to wrap it in an Attribute
         // instance.
@@ -1635,15 +1740,14 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
 
         std::string value =
             std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
-        body << formatv("  {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
-                        namedAttr.name, value);
+        body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
+                        builderOpState, namedAttr.name, value);
       } else {
-        body << formatv("  {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
-                        namedAttr.name);
+        body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {1});\n",
+                        builderOpState, namedAttr.name);
       }
-      if (emitNotNullCheck) {
+      if (emitNotNullCheck)
         body << "  }\n";
-      }
     }
   }
 

diff  --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 287839120ab26..ebf3153844e4a 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -114,6 +114,7 @@ struct TypeNoLayout : public Type::TypeBase<TypeNoLayout, Type, TypeStorage> {
 /// types itself.
 struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
   using Op::Op;
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
 
   static StringRef getOperationName() { return "dltest.op_with_layout"; }
 
@@ -156,6 +157,7 @@ struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
 struct OpWith7BitByte
     : public Op<OpWith7BitByte, DataLayoutOpInterface::Trait> {
   using Op::Op;
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
 
   static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; }
 


        


More information about the Mlir-commits mailing list