[Mlir-commits] [mlir] 65c9907 - [mlir][ods] Enable emitting getter/setter prefix

Jacques Pienaar llvmlistbot at llvm.org
Thu Oct 14 15:59:04 PDT 2021


Author: Jacques Pienaar
Date: 2021-10-14T15:58:44-07:00
New Revision: 65c9907c809a275e57bd925d1eda5a743a462d20

URL: https://github.com/llvm/llvm-project/commit/65c9907c809a275e57bd925d1eda5a743a462d20
DIFF: https://github.com/llvm/llvm-project/commit/65c9907c809a275e57bd925d1eda5a743a462d20.diff

LOG: [mlir][ods] Enable emitting getter/setter prefix

Allow emitting get & set prefix for accessors generated for ops. If
enabled, then the argument/return/region name gets converted from
snake_case to UpperCamel and prefix added. The attribute also allows
generating both the current "raw" method along with the prefix'd one to
make it easier to stage changes.

The option is added on the dialect and currently defaults to existing
raw behavior. The expectation is that the staging where both are
generated would be short lived and so optimized to keeping the changes
local/less invasive (it just generates two functions for each accessor
with the same body - most of these internally again call a helper
function). But generation can be optimized if needed.

I'm unsure about OpAdaptor classes as there it is all get methods (it is
a named view into raw data structures), so prefix doesn't add much.

This starts with emitting raw-only form (as current behavior) as
default, then one can opt-in to raw & prefixed, then just prefixed. The
default in OpBase will switch to prefixed-only to be consistent with
MLIR style guide. And the option potentially removed later (considered
enabling specifying prefix but current discussion more pro keeping it
limited and stuck with that).

Also add more explicit checking for pruned functions to avoid emitting
where no function was added (and so avoiding dereferencing nullptr)
during op def/decl generation.

See https://bugs.llvm.org/show_bug.cgi?id=51916 for further discussion.

Differential Revision: https://reviews.llvm.org/D111033

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/Dialect.h
    mlir/lib/TableGen/Dialect.cpp
    mlir/test/mlir-tblgen/op-attribute.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 403f4b7029424..d92c0a80a54fe 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -237,6 +237,11 @@ def IsTupleTypePred : CPred<"$_self.isa<::mlir::TupleType>()">;
 // Dialect definitions
 //===----------------------------------------------------------------------===//
 
