[Mlir-commits] [mlir] 7fe2294 - [mlir][ods] Allow specifying return types of builders
Jeff Niu
llvmlistbot at llvm.org
Fri Jul 15 18:00:42 PDT 2022
Author: Jeff Niu
Date: 2022-07-15T18:00:35-07:00
New Revision: 7fe2294e474bfcc4f9eee7828738d145836f9144
URL: https://github.com/llvm/llvm-project/commit/7fe2294e474bfcc4f9eee7828738d145836f9144
DIFF: https://github.com/llvm/llvm-project/commit/7fe2294e474bfcc4f9eee7828738d145836f9144.diff
LOG: [mlir][ods] Allow specifying return types of builders
This patch allows custom attribute and type builders to return
something other than the C++ type of the attribute or type.
This is useful for attributes or types that may perform extra work during
construction (e.g. canonicalization) that could result in a different
kind of attribute or type being returned.
Reviewed By: rriddle, lattner
Differential Revision: https://reviews.llvm.org/D129792
Added:
Modified:
mlir/docs/AttributesAndTypes.md
mlir/include/mlir/IR/AttrTypeBase.td
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/mlir-tblgen/attr-or-type-format.td
mlir/test/mlir-tblgen/attrdefs.td
mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index e1ad13952fe0b..e7cd74c9315af 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -347,6 +347,7 @@ def MyType : ... {
// its arguments.
return Base::get(typeParam.getContext(), ...);
}]>,
+ TypeBuilder<(ins "int":$intParam), [{}], "IntegerType">,
];
}
```
@@ -461,6 +462,28 @@ the builder used `TypeBuilderWithInferredContext` implies that the context
parameter is not necessary as it can be inferred from the arguments to the
builder.
+The fifth builder will generate the declaration of a builder method with a
+custom return type, like:
+
+```tablegen
+ let builders = [
+ TypeBuilder<(ins "int":$intParam), [{}], "IntegerType">,
+ ]
+```
+
+```c++
+class MyType : /*...*/ {
+ /*...*/
+ static IntegerType get(::mlir::MLIRContext *context, int intParam);
+
+};
+```
+
+This generates a builder declaration the same as the first three examples, but
+the return type of the builder is user-specified instead of the attribute or
+type class. This is useful for defining builders of attributes and types that
+may fold or canonicalize on construction.
+
### Parsing and Printing
If a mnemonic was specified, the `hasCustomAssemblyFormat` and `assemblyFormat`
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index be8fed544ae1b..d72cb110650c7 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -96,30 +96,38 @@ class PredTypeTrait<string descr, Pred pred> : PredTrait<descr, pred>;
// This is necessary because the `body` is also used to generate `getChecked`
// methods, which have a
diff erent underlying `Base::get*` call.
//
-class AttrOrTypeBuilder<dag parameters, code bodyCode = ""> {
+class AttrOrTypeBuilder<dag parameters, code bodyCode = "",
+ string returnTypeStr = ""> {
dag dagParams = parameters;
code body = bodyCode;
+ // Change the return type of the builder. By default, it is the type of the
+ // attribute or type.
+ string returnType = returnTypeStr;
+
// The context parameter can be inferred from one of the other parameters and
// is not implicitly added to the parameter list.
bit hasInferredContextParam = 0;
}
-class AttrBuilder<dag parameters, code bodyCode = "">
- : AttrOrTypeBuilder<parameters, bodyCode>;
-class TypeBuilder<dag parameters, code bodyCode = "">
- : AttrOrTypeBuilder<parameters, bodyCode>;
+class AttrBuilder<dag parameters, code bodyCode = "", string returnType = "">
+ : AttrOrTypeBuilder<parameters, bodyCode, returnType>;
+class TypeBuilder<dag parameters, code bodyCode = "", string returnType = "">
+ : AttrOrTypeBuilder<parameters, bodyCode, returnType>;
// A class of AttrOrTypeBuilder that is able to infer the MLIRContext parameter
// from one of the other builder parameters. Instances of this builder do not
// have `MLIRContext *` implicitly added to the parameter list.
-class AttrOrTypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
- : TypeBuilder<parameters, bodyCode> {
+class AttrOrTypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
+ string returnType = "">
+ : TypeBuilder<parameters, bodyCode, returnType> {
let hasInferredContextParam = 1;
}
-class AttrBuilderWithInferredContext<dag parameters, code bodyCode = "">
- : AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
-class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
- : AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
+class AttrBuilderWithInferredContext<dag parameters, code bodyCode = "",
+ string returnType = "">
+ : AttrOrTypeBuilderWithInferredContext<parameters, bodyCode, returnType>;
+class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
+ string returnType = "">
+ : AttrOrTypeBuilderWithInferredContext<parameters, bodyCode, returnType>;
//===----------------------------------------------------------------------===//
// Definitions
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index bfb8adda2871f..a69e3b4a4e074 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -792,14 +792,14 @@ class AsmParser {
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
/// context parameter.
template <typename T, typename... ParamsT>
- T getChecked(SMLoc loc, ParamsT &&...params) {
+ auto getChecked(SMLoc loc, ParamsT &&...params) {
return T::getChecked([&] { return emitError(loc); },
std::forward<ParamsT>(params)...);
}
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
/// errors.
template <typename T, typename... ParamsT>
- T getChecked(ParamsT &&...params) {
+ auto getChecked(ParamsT &&...params) {
return T::getChecked([&] { return emitError(getNameLoc()); },
std::forward<ParamsT>(params)...);
}
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 043d63794bdf0..d4bab72f485a0 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -37,6 +37,9 @@ class AttrOrTypeBuilder : public Builder {
public:
using Builder::Builder;
+ /// Returns an optional builder return type.
+ Optional<StringRef> getReturnType() const;
+
/// Returns true if this builder is able to infer the MLIRContext parameter.
bool hasInferredContextParameter() const;
};
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index ddcec88a4dbef..d399c06d2f9c7 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -20,7 +20,11 @@ using namespace mlir::tblgen;
// AttrOrTypeBuilder
//===----------------------------------------------------------------------===//
-/// Returns true if this builder is able to infer the MLIRContext parameter.
+Optional<StringRef> AttrOrTypeBuilder::getReturnType() const {
+ Optional<StringRef> type = def->getValueAsOptionalString("returnType");
+ return type && !type->empty() ? type : llvm::None;
+}
+
bool AttrOrTypeBuilder::hasInferredContextParameter() const {
return def->getValueAsBit("hasInferredContextParam");
}
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 9fb0efa3d72b3..10c82b74282df 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -223,6 +223,19 @@ def TestAttrSelfTypeParameterFormat
let assemblyFormat = "`<` $a `>`";
}
+// Test overridding attribute builders with a custom builder.
+def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
+ let mnemonic = "override_builder";
+ let parameters = (ins "int":$a);
+ let assemblyFormat = "`<` $a `>`";
+
+ let skipDefaultBuilders = 1;
+ let genVerifyDecl = 1;
+ let builders = [AttrBuilder<(ins "int":$a), [{
+ return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a);
+ }], "::mlir::Attribute">];
+}
+
// Test simple extern 1D vector using ElementsAttrInterface.
def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
ElementsAttrInterface
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 751e9c1b33e3b..05b48671b5bae 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -55,8 +55,8 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
// ATTR: if (odsParser.parseRParen())
// ATTR: return {};
// ATTR: return TestAAttr::get(odsParser.getContext(),
-// ATTR: (*_result_value),
-// ATTR: (*_result_complex));
+// ATTR: IntegerAttr((*_result_value)),
+// ATTR: TestParamA((*_result_complex)));
// ATTR: }
// ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -114,8 +114,8 @@ def AttrA : TestAttr<"TestA"> {
// ATTR: return {};
// ATTR: }
// ATTR: return TestBAttr::get(odsParser.getContext(),
-// ATTR: (*_result_v0),
-// ATTR: (*_result_v1));
+// ATTR: TestParamA((*_result_v0)),
+// ATTR: TestParamB((*_result_v1)));
// ATTR: }
// ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -151,8 +151,8 @@ def AttrB : TestAttr<"TestB"> {
// ATTR: if (::mlir::failed(_result_v1))
// ATTR: return {};
// ATTR: return TestFAttr::get(odsParser.getContext(),
-// ATTR: (*_result_v0),
-// ATTR: (*_result_v1));
+// ATTR: int((*_result_v0)),
+// ATTR: int((*_result_v1)));
// ATTR: }
def AttrC : TestAttr<"TestF"> {
@@ -278,10 +278,10 @@ def TypeA : TestType<"TestC"> {
// TYPE: if (::mlir::failed(_result_v3))
// TYPE: return {};
// TYPE: return TestDType::get(odsParser.getContext(),
-// TYPE: (*_result_v0),
-// TYPE: (*_result_v1),
-// TYPE: (*_result_v2),
-// TYPE: (*_result_v3));
+// TYPE: TestParamC((*_result_v0)),
+// TYPE: TestParamD((*_result_v1)),
+// TYPE: TestParamC((*_result_v2)),
+// TYPE: TestParamD((*_result_v3)));
// TYPE: }
// TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const {
@@ -369,10 +369,10 @@ def TypeB : TestType<"TestD"> {
// TYPE: return {};
// TYPE: }
// TYPE: return TestEType::get(odsParser.getContext(),
-// TYPE: (*_result_v0),
-// TYPE: (*_result_v1),
-// TYPE: (*_result_v2),
-// TYPE: (*_result_v3));
+// TYPE: IntegerAttr((*_result_v0)),
+// TYPE: IntegerAttr((*_result_v1)),
+// TYPE: IntegerAttr((*_result_v2)),
+// TYPE: IntegerAttr((*_result_v3)));
// TYPE: }
// TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const {
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 50dba45c63d0b..8de12eb2b3006 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -31,9 +31,9 @@ include "mlir/IR/OpBase.td"
// DEF-NEXT: .Case(::test::IndexAttr::getMnemonic()
// DEF-NEXT: value = ::test::IndexAttr::parse(parser, type);
// DEF-NEXT: return ::mlir::success(!!value);
-// DEF: .Default([&](llvm::StringRef keyword,
+// DEF: .Default([&](llvm::StringRef keyword,
// DEF-NEXT: *mnemonic = keyword;
-// DEF-NEXT: return llvm::None;
+// DEF-NEXT: return llvm::None;
def Test_Dialect: Dialect {
// DECL-NOT: TestDialect
@@ -148,3 +148,13 @@ def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> {
// DEF: ParamWithAccessorTypeAttrStorage
// DEF: ParamWithAccessorTypeAttrStorage(std::string param)
// DEF: StringRef ParamWithAccessorTypeAttr::getParam()
+
+def G_BuilderWithReturnTypeAttr : TestAttr<"BuilderWithReturnType"> {
+ let parameters = (ins "int":$a);
+ let genVerifyDecl = 1;
+ let builders = [AttrBuilder<(ins), [{ return {}; }], "::mlir::Attribute">];
+}
+
+// DECL-LABEL: class BuilderWithReturnTypeAttr
+// DECL: ::mlir::Attribute get(
+// DECL: ::mlir::Attribute getChecked(
diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
index 4ccf3bd624094..ee92ea06a208c 100644
--- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir
@@ -13,3 +13,9 @@ func.func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [
// CHECK-LABEL: @qualifiedAttr()
// CHECK-SAME: #test.cmpnd_nested_outer_qual<i #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>>
func.func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qual<i #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>>}
+
+// CHECK-LABEL: @overriddenAttr
+// CHECK-SAME: foo = 5 : index
+func.func private @overriddenAttr() attributes {
+ foo = #test.override_builder<5>
+}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index db443e9ac0cff..ef2bd4102b6f0 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -348,7 +348,10 @@ getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
// Don't emit a body if there isn't one.
auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
- Method *m = defCls.addMethod(def.getCppClassName(), "get", props,
+ StringRef returnType = def.getCppClassName();
+ if (Optional<StringRef> builderReturnType = builder.getReturnType())
+ returnType = *builderReturnType;
+ Method *m = defCls.addMethod(returnType, "get", props,
getCustomBuilderParams({}, builder));
if (!builder.getBody())
return;
@@ -373,8 +376,11 @@ static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
// Don't emit a body if there isn't one.
auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
+ StringRef returnType = def.getCppClassName();
+ if (Optional<StringRef> builderReturnType = builder.getReturnType())
+ returnType = *builderReturnType;
Method *m = defCls.addMethod(
- def.getCppClassName(), "getChecked", props,
+ returnType, "getChecked", props,
getCustomBuilderParams(
{{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
builder));
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 19ab7f3eb72ce..8321861738d08 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -311,7 +311,9 @@ void DefFormat::genParser(MethodBody &os) {
} else {
selfOs << formatv("(*_result_{0})", param.getName());
}
- os << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()));
+ os << param.getCppType() << "("
+ << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
+ << ")";
}
os << ");";
}
More information about the Mlir-commits
mailing list