[Mlir-commits] [mlir] 7fe2294 - [mlir][ods] Allow specifying return types of builders

Jeff Niu llvmlistbot at llvm.org
Fri Jul 15 18:00:42 PDT 2022


Author: Jeff Niu
Date: 2022-07-15T18:00:35-07:00
New Revision: 7fe2294e474bfcc4f9eee7828738d145836f9144

URL: https://github.com/llvm/llvm-project/commit/7fe2294e474bfcc4f9eee7828738d145836f9144
DIFF: https://github.com/llvm/llvm-project/commit/7fe2294e474bfcc4f9eee7828738d145836f9144.diff

LOG: [mlir][ods] Allow specifying return types of builders

This patch allows custom attribute and type builders to return
something other than the C++ type of the attribute or type.

This is useful for attributes or types that may perform extra work during
construction (e.g. canonicalization) that could result in a different
kind of attribute or type being returned.

Reviewed By: rriddle, lattner

Differential Revision: https://reviews.llvm.org/D129792

Added: 
    

Modified: 
    mlir/docs/AttributesAndTypes.md
    mlir/include/mlir/IR/AttrTypeBase.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/include/mlir/TableGen/AttrOrTypeDef.h
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/test/mlir-tblgen/attrdefs.td
    mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index e1ad13952fe0b..e7cd74c9315af 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -347,6 +347,7 @@ def MyType : ... {
       // its arguments.
       return Base::get(typeParam.getContext(), ...);
     }]>,