+// "Enum" values for emitAccessorPrefix of Dialect.
+defvar kEmitAccessorPrefix_Raw = 0;      // Don't emit any getter/setter prefix.
+defvar kEmitAccessorPrefix_Prefixed = 1; // Only emit with getter/setter prefix.
+defvar kEmitAccessorPrefix_Both = 2;     // Emit without and with prefix.
+
 class Dialect {
   // The name of the dialect.
   string name = ?;
@@ -290,6 +295,17 @@ class Dialect {
 
   // If this dialect overrides the hook for canonicalization patterns.
   bit hasCanonicalizer = 0;
+
+  // Whether to emit raw/with no prefix or format changes, or emit with
+  // accessor with prefix only and UpperCamel suffix or to emit accessors with
+  // both.
+  //
+  // If emitting with prefix is specified then the attribute/operand's
+  // name is converted to UpperCamel from snake_case (which would result in
+  // leaving UpperCamel unchanged while also converting lowerCamel to
+  // UpperCamel) and prefixed with `get` or `set` depending on if it is a getter
+  // or setter.
+  int emitAccessorPrefix = kEmitAccessorPrefix_Raw;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 4c5af8eba7d14..2de0d9b0406eb 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -85,6 +85,10 @@ class Dialect {
   // Returns whether the dialect is defined.
   explicit operator bool() const { return def != nullptr; }
 
+  // Returns how the accessors should be prefixed in dialect.
+  enum class EmitPrefix { Raw = 0, Prefixed = 1, Both = 2 };
+  EmitPrefix getEmitAccessorPrefix() const;
+
 private:
   const llvm::Record *def;
   std::vector<StringRef> dependentDialects;

diff  --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 59e7593c9425a..7b5e89a7e6c98 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/TableGen/Dialect.h"
+#include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
 
 using namespace mlir;
@@ -89,6 +90,13 @@ bool Dialect::hasOperationInterfaceFallback() const {
   return def->getValueAsBit("hasOperationInterfaceFallback");
 }
 
+Dialect::EmitPrefix Dialect::getEmitAccessorPrefix() const {
+  int prefix = def->getValueAsInt("emitAccessorPrefix");
+  if (prefix < 0 || prefix > static_cast<int>(EmitPrefix::Both))
+    PrintFatalError(def->getLoc(), "Invalid accessor prefix value");
+  return static_cast<EmitPrefix>(prefix);
+}
+
 bool Dialect::operator==(const Dialect &other) const {
   return def == other.def;
 }

diff  --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 1656b42cbf3cb..4ed715d5b8e6c 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -130,6 +130,118 @@ def AOp : NS_Op<"a_op", []> {
 // DEF:        ::llvm::ArrayRef<::mlir::NamedAttribute> attributes
 // DEF:      odsState.addAttributes(attributes);
 
+// Test the above but with prefix.
+
+def Test2_Dialect : Dialect {
+  let name = "test2";
+  let cppNamespace = "foobar2";
+  let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
+}
+def AgetOp : Op<Test2_Dialect, "a_get_op", []> {
+  let arguments = (ins
+      SomeAttr:$aAttr,
+      DefaultValuedAttr<SomeAttr, "4.2">:$bAttr,
+      OptionalAttr<SomeAttr>:$cAttr
+  );
+}
+
+// DECL-LABEL: AgetOp declarations
+
+// Test attribute name methods
+// ---
+
+// DECL:      static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames()
+// DECL-NEXT:   static ::llvm::StringRef attrNames[] =
+// DECL-SAME:     {::llvm::StringRef("aAttr"), ::llvm::StringRef("bAttr"), ::llvm::StringRef("cAttr")};
+// DECL-NEXT:   return ::llvm::makeArrayRef(attrNames);
+
+// DECL:      ::mlir::Identifier getAAttrAttrName()
+// DECL-NEXT:      return getAttributeNameForIndex(0);
+// DECL:      ::mlir::Identifier getAAttrAttrName(::mlir::OperationName name)
+// DECL-NEXT:      return getAttributeNameForIndex(name, 0);
+
+// DECL:      ::mlir::Identifier getBAttrAttrName()
+// DECL-NEXT:      return getAttributeNameForIndex(1);
+// DECL:      ::mlir::Identifier getBAttrAttrName(::mlir::OperationName name)
+// DECL-NEXT:      return getAttributeNameForIndex(name, 1);
+
+// DECL:      ::mlir::Identifier getCAttrAttrName()
+// DECL-NEXT:      return getAttributeNameForIndex(2);
+// DECL:      ::mlir::Identifier getCAttrAttrName(::mlir::OperationName name)
+// DECL-NEXT:      return getAttributeNameForIndex(name, 2);
+
+// DEF-LABEL: AgetOp definitions
+
+// Test verify method
+// ---
+
+// DEF:      ::mlir::LogicalResult AgetOpAdaptor::verify
+// DEF:      auto tblgen_aAttr = odsAttrs.get("aAttr");
+// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'");
+// DEF:        if (!((some-condition))) return emitError(loc, "'test2.a_get_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind");
+// DEF:      auto tblgen_bAttr = odsAttrs.get("bAttr");
+// DEF-NEXT: if (tblgen_bAttr) {
+// DEF-NEXT:   if (!((some-condition))) return emitError(loc, "'test2.a_get_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind");
+// DEF:      auto tblgen_cAttr = odsAttrs.get("cAttr");
+// DEF-NEXT: if (tblgen_cAttr) {
+// DEF-NEXT:   if (!((some-condition))) return emitError(loc, "'test2.a_get_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind");
+
+// Test getter methods
+// ---
+
+// DEF:      some-attr-kind AgetOp::getAAttrAttr()
+// DEF-NEXT:   (*this)->getAttr(getAAttrAttrName()).template cast<some-attr-kind>()
+// DEF:      some-return-type AgetOp::getAAttr() {
+// DEF-NEXT:   auto attr = getAAttrAttr()
+// DEF-NEXT:   return attr.some-convert-from-storage();
+
+// DEF:      some-attr-kind AgetOp::getBAttrAttr()
+// DEF-NEXT:   return (*this)->getAttr(getBAttrAttrName()).template dyn_cast_or_null<some-attr-kind>()
+// DEF:      some-return-type AgetOp::getBAttr() {
+// DEF-NEXT:   auto attr = getBAttrAttr();
+// DEF-NEXT:   if (!attr)
+// DEF-NEXT:       return some-const-builder-call(::mlir::Builder((*this)->getContext()), 4.2).some-convert-from-storage();
+// DEF-NEXT:   return attr.some-convert-from-storage();
+
+// DEF:      some-attr-kind AgetOp::getCAttrAttr()
+// DEF-NEXT:   return (*this)->getAttr(getCAttrAttrName()).template dyn_cast_or_null<some-attr-kind>()
+// DEF:      ::llvm::Optional<some-return-type> AgetOp::getCAttr() {
+// DEF-NEXT:   auto attr = getCAttrAttr()
+// DEF-NEXT:   return attr ? ::llvm::Optional<some-return-type>(attr.some-convert-from-storage()) : (::llvm::None);
+
+// Test setter methods
+// ---
+
+// DEF:      void AgetOp::setAAttrAttr(some-attr-kind attr) {
+// DEF-NEXT:   (*this)->setAttr(getAAttrAttrName(), attr);
+// DEF:      void AgetOp::setBAttrAttr(some-attr-kind attr) {
+// DEF-NEXT:   (*this)->setAttr(getBAttrAttrName(), attr);
+// DEF:      void AgetOp::setCAttrAttr(some-attr-kind attr) {
+// DEF-NEXT:   (*this)->setAttr(getCAttrAttrName(), attr);
+
+// Test remove methods
+// ---
+
+// DEF: ::mlir::Attribute AgetOp::removeCAttrAttr() {
+// DEF-NEXT: return (*this)->removeAttr(getCAttrAttrName());
+
+// Test build methods
+// ---
+
+// DEF:      void AgetOp::build(
+// DEF:        odsState.addAttribute(getAAttrAttrName(odsState.name), aAttr);
+// DEF:        odsState.addAttribute(getBAttrAttrName(odsState.name), bAttr);
+// DEF:        if (cAttr) {
+// DEF-NEXT:     odsState.addAttribute(getCAttrAttrName(odsState.name), cAttr);
+
+// DEF:      void AgetOp::build(
+// DEF:        some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr
+// DEF:        odsState.addAttribute(getAAttrAttrName(odsState.name), some-const-builder-call(odsBuilder, aAttr));
+
+// DEF:      void AgetOp::build(
+// DEF:        ::llvm::ArrayRef<::mlir::NamedAttribute> attributes
+// DEF:      odsState.addAttributes(attributes);
+
 def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">;
 
 def BOp : NS_Op<"b_op", []> {

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 2c4cd7a27024c..8c78a3bac7714 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -25,6 +25,7 @@
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/Signals.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
@@ -524,6 +525,73 @@ void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
 
 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
 
+// Helper to return the names for accessor.
+static SmallVector<std::string, 2>
+getGetterOrSetterNames(bool isGetter, const Operator &op, StringRef name) {
+  Dialect::EmitPrefix prefixType = op.getDialect().getEmitAccessorPrefix();
+  std::string prefix;
+  if (prefixType != Dialect::EmitPrefix::Raw)
+    prefix = isGetter ? "get" : "set";
+
+  SmallVector<std::string, 2> names;
+  bool rawToo = prefixType == Dialect::EmitPrefix::Both;
+
+  auto skip = [&](StringRef newName) {
+    bool shouldSkip = newName == "getOperands";
+    if (!shouldSkip)
+      return false;
+
+    // This note could be avoided where the final function generated would
+    // have been identical. But preferably in the op definition avoiding using
+    // the generic name and then getting a more specialize type is better.
+    PrintNote(op.getLoc(),
+              "Skipping generation of prefixed accessor `" + newName +
+                  "` as it overlaps with default one; generating raw form (`" +
+                  name + "`) still");
+    return true;
+  };
+
+  if (!prefix.empty()) {
+    names.push_back(prefix + convertToCamelFromSnakeCase(name, true));
+    // Skip cases which would overlap with default ones for now.
+    if (skip(names.back())) {
+      rawToo = true;
+      names.clear();
+    } else {
+      LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName()
+                              << "::" << names.back() << "\");\n"
+                              << "WITH_GETTER(\"" << op.getQualCppClassName()
+                              << "Adaptor::" << names.back() << "\");\n";);
+    }
+  }
+
+  if (prefix.empty() || rawToo)
+    names.push_back(name.str());
+  return names;
+}
+static SmallVector<std::string, 2> getGetterNames(const Operator &op,
+                                                  StringRef name) {
+  return getGetterOrSetterNames(/*isGetter=*/true, op, name);
+}
+static std::string getGetterName(const Operator &op, StringRef name) {
+  return getGetterOrSetterNames(/*isGetter=*/true, op, name).front();
+}
+static SmallVector<std::string, 2> getSetterNames(const Operator &op,
+                                                  StringRef name) {
+  return getGetterOrSetterNames(/*isGetter=*/false, op, name);
+}
+
+static void errorIfPruned(size_t line, OpMethod *m, const Twine &methodName,
+                          const Operator &op) {
+  if (m)
+    return;
+  PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" +
+                                   methodName + "` for " +
+                                   op.getOperationName() + " (from line " +
+                                   Twine(line) + ")");
+}
+#define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
+
 void OpEmitter::genAttrNameGetters() {
   // Enumerate the attribute names of this op, assigning each a relative
   // ordering.
@@ -544,6 +612,7 @@ void OpEmitter::genAttrNameGetters() {
     auto *method = opClass.addMethodAndPrune(
         "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames",
         OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline));
+    ERROR_IF_PRUNED(method, "getAttributeNames", op);
     auto &body = method->body();
     if (attributeNames.empty()) {
       body << "  return {};";
@@ -566,6 +635,7 @@ void OpEmitter::genAttrNameGetters() {
         "::mlir::Identifier", "getAttributeNameForIndex",
         OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline),
         "unsigned", "index");
+    ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
     method->body()
         << "  return getAttributeNameForIndex((*this)->getName(), index);";
   }
@@ -575,6 +645,7 @@ void OpEmitter::genAttrNameGetters() {
         OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline |
                            OpMethod::MP_Static),
         "::mlir::OperationName name, unsigned index");
+    ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
     method->body() << "assert(index < " << attributeNames.size()
                    << " && \"invalid attribute index\");\n"
                       "  return name.getAbstractOperation()"
@@ -585,25 +656,30 @@ void OpEmitter::genAttrNameGetters() {
   // users.
   const char *attrNameMethodBody = "  return getAttributeNameForIndex({0});";
   for (const std::pair<StringRef, unsigned> &attrIt : attributeNames) {
-    std::string methodName = (attrIt.first + "AttrName").str();
+    for (StringRef name : getGetterNames(op, attrIt.first)) {
+      std::string methodName = (name + "AttrName").str();
 
-    // Generate the non-static variant.
-    {
-      auto *method =
-          opClass.addMethodAndPrune("::mlir::Identifier", methodName,
-                                    OpMethod::Property(OpMethod::MP_Inline));
-      method->body() << llvm::formatv(attrNameMethodBody, attrIt.second).str();
-    }
+      // Generate the non-static variant.
+      {
+        auto *method =
+            opClass.addMethodAndPrune("::mlir::Identifier", methodName,
+                                      OpMethod::Property(OpMethod::MP_Inline));
+        ERROR_IF_PRUNED(method, methodName, op);
+        method->body()
+            << llvm::formatv(attrNameMethodBody, attrIt.second).str();
+      }
 
-    // Generate the static variant.
-    {
-      auto *method = opClass.addMethodAndPrune(
-          "::mlir::Identifier", methodName,
-          OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static),
-          "::mlir::OperationName", "name");
-      method->body() << llvm::formatv(attrNameMethodBody,
-                                      "name, " + Twine(attrIt.second))
-                            .str();
+      // Generate the static variant.
+      {
+        auto *method = opClass.addMethodAndPrune(
+            "::mlir::Identifier", methodName,
+            OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static),
+            "::mlir::OperationName", "name");
+        ERROR_IF_PRUNED(method, methodName, op);
+        method->body() << llvm::formatv(attrNameMethodBody,
+                                        "name, " + Twine(attrIt.second))
+                              .str();
+      }
     }
   }
 }
@@ -621,6 +697,7 @@ void OpEmitter::genAttrGetters() {
   // Emit with return type specified.
   auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
     auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
+    ERROR_IF_PRUNED(method, name, op);
     auto &body = method->body();
     body << "  auto attr = " << name << "Attr();\n";
     if (attr.hasDefaultValue()) {
@@ -639,9 +716,9 @@ void OpEmitter::genAttrGetters() {
          << ";\n";
   };
 
-  // Generate raw named accessor type. This is a wrapper class that allows
-  // referring to the attributes via accessors instead of having to use
-  // the string interface for better compile time verification.
+  // Generate named accessor with Attribute return type. This is a wrapper class
+  // that allows referring to the attributes via accessors instead of having to
+  // use the string interface for better compile time verification.
   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
     auto *method =
         opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
@@ -657,11 +734,13 @@ void OpEmitter::genAttrGetters() {
   };
 
   for (const NamedAttribute &namedAttr : op.getAttributes()) {
-    if (namedAttr.attr.isDerivedAttr()) {
-      emitDerivedAttr(namedAttr.name, namedAttr.attr);
-    } else {
-      emitAttrWithStorageType(namedAttr.name, namedAttr.attr);
-      emitAttrWithReturnType(namedAttr.name, namedAttr.attr);
+    for (StringRef name : getGetterNames(op, namedAttr.name)) {
+      if (namedAttr.attr.isDerivedAttr()) {
+        emitDerivedAttr(name, namedAttr.attr);
+      } else {
+        emitAttrWithStorageType(name, namedAttr.attr);
+        emitAttrWithReturnType(name, namedAttr.attr);
+      }
     }
   }
 
@@ -678,6 +757,7 @@ void OpEmitter::genAttrGetters() {
       auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
                                                OpMethod::MP_Static,
                                                "::llvm::StringRef", "name");
+      ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
       auto &body = method->body();
       for (auto namedAttr : derivedAttrs)
         body << "  if (name == \"" << namedAttr.name << "\") return true;\n";
@@ -687,6 +767,7 @@ void OpEmitter::genAttrGetters() {
     {
       auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
                                                "materializeDerivedAttributes");
+      ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
       auto &body = method->body();
 
       auto nonMaterializable =
@@ -734,16 +815,22 @@ void OpEmitter::genAttrSetters() {
   // Generate raw named setter type. This is a wrapper class that allows setting
   // to the attributes via setters instead of having to use the string interface
   // for better compile time verification.
-  auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
-    auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
-                                             attr.getStorageType(), "attr");
+  auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
+                                     Attribute attr) {
+    auto *method = opClass.addMethodAndPrune(
+        "void", (setterName + "Attr").str(), attr.getStorageType(), "attr");
     if (method)
-      method->body() << "  (*this)->setAttr(" << name << "AttrName(), attr);";
+      method->body() << "  (*this)->setAttr(" << getterName
+                     << "AttrName(), attr);";
   };
 
-  for (const NamedAttribute &namedAttr : op.getAttributes())
+  for (const NamedAttribute &namedAttr : op.getAttributes()) {
     if (!namedAttr.attr.isDerivedAttr())
-      emitAttrWithStorageType(namedAttr.name, namedAttr.attr);
+      for (auto names : llvm::zip(getSetterNames(op, namedAttr.name),
+                                  getGetterNames(op, namedAttr.name)))
+        emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
+                                namedAttr.attr);
+  }
 }
 
 void OpEmitter::genOptionalAttrRemovers() {
@@ -756,7 +843,8 @@ void OpEmitter::genOptionalAttrRemovers() {
         "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
     if (!method)
       return;
-    method->body() << "  return (*this)->removeAttr(" << name << "AttrName());";
+    method->body() << "  return (*this)->removeAttr(" << getGetterName(op, name)
+                   << "AttrName());";
   };
 
   for (const NamedAttribute &namedAttr : op.getAttributes())
@@ -846,6 +934,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
 
   auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
                                       "index");
+  ERROR_IF_PRUNED(m, "getODSOperands", op);
   auto &body = m->body();
   body << formatv(valueRangeReturnCode, rangeBeginCall,
                   "getODSOperandIndexAndLength(index)");
@@ -856,31 +945,38 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
     const auto &operand = op.getOperand(i);
     if (operand.name.empty())
       continue;
-    if (operand.isOptional()) {
-      m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
-      m->body()
-          << "  auto operands = getODSOperands(" << i << ");\n"
-          << "  return operands.empty() ? ::mlir::Value() : *operands.begin();";
-    } else if (operand.isVariadicOfVariadic()) {
-      StringRef segmentAttr =
-          operand.constraint.getVariadicOfVariadicSegmentSizeAttr();
-      if (isAdaptor) {
-        m = opClass.addMethodAndPrune("::llvm::SmallVector<::mlir::ValueRange>",
-                                      operand.name);
-        m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
-                                   segmentAttr, i);
-        continue;
-      }
+    for (StringRef name : getGetterNames(op, operand.name)) {
+      if (operand.isOptional()) {
+        m = opClass.addMethodAndPrune("::mlir::Value", name);
+        ERROR_IF_PRUNED(m, name, op);
+        m->body() << "  auto operands = getODSOperands(" << i << ");\n"
+                  << "  return operands.empty() ? ::mlir::Value() : "
+                     "*operands.begin();";
+      } else if (operand.isVariadicOfVariadic()) {
+        StringRef segmentAttr =
+            operand.constraint.getVariadicOfVariadicSegmentSizeAttr();
+        if (isAdaptor) {
+          m = opClass.addMethodAndPrune(
+              "::llvm::SmallVector<::mlir::ValueRange>", name);
+          ERROR_IF_PRUNED(m, name, op);
+          m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
+                                     segmentAttr, i);
+          continue;
+        }
 
-      m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", operand.name);
-      m->body() << "  return getODSOperands(" << i << ").split(" << segmentAttr
-                << "Attr());";
-    } else if (operand.isVariadic()) {
-      m = opClass.addMethodAndPrune(rangeType, operand.name);
-      m->body() << "  return getODSOperands(" << i << ");";
-    } else {
-      m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
-      m->body() << "  return *getODSOperands(" << i << ").begin();";
+        m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", name);
+        ERROR_IF_PRUNED(m, name, op);
+        m->body() << "  return getODSOperands(" << i << ").split("
+                  << segmentAttr << "Attr());";
+      } else if (operand.isVariadic()) {
+        m = opClass.addMethodAndPrune(rangeType, name);
+        ERROR_IF_PRUNED(m, name, op);
+        m->body() << "  return getODSOperands(" << i << ");";
+      } else {
+        m = opClass.addMethodAndPrune("::mlir::Value", name);
+        ERROR_IF_PRUNED(m, name, op);
+        m->body() << "  return *getODSOperands(" << i << ").begin();";
+      }
     }
   }
 }
