[Mlir-commits] [mlir] 82140ad - [mlir] Add method to populate default attributes

Jacques Pienaar llvmlistbot at llvm.org
Fri Jul 8 11:31:18 PDT 2022


Author: Jacques Pienaar
Date: 2022-07-08T11:31:13-07:00
New Revision: 82140ad72814f5544e68643c64e528fa5b734fad

URL: https://github.com/llvm/llvm-project/commit/82140ad72814f5544e68643c64e528fa5b734fad
DIFF: https://github.com/llvm/llvm-project/commit/82140ad72814f5544e68643c64e528fa5b734fad.diff

LOG: [mlir] Add method to populate default attributes

Previously default attributes were only usable by way of the ODS generated
accessors, but this was undesirable as
1. The ODS getters could construct Attribute each get request;
2. For non-C++ uses this would require either duplicating some of tee default
   attribute generating or generating additional bindings to generate methods;
3. Accessing op.getAttr("foo") and op.getFoo() would return different results;
Generate method to populate default attributes that can be used to address
these.

This merely adds this facility but does not employ by default on any path.

Differential Revision: https://reviews.llvm.org/D128962

Added: 
    

Modified: 
    mlir/include/mlir/IR/ExtensibleDialect.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/ExtensibleDialect.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/unittests/IR/OperationSupportTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index ee83ef5576df9..65f9f190f57ec 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -431,6 +431,7 @@ class DynamicOpDefinition {
   OperationName::PrintAssemblyFn printFn;
   OperationName::FoldHookFn foldHookFn;
   OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
+  OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn;
 
   friend ExtensibleDialect;
 };

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index c98993a2cb93b..81b0603fa8995 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -182,6 +182,10 @@ class OpState {
   static void getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {}
 
+  /// This hook populates any unset default attrs.
+  static void populateDefaultAttrs(const RegisteredOperationName &,
+                                   NamedAttrList &) {}
+
 protected:
   /// If the concrete type didn't implement a custom verifier hook, just fall
   /// back to this one which accepts everything.
@@ -1869,6 +1873,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
     OpState::printOpName(op, p, defaultDialect);
     return cast<ConcreteType>(op).print(p);
   }
+  /// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
+  static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
+    return ConcreteType::populateDefaultAttrs;
+  }
   /// Implementation of `VerifyInvariantsFn` OperationName hook.
   static LogicalResult verifyInvariants(Operation *op) {
     static_assert(hasNoDataMembers(),

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index d6a231b1941e7..70509bdb2c515 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -467,6 +467,15 @@ class alignas(8) Operation final
     setAttrs(attrs.getDictionary(getContext()));
   }
 
+  /// Sets default attributes on unset attributes.
+  void populateDefaultAttrs() {
+    if (auto registered = getRegisteredInfo()) {
+      NamedAttrList attrs(getAttrDictionary());
+      registered->populateDefaultAttrs(attrs);
+      setAttrs(attrs.getDictionary(getContext()));
+    }
+  }
+
   //===--------------------------------------------------------------------===//
   // Blocks
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index de09ca58a4092..2c480d6ca52da 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -36,6 +36,7 @@ class Dialect;
 class DictionaryAttr;
 class ElementsAttr;
 class MutableOperandRangeRange;
+class NamedAttrList;
 class Operation;
 struct OperationState;
 class OpAsmParser;
@@ -69,6 +70,10 @@ class OperationName {
   using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
   using ParseAssemblyFn =
       llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
+  // Note: RegisteredOperationName is passed as reference here as the derived
+  // class is defined below.
+  using PopulateDefaultAttrsFn = llvm::unique_function<void(
+      const RegisteredOperationName &, NamedAttrList &) const>;
   using PrintAssemblyFn =
       llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
   using VerifyInvariantsFn =
@@ -112,6 +117,7 @@ class OperationName {
     GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
     HasTraitFn hasTraitFn;
     ParseAssemblyFn parseAssemblyFn;
+    PopulateDefaultAttrsFn populateDefaultAttrsFn;
     PrintAssemblyFn printAssemblyFn;
     VerifyInvariantsFn verifyInvariantsFn;
     VerifyRegionInvariantsFn verifyRegionInvariantsFn;
@@ -254,7 +260,8 @@ class RegisteredOperationName : public OperationName {
            T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
            T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
            T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
-           T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames());
+           T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(),
+           T::getPopulateDefaultAttrsFn());
   }
   /// The use of this method is in general discouraged in favor of
   /// 'insert<CustomOp>(dialect)'.
@@ -266,7 +273,8 @@ class RegisteredOperationName : public OperationName {
          FoldHookFn &&foldHook,
          GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
          detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
-         ArrayRef<StringRef> attrNames);
+         ArrayRef<StringRef> attrNames,
+         PopulateDefaultAttrsFn &&populateDefaultAttrs);
 
   /// Return the dialect this operation is registered to.
   Dialect &getDialect() const { return *impl->dialect; }
@@ -364,6 +372,10 @@ class RegisteredOperationName : public OperationName {
     return impl->attributeNames;
   }
 