+    TypeBuilder<(ins "int":$intParam), [{}], "IntegerType">,
   ];
 }
 ```
@@ -461,6 +462,28 @@ the builder used `TypeBuilderWithInferredContext` implies that the context
 parameter is not necessary as it can be inferred from the arguments to the
 builder.
 
+The fifth builder will generate the declaration of a builder method with a
+custom return type, like:
+
+```tablegen
+  let builders = [
+    TypeBuilder<(ins "int":$intParam), [{}], "IntegerType">,
+  ]
+```
+
+```c++
+class MyType : /*...*/ {
+  /*...*/
+  static IntegerType get(::mlir::MLIRContext *context, int intParam);
+
+};
+```
+
+This generates a builder declaration the same as the first three examples, but
+the return type of the builder is user-specified instead of the attribute or
+type class. This is useful for defining builders of attributes and types that
+may fold or canonicalize on construction.
+
 ### Parsing and Printing
 
 If a mnemonic was specified, the `hasCustomAssemblyFormat` and `assemblyFormat`

diff  --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index be8fed544ae1b..d72cb110650c7 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -96,30 +96,38 @@ class PredTypeTrait<string descr, Pred pred> : PredTrait<descr, pred>;
 // This is necessary because the `body` is also used to generate `getChecked`
 // methods, which have a 
diff erent underlying `Base::get*` call.
 //
-class AttrOrTypeBuilder<dag parameters, code bodyCode = ""> {
+class AttrOrTypeBuilder<dag parameters, code bodyCode = "",
+                        string returnTypeStr = ""> {
   dag dagParams = parameters;
   code body = bodyCode;
 
+  // Change the return type of the builder. By default, it is the type of the
+  // attribute or type.
+  string returnType = returnTypeStr;
+
   // The context parameter can be inferred from one of the other parameters and
   // is not implicitly added to the parameter list.
   bit hasInferredContextParam = 0;
 }
-class AttrBuilder<dag parameters, code bodyCode = "">
-  : AttrOrTypeBuilder<parameters, bodyCode>;
-class TypeBuilder<dag parameters, code bodyCode = "">
-  : AttrOrTypeBuilder<parameters, bodyCode>;
+class AttrBuilder<dag parameters, code bodyCode = "", string returnType = "">
+  : AttrOrTypeBuilder<parameters, bodyCode, returnType>;
+class TypeBuilder<dag parameters, code bodyCode = "", string returnType = "">
+  : AttrOrTypeBuilder<parameters, bodyCode, returnType>;
 
 // A class of AttrOrTypeBuilder that is able to infer the MLIRContext parameter
 // from one of the other builder parameters. Instances of this builder do not
 // have `MLIRContext *` implicitly added to the parameter list.
-class AttrOrTypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
-  : TypeBuilder<parameters, bodyCode> {
+class AttrOrTypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
+                                           string returnType = "">
+  : TypeBuilder<parameters, bodyCode, returnType> {
   let hasInferredContextParam = 1;
 }
-class AttrBuilderWithInferredContext<dag parameters, code bodyCode = "">
-  : AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
-class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
-  : AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
+class AttrBuilderWithInferredContext<dag parameters, code bodyCode = "",
+                                     string returnType = "">
+  : AttrOrTypeBuilderWithInferredContext<parameters, bodyCode, returnType>;
+class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
+                                     string returnType = "">
+  : AttrOrTypeBuilderWithInferredContext<parameters, bodyCode, returnType>;
 
 //===----------------------------------------------------------------------===//
 // Definitions

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index bfb8adda2871f..a69e3b4a4e074 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -792,14 +792,14 @@ class AsmParser {
   /// unlike `OpBuilder::getType`, this method does not implicitly insert a
   /// context parameter.
   template <typename T, typename... ParamsT>
-  T getChecked(SMLoc loc, ParamsT &&...params) {
+  auto getChecked(SMLoc loc, ParamsT &&...params) {
     return T::getChecked([&] { return emitError(loc); },
                          std::forward<ParamsT>(params)...);
   }
   /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
   /// errors.
   template <typename T, typename... ParamsT>
-  T getChecked(ParamsT &&...params) {
+  auto getChecked(ParamsT &&...params) {
     return T::getChecked([&] { return emitError(getNameLoc()); },
                          std::forward<ParamsT>(params)...);
   }

diff  --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 043d63794bdf0..d4bab72f485a0 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -37,6 +37,9 @@ class AttrOrTypeBuilder : public Builder {
 public:
   using Builder::Builder;
 
+  /// Returns an optional builder return type.
+  Optional<StringRef> getReturnType() const;
+
   /// Returns true if this builder is able to infer the MLIRContext parameter.
   bool hasInferredContextParameter() const;
 };

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index ddcec88a4dbef..d399c06d2f9c7 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -20,7 +20,11 @@ using namespace mlir::tblgen;
 // AttrOrTypeBuilder
 //===----------------------------------------------------------------------===//
 
-/// Returns true if this builder is able to infer the MLIRContext parameter.
+Optional<StringRef> AttrOrTypeBuilder::getReturnType() const {
+  Optional<StringRef> type = def->getValueAsOptionalString("returnType");
+  return type && !type->empty() ? type : llvm::None;
+}
+
 bool AttrOrTypeBuilder::hasInferredContextParameter() const {
   return def->getValueAsBit("hasInferredContextParam");
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 9fb0efa3d72b3..10c82b74282df 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -223,6 +223,19 @@ def TestAttrSelfTypeParameterFormat
   let assemblyFormat = "`<` $a `>`";
 }
 
+// Test overridding attribute builders with a custom builder.
+def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
+  let mnemonic = "override_builder";
+  let parameters = (ins "int":$a);
+  let assemblyFormat = "`<` $a `>`";
+
+  let skipDefaultBuilders = 1;
+  let genVerifyDecl = 1;
+  let builders = [AttrBuilder<(ins "int":$a), [{
+    return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a);
+  }], "::mlir::Attribute">];
+}
+
 // Test simple extern 1D vector using ElementsAttrInterface.
 def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
     ElementsAttrInterface

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 751e9c1b33e3b..05b48671b5bae 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -55,8 +55,8 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
 // ATTR:   if (odsParser.parseRParen())
 // ATTR:     return {};
 // ATTR:   return TestAAttr::get(odsParser.getContext(),
-// ATTR:                         (*_result_value),
-// ATTR:                         (*_result_complex));
+// ATTR:                         IntegerAttr((*_result_value)),
+// ATTR:                         TestParamA((*_result_complex)));
 // ATTR: }
 
 // ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -114,8 +114,8 @@ def AttrA : TestAttr<"TestA"> {
 // ATTR:       return {};
 // ATTR:   }
 // ATTR:   return TestBAttr::get(odsParser.getContext(),
-// ATTR:                         (*_result_v0),
-// ATTR:                         (*_result_v1));
+// ATTR:                         TestParamA((*_result_v0)),
+// ATTR:                         TestParamB((*_result_v1)));
 // ATTR: }
 
 // ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -151,8 +151,8 @@ def AttrB : TestAttr<"TestB"> {
 // ATTR:   if (::mlir::failed(_result_v1))
 // ATTR:     return {};
 // ATTR:   return TestFAttr::get(odsParser.getContext(),
-// ATTR:     (*_result_v0),
-// ATTR:     (*_result_v1));
+// ATTR:     int((*_result_v0)),
+// ATTR:     int((*_result_v1)));
 // ATTR: }
 
 def AttrC : TestAttr<"TestF"> {
@@ -278,10 +278,10 @@ def TypeA : TestType<"TestC"> {
 // TYPE:   if (::mlir::failed(_result_v3))
 // TYPE:     return {};
 // TYPE:   return TestDType::get(odsParser.getContext(),
-// TYPE:                         (*_result_v0),
-// TYPE:                         (*_result_v1),
-// TYPE:                         (*_result_v2),
-// TYPE:                         (*_result_v3));
+// TYPE:                         TestParamC((*_result_v0)),
+// TYPE:                         TestParamD((*_result_v1)),
+// TYPE:                         TestParamC((*_result_v2)),
+// TYPE:                         TestParamD((*_result_v3)));
 // TYPE: }
 
 // TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -369,10 +369,10 @@ def TypeB : TestType<"TestD"> {
 // TYPE:       return {};
 // TYPE:   }
 // TYPE:   return TestEType::get(odsParser.getContext(),
-// TYPE:     (*_result_v0),
-// TYPE:     (*_result_v1),
-// TYPE:     (*_result_v2),
-// TYPE:     (*_result_v3));
+// TYPE:     IntegerAttr((*_result_v0)),
+// TYPE:     IntegerAttr((*_result_v1)),
+// TYPE:     IntegerAttr((*_result_v2)),
+// TYPE:     IntegerAttr((*_result_v3)));
 // TYPE: }
 
 // TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const {

diff  --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 50dba45c63d0b..8de12eb2b3006 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -31,9 +31,9 @@ include "mlir/IR/OpBase.td"
 // DEF-NEXT: .Case(::test::IndexAttr::getMnemonic()
 // DEF-NEXT:   value = ::test::IndexAttr::parse(parser, type);
 // DEF-NEXT:   return ::mlir::success(!!value);
-// DEF: .Default([&](llvm::StringRef keyword, 
+// DEF: .Default([&](llvm::StringRef keyword,
 // DEF-NEXT:   *mnemonic = keyword;
-// DEF-NEXT:   return llvm::None; 
+// DEF-NEXT:   return llvm::None;
 
 def Test_Dialect: Dialect {
 // DECL-NOT: TestDialect
@@ -148,3 +148,13 @@ def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> {
 // DEF: ParamWithAccessorTypeAttrStorage
 // DEF: ParamWithAccessorTypeAttrStorage(std::string param)
 // DEF: StringRef ParamWithAccessorTypeAttr::getParam()
+
+def G_BuilderWithReturnTypeAttr : TestAttr<"BuilderWithReturnType"> {
+  let parameters = (ins "int":$a);
+  let genVerifyDecl = 1;
+  let builders = [AttrBuilder<(ins), [{ return {}; }], "::mlir::Attribute">];
+}
+
+// DECL-LABEL: class BuilderWithReturnTypeAttr
+// DECL: ::mlir::Attribute get(
+// DECL: ::mlir::Attribute getChecked(

diff  --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
index 4ccf3bd624094..ee92ea06a208c 100644
--- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
@@ -13,3 +13,9 @@ func.func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [
 // CHECK-LABEL: @qualifiedAttr()
 // CHECK-SAME: #test.cmpnd_nested_outer_qual<i #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>>
 func.func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qual<i #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>>}
+
+// CHECK-LABEL: @overriddenAttr
+// CHECK-SAME: foo = 5 : index
+func.func private @overriddenAttr() attributes {
+  foo = #test.override_builder<5>
+}

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index db443e9ac0cff..ef2bd4102b6f0 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -348,7 +348,10 @@ getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
 void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
   // Don't emit a body if there isn't one.
   auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
-  Method *m = defCls.addMethod(def.getCppClassName(), "get", props,
+  StringRef returnType = def.getCppClassName();
+  if (Optional<StringRef> builderReturnType = builder.getReturnType())
+    returnType = *builderReturnType;
+  Method *m = defCls.addMethod(returnType, "get", props,
                                getCustomBuilderParams({}, builder));
   if (!builder.getBody())
     return;
@@ -373,8 +376,11 @@ static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
 void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
   // Don't emit a body if there isn't one.
   auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
+  StringRef returnType = def.getCppClassName();
+  if (Optional<StringRef> builderReturnType = builder.getReturnType())
+    returnType = *builderReturnType;
   Method *m = defCls.addMethod(
-      def.getCppClassName(), "getChecked", props,
+      returnType, "getChecked", props,
       getCustomBuilderParams(
           {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
           builder));

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 19ab7f3eb72ce..8321861738d08 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -311,7 +311,9 @@ void DefFormat::genParser(MethodBody &os) {
     } else {
       selfOs << formatv("(*_result_{0})", param.getName());
     }
-    os << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()));
+    os << param.getCppType() << "("
+       << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
+       << ")";
   }
   os << ");";
 }


        


More information about the Mlir-commits mailing list