@@ -912,31 +1008,38 @@ void OpEmitter::genNamedOperandSetters() {
     const auto &operand = op.getOperand(i);
     if (operand.name.empty())
       continue;
-    auto *m = opClass.addMethodAndPrune(operand.isVariadicOfVariadic()
-                                            ? "::mlir::MutableOperandRangeRange"
-                                            : "::mlir::MutableOperandRange",
-                                        (operand.name + "Mutable").str());
-    auto &body = m->body();
-    body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
-         << "  auto mutableRange = ::mlir::MutableOperandRange(getOperation(), "
-            "range.first, range.second";
-    if (attrSizedOperands)
-      body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
-           << "u, *getOperation()->getAttrDictionary().getNamed("
-              "operand_segment_sizesAttrName()))";
-    body << ");\n";
-
-    // If this operand is a nested variadic, we split the range into a
-    // MutableOperandRangeRange that provides a range over all of the
-    // sub-ranges.
-    if (operand.isVariadicOfVariadic()) {
-      body << "  return "
-              "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
-           << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
-           << "AttrName()));\n";
-    } else {
-      // Otherwise, we use the full range directly.
-      body << "  return mutableRange;\n";
+    for (StringRef name : getGetterNames(op, operand.name)) {
+      auto *m = opClass.addMethodAndPrune(
+          operand.isVariadicOfVariadic() ? "::mlir::MutableOperandRangeRange"
+                                         : "::mlir::MutableOperandRange",
+          (name + "Mutable").str());
+      ERROR_IF_PRUNED(m, name, op);
+      auto &body = m->body();
+      body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
+           << "  auto mutableRange = "
+              "::mlir::MutableOperandRange(getOperation(), "
+              "range.first, range.second";
+      if (attrSizedOperands)
+        body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
+             << "u, *getOperation()->getAttrDictionary().getNamed("
+                "operand_segment_sizesAttrName()))";
+      body << ");\n";
+
+      // If this operand is a nested variadic, we split the range into a
+      // MutableOperandRangeRange that provides a range over all of the
+      // sub-ranges.
+      if (operand.isVariadicOfVariadic()) {
+        //
+        body << "  return "
+                "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
+             << getGetterName(
+                    op,
+                    operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
+             << "AttrName()));\n";
+      } else {
+        // Otherwise, we use the full range directly.
+        body << "  return mutableRange;\n";
+      }
     }
   }
 }
@@ -985,6 +1088,7 @@ void OpEmitter::genNamedResultGetters() {
 
   auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
                                       "getODSResults", "unsigned", "index");
+  ERROR_IF_PRUNED(m, "getODSResults", op);
   m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
                        "getODSResultIndexAndLength(index)");
 