+  /// This hook implements the method to populate defaults attributes that are
+  /// unset.
+  void populateDefaultAttrs(NamedAttrList &attrs) const;
+
   /// Represent the operation name as an opaque pointer. (Used to support
   /// PointerLikeTypeTraits).
   static RegisteredOperationName getFromOpaquePointer(const void *pointer) {

diff  --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp
index 3e96b83031d29..0dcc971ca2e5a 100644
--- a/mlir/lib/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/IR/ExtensibleDialect.cpp
@@ -447,7 +447,8 @@ void ExtensibleDialect::registerDynamicOp(
       std::move(op->printFn), std::move(op->verifyFn),
       std::move(op->verifyRegionFn), std::move(op->foldHookFn),
       std::move(op->getCanonicalizationPatternsFn),
-      detail::InterfaceMap::get<>(), std::move(hasTraitFn), {});
+      detail::InterfaceMap::get<>(), std::move(hasTraitFn), {},
+      std::move(op->getPopulateDefaultAttrsFn));
 }
 
 bool ExtensibleDialect::classof(const Dialect *dialect) {

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2a84362635ac6..273faa89b826c 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -707,6 +707,10 @@ RegisteredOperationName::parseAssembly(OpAsmParser &parser,
   return impl->parseAssemblyFn(parser, result);
 }
 
+void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
+  impl->populateDefaultAttrsFn(*this, attrs);
+}
+
 void RegisteredOperationName::insert(
     StringRef name, Dialect &dialect, TypeID typeID,
     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
@@ -714,7 +718,8 @@ void RegisteredOperationName::insert(
     VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
-    ArrayRef<StringRef> attrNames) {
+    ArrayRef<StringRef> attrNames,
+    PopulateDefaultAttrsFn &&populateDefaultAttrs) {
   MLIRContext *ctx = dialect.getContext();
   auto &ctxImpl = ctx->getImpl();
   assert(ctxImpl.multiThreadedExecutionContext == 0 &&
@@ -769,6 +774,7 @@ void RegisteredOperationName::insert(
   impl.verifyInvariantsFn = std::move(verifyInvariants);
   impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
   impl.attributeNames = cachedAttrNames;
+  impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index dd3487bb2b0cd..3330fdf3c28a4 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -430,6 +430,9 @@ class OpEmitter {
   // Generates getters for named successors.
   void genNamedSuccessorGetters();
 
+  // Generates the method to populate default attributes.
+  void genPopulateDefaultAttributes();
+
   // Generates builder methods for the operation.
   void genBuilder();
 
@@ -823,6 +826,7 @@ OpEmitter::OpEmitter(const Operator &op,
   genAttrSetters();
   genOptionalAttrRemovers();
   genBuilder();
+  genPopulateDefaultAttributes();
   genParser();
   genPrinter();
   genVerifier();
@@ -1587,6 +1591,45 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
        << llvm::join(resultTypes, ", ") << "});\n\n";
 }
 
+void OpEmitter::genPopulateDefaultAttributes() {
+  // All done if no attributes have default values.
+  if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) {
+        return !named.attr.hasDefaultValue();
+      }))
+    return;
+
+  SmallVector<MethodParameter> paramList;
+  paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName");
+  paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
+  auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
+  ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
+  auto &body = m->body();
+  body.indent();
+
+  // Set default attributes that are unset.
+  body << "auto attrNames = opName.getAttributeNames();\n";
+  body << "::mlir::Builder " << odsBuilder
+       << "(attrNames.front().getContext());\n";
+  StringMap<int> attrIndex;
+  for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) {
+    attrIndex[it.value().first] = it.index();
+  }
+  for (const NamedAttribute &namedAttr : op.getAttributes()) {
+    auto &attr = namedAttr.attr;
+    if (!attr.hasDefaultValue())
+      continue;
+    auto index = attrIndex[namedAttr.name];
+    body << "if (!attributes.get(attrNames[" << index << "])) {\n";
+    FmtContext fctx;
+    fctx.withBuilder(odsBuilder);
+    std::string defaultValue = std::string(
+        tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+    body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n",
+                             index, defaultValue);
+    body.unindent() << "}\n";
+  }
+}
+
 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
   SmallVector<MethodParameter> paramList;
   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
@@ -1869,7 +1912,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
   auto numResults = op.getNumResults();
   resultTypeNames.reserve(numResults);
 
-  paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
+  paramList.emplace_back("::mlir::OpBuilder &", odsBuilder);
   paramList.emplace_back("::mlir::OperationState &", builderOpState);
 
   switch (typeParamKind) {
@@ -2879,7 +2922,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
       body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
     }
-    body << "  return attr;\n";
+    body << "return attr;\n";
   };
 
   {

diff  --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 2511a5d3b6bfd..b8cbc6d1e6c0c 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/OperationSupport.h"
+#include "../../test/lib/Dialect/Test/TestDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/BitVector.h"
@@ -271,4 +272,22 @@ TEST(NamedAttrListTest, TestAppendAssign) {
   attrs.assign({});
   ASSERT_TRUE(attrs.empty());
 }
+
+TEST(OperandStorageTest, PopulateDefaultAttrs) {
+  MLIRContext context;
+  context.getOrLoadDialect<test::TestDialect>();
+  Builder builder(&context);
+
+  OpBuilder b(&context);
+  auto req1 = b.getI32IntegerAttr(10);
+  auto req2 = b.getI32IntegerAttr(60);
+  Operation *op = b.create<test::OpAttrMatch1>(b.getUnknownLoc(), req1, nullptr,
+                                               nullptr, req2);
+  EXPECT_EQ(op->getAttr("default_valued_attr"), nullptr);
+  op->populateDefaultAttrs();
+  auto opt = op->getAttr("default_valued_attr");
+  EXPECT_NE(opt, nullptr) << *op;
+
+  op->destroy();
+}
 } // namespace


        


More information about the Mlir-commits mailing list