@@ -992,18 +1096,22 @@ void OpEmitter::genNamedResultGetters() {
     const auto &result = op.getResult(i);
     if (result.name.empty())
       continue;
-    if (result.isOptional()) {
-      m = opClass.addMethodAndPrune("::mlir::Value", result.name);
-      m->body()
-          << "  auto results = getODSResults(" << i << ");\n"
-          << "  return results.empty() ? ::mlir::Value() : *results.begin();";
-    } else if (result.isVariadic()) {
-      m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
-                                    result.name);
-      m->body() << "  return getODSResults(" << i << ");";
-    } else {
-      m = opClass.addMethodAndPrune("::mlir::Value", result.name);
-      m->body() << "  return *getODSResults(" << i << ").begin();";
+    for (StringRef name : getGetterNames(op, result.name)) {
+      if (result.isOptional()) {
+        m = opClass.addMethodAndPrune("::mlir::Value", name);
+        ERROR_IF_PRUNED(m, name, op);
+        m->body()
+            << "  auto results = getODSResults(" << i << ");\n"
+            << "  return results.empty() ? ::mlir::Value() : *results.begin();";
+      } else if (result.isVariadic()) {
+        m = opClass.addMethodAndPrune("::mlir::Operation::result_range", name);
+        ERROR_IF_PRUNED(m, name, op);
+        m->body() << "  return getODSResults(" << i << ");";
+      } else {
+        m = opClass.addMethodAndPrune("::mlir::Value", name);
+        ERROR_IF_PRUNED(m, name, op);
+        m->body() << "  return *getODSResults(" << i << ").begin();";
+      }
     }
   }
 }
@@ -1015,17 +1123,21 @@ void OpEmitter::genNamedRegionGetters() {
     if (region.name.empty())
       continue;
 
-    // Generate the accessors for a variadic region.
-    if (region.isVariadic()) {
-      auto *m = opClass.addMethodAndPrune(
-          "::mlir::MutableArrayRef<::mlir::Region>", region.name);
-      m->body() << formatv("  return (*this)->getRegions().drop_front({0});",
-                           i);
-      continue;
-    }
+    for (StringRef name : getGetterNames(op, region.name)) {
+      // Generate the accessors for a variadic region.
+      if (region.isVariadic()) {
+        auto *m = opClass.addMethodAndPrune(
+            "::mlir::MutableArrayRef<::mlir::Region>", name);
+        ERROR_IF_PRUNED(m, name, op);
+        m->body() << formatv("  return (*this)->getRegions().drop_front({0});",
+                             i);
+        continue;
+      }
 
-    auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name);
-    m->body() << formatv("  return (*this)->getRegion({0});", i);
+      auto *m = opClass.addMethodAndPrune("::mlir::Region &", name);
+      ERROR_IF_PRUNED(m, name, op);
+      m->body() << formatv("  return (*this)->getRegion({0});", i);
+    }
   }
 }
 
@@ -1036,19 +1148,22 @@ void OpEmitter::genNamedSuccessorGetters() {
     if (successor.name.empty())
       continue;
 
-    // Generate the accessors for a variadic successor list.
-    if (successor.isVariadic()) {
-      auto *m =
-          opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name);
-      m->body() << formatv(
-          "  return {std::next((*this)->successor_begin(), {0}), "
-          "(*this)->successor_end()};",
-          i);
-      continue;
-    }
+    for (StringRef name : getGetterNames(op, successor.name)) {
+      // Generate the accessors for a variadic successor list.
+      if (successor.isVariadic()) {
+        auto *m = opClass.addMethodAndPrune("::mlir::SuccessorRange", name);
+        ERROR_IF_PRUNED(m, name, op);
+        m->body() << formatv(
+            "  return {std::next((*this)->successor_begin(), {0}), "
+            "(*this)->successor_end()};",
+            i);
+        continue;
+      }
 
-    auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name);
-    m->body() << formatv("  return (*this)->getSuccessor({0});", i);
+      auto *m = opClass.addMethodAndPrune("::mlir::Block *", name);
+      ERROR_IF_PRUNED(m, name, op);
+      m->body() << formatv("  return (*this)->getSuccessor({0});", i);
+    }
   }
 }
 
@@ -1315,8 +1430,8 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
   std::string resultType;
   const auto &namedAttr = op.getAttribute(0);
 
-  body << "  auto attrName = " << namedAttr.name << "AttrName("
-       << builderOpState
+  body << "  auto attrName = " << getGetterName(op, namedAttr.name)
+       << "AttrName(" << builderOpState
        << ".name);\n"
           "  for (auto attr : attributes) {\n"
           "    if (attr.first != attrName) continue;\n";
@@ -1379,6 +1494,8 @@ void OpEmitter::genBuilder() {
         body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
     auto *method =
         opClass.addMethodAndPrune("void", "build", properties, paramStr);
+    if (body)
+      ERROR_IF_PRUNED(method, "build", op);
 
     FmtContext fctx;
     fctx.withBuilder(odsBuilder);
@@ -1634,7 +1751,8 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
            << "    for (::mlir::ValueRange range : " << argName << ")\n"
            << "      rangeSegments.push_back(range.size());\n"
            << "    " << builderOpState << ".addAttribute("
-           << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
+           << getGetterName(
+                  op, operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
            << "AttrName(" << builderOpState << ".name), " << odsBuilder
            << ".getI32TensorAttr(rangeSegments));"
            << "  }\n";
@@ -1703,10 +1821,11 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
       std::string value =
           std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
       body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
-                      builderOpState, namedAttr.name, value);
+                      builderOpState, getGetterName(op, namedAttr.name), value);
     } else {
-      body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {1});\n",
-                      builderOpState, namedAttr.name);
+      body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
+                      builderOpState, getGetterName(op, namedAttr.name),
+                      namedAttr.name);
     }
     if (emitNotNullCheck)
       body << "  }\n";
@@ -1736,9 +1855,10 @@ void OpEmitter::genCanonicalizerDecls() {
     SmallVector<OpMethodParameter, 2> paramList;
     paramList.emplace_back(op.getCppClassName(), "op");
     paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
-    opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize",
-                              OpMethod::MP_StaticDeclaration,
-                              std::move(paramList));
+    auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize",
+                                        OpMethod::MP_StaticDeclaration,
+                                        std::move(paramList));
+    ERROR_IF_PRUNED(m, "canonicalize", op);
   }
 
   // We get a prototype for 'getCanonicalizationPatterns' if requested directly
@@ -1761,8 +1881,10 @@ void OpEmitter::genCanonicalizerDecls() {
       "void", "getCanonicalizationPatterns", kind, std::move(paramList));
 
   // If synthesizing the method, fill it it.
-  if (hasBody)
+  if (hasBody) {
+    ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op);
     method->body() << "  results.add(canonicalize);\n";
+  }
 }
 
 void OpEmitter::genFolderDecls() {
@@ -1771,16 +1893,19 @@ void OpEmitter::genFolderDecls() {
 
   if (def.getValueAsBit("hasFolder")) {
     if (hasSingleResult) {
-      opClass.addMethodAndPrune(
+      auto *m = opClass.addMethodAndPrune(
           "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
           "::llvm::ArrayRef<::mlir::Attribute>", "operands");
+      ERROR_IF_PRUNED(m, "operands", op);
     } else {
       SmallVector<OpMethodParameter, 2> paramList;
       paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
       paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
                              "results");
-      opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
-                                OpMethod::MP_Declaration, std::move(paramList));
+      auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
+                                          OpMethod::MP_Declaration,
+                                          std::move(paramList));
+      ERROR_IF_PRUNED(m, "fold", op);
     }
   }
 }
@@ -1803,7 +1928,9 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
     if (method.getDefaultImplementation() &&
         !alwaysDeclaredMethods.count(method.getName()))
       continue;
-    genOpInterfaceMethod(method);
+    // Interface methods are allowed to overlap with existing methods, so don't
+    // check if pruned.
+    (void)genOpInterfaceMethod(method);
   }
 }
 
@@ -1895,6 +2022,7 @@ void OpEmitter::genSideEffectInterfaceMethods() {
                            .str();
     auto *getEffects =
         opClass.addMethodAndPrune("void", "getEffects", type, "effects");
+    ERROR_IF_PRUNED(getEffects, "getEffects", op);
     auto &body = getEffects->body();
 
     // Add effect instances for each of the locations marked on the operation.
@@ -1944,6 +2072,7 @@ void OpEmitter::genTypeInterfaceMethods() {
     assert(0 && "unable to find inferReturnTypes interface method");
     return nullptr;
   }();
+  ERROR_IF_PRUNED(method, "inferReturnTypes", op);
   auto &body = method->body();
   body << "  inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
 
@@ -1989,6 +2118,7 @@ void OpEmitter::genParser() {
   auto *method =
       opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
                                 OpMethod::MP_Static, std::move(paramList));
+  ERROR_IF_PRUNED(method, "parse", op);
 
   FmtContext fctx;
   fctx.addSubst("cppClass", opClass.getClassName());
@@ -2007,6 +2137,7 @@ void OpEmitter::genPrinter() {
 
   auto *method =
       opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
+  ERROR_IF_PRUNED(method, "print", op);
   FmtContext fctx;
   fctx.addSubst("cppClass", opClass.getClassName());
   auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
@@ -2015,6 +2146,7 @@ void OpEmitter::genPrinter() {
 
 void OpEmitter::genVerifier() {
   auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
+  ERROR_IF_PRUNED(method, "verify", op);
   auto &body = method->body();
   body << "  if (::mlir::failed(" << op.getAdaptorName()
        << "(*this).verify((*this)->getLoc()))) "
@@ -2274,6 +2406,7 @@ void OpEmitter::genOpNameGetter() {
   auto *method = opClass.addMethodAndPrune(
       "::llvm::StringLiteral", "getOperationName",
       OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
+  ERROR_IF_PRUNED(method, "getOperationName", op);
   method->body() << "  return ::llvm::StringLiteral(\"" << op.getOperationName()
                  << "\");";
 }
@@ -2301,6 +2434,7 @@ void OpEmitter::genOpAsmInterface() {
   // Generate the right accessor for the number of results.
   auto *method = opClass.addMethodAndPrune(
       "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
+  ERROR_IF_PRUNED(method, "getAsmResultNames", op);
   auto &body = method->body();
   for (int i = 0; i != numResults; ++i) {
     body << "  auto resultGroup" << i << " = getODSResults(" << i << ");\n"
@@ -2365,6 +2499,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
 
   {
     auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands");
+    ERROR_IF_PRUNED(m, "getOperands", op);
     m->body() << "  return odsOperands;";
   }
   std::string sizeAttrInit =
@@ -2380,7 +2515,9 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
 
   auto emitAttr = [&](StringRef name, Attribute attr) {
-    auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body();
+    auto *method = adaptor.addMethodAndPrune(attr.getStorageType(), name);
+    ERROR_IF_PRUNED(method, "Adaptor::" + name, op);
+    auto &body = method->body();
     body << "  assert(odsAttrs && \"no attributes when constructing adapter\");"
          << "\n  " << attr.getStorageType() << " attr = "
          << "odsAttrs.get(\"" << name << "\").";
@@ -2404,6 +2541,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
   {
     auto *m =
         adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes");
+    ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
     m->body() << "  return odsAttrs;";
   }
   for (auto &namedAttr : op.getAttributes()) {
@@ -2416,6 +2554,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
   unsigned numRegions = op.getNumRegions();
   if (numRegions > 0) {
     auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions");
+    ERROR_IF_PRUNED(m, "Adaptor::getRegions", op);
     m->body() << "  return odsRegions;";
   }
   for (unsigned i = 0; i < numRegions; ++i) {
@@ -2426,11 +2565,13 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
     // Generate the accessors for a variadic region.
     if (region.isVariadic()) {
       auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name);
+      ERROR_IF_PRUNED(m, "Adaptor::" + region.name, op);
       m->body() << formatv("  return odsRegions.drop_front({0});", i);
       continue;
     }
 
     auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name);
+    ERROR_IF_PRUNED(m, "Adaptor::" + region.name, op);
     m->body() << formatv("  return *odsRegions[{0}];", i);
   }
 
@@ -2441,6 +2582,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
 void OpOperandAdaptorEmitter::addVerification() {
   auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
                                            "::mlir::Location", "loc");
+  ERROR_IF_PRUNED(method, "verify", op);
   auto &body = method->body();
 
   const char *checkAttrSizedValueSegmentsCode = R"(


        


More information about the Mlir-commits mailing list