[Mlir-commits] [mlir] [mlir] add optional type functor to call and function interfaces (PR #146979)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 8 13:54:49 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir
Author: Gibran Essa (gibrane)
<details>
<summary>Changes</summary>
adds type parsing functors to `call_interface_impl` and `function_interface_impl`
---
Patch is 31.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146979.diff
10 Files Affected:
- (modified) mlir/include/mlir/IR/OpImplementation.h (+15-9)
- (modified) mlir/include/mlir/Interfaces/CallInterfaces.h (+24-14)
- (modified) mlir/include/mlir/Interfaces/FunctionImplementation.h (+22-12)
- (modified) mlir/lib/AsmParser/Parser.cpp (+11-5)
- (modified) mlir/lib/IR/AsmPrinter.cpp (+12-10)
- (modified) mlir/lib/Interfaces/CallInterfaces.cpp (+60-19)
- (modified) mlir/lib/Interfaces/FunctionImplementation.cpp (+32-14)
- (added) mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir (+26)
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+34)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+36)
``````````diff
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d70aa346eaa1f..5232a31ff5c77 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,10 +463,11 @@ class OpAsmPrinter : public AsmPrinter {
/// where location printing is controlled by the standard internal option.
/// You may pass omitType=true to not print a type, and pass an empty
/// attribute list if you don't care for attributes.
- virtual void printRegionArgument(BlockArgument arg,
- ArrayRef<NamedAttribute> argAttrs = {},
- bool omitType = false) = 0;
-
+ /// You can override default type printing behavior with the typePrinter arg.
+ virtual void printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+ bool omitType = false,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) = 0;
/// Print implementations for various things an operation contains.
virtual void printOperand(Value value) = 0;
virtual void printOperand(Value value, raw_ostream &os) = 0;
@@ -1701,13 +1702,18 @@ class OpAsmParser : public AsmParser {
///
/// If `allowType` is false or `allowAttrs` are false then the respective
/// parts of the grammar are not parsed.
- virtual ParseResult parseArgument(Argument &result, bool allowType = false,
- bool allowAttrs = false) = 0;
+ /// You can override default type parsing behavior with the typeParser arg.
+ virtual ParseResult
+ parseArgument(Argument &result, bool allowType = false,
+ bool allowAttrs = false,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+ nullptr) = 0;
/// Parse a single argument if present.
- virtual OptionalParseResult
- parseOptionalArgument(Argument &result, bool allowType = false,
- bool allowAttrs = false) = 0;
+ virtual OptionalParseResult parseOptionalArgument(
+ Argument &result, bool allowType = false, bool allowAttrs = false,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+ nullptr) = 0;
/// Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 2bf3a3ca5f8a8..66f0287471da5 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -37,6 +37,8 @@ Operation *resolveCallable(CallOpInterface call,
SymbolTableCollection *symbolTable = nullptr);
/// Parse a function or call result list.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
///
/// function-result-list ::= function-result-list-parens
/// | non-function-type
@@ -45,31 +47,39 @@ Operation *resolveCallable(CallOpInterface call,
/// function-result-list-no-parens ::= function-result (`,` function-result)*
/// function-result ::= type attribute-dict?
///
-ParseResult
-parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs);
+ParseResult parseFunctionResultList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Parses a function signature using `parser`. This does not deal with function
/// signatures containing SSA region arguments (to parse these signatures, use
/// function_interface_impl::parseFunctionSignature). When
/// `mustParseEmptyResult`, `-> ()` is expected when there is no result type.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
+///
///
/// no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)`
/// -> function-result-list
/// no-ssa-function-arg-list ::= no-ssa-function-arg
/// (`,` no-ssa-function-arg)*
/// no-ssa-function-arg ::= type attribute-dict?
-ParseResult parseFunctionSignature(OpAsmParser &parser,
- SmallVectorImpl<Type> &argTypes,
- SmallVectorImpl<DictionaryAttr> &argAttrs,
- SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs,
- bool mustParseEmptyResult = true);
+ParseResult parseFunctionSignature(
+ OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+ SmallVectorImpl<DictionaryAttr> &argAttrs,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ bool mustParseEmptyResult = true,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Print a function signature for a call or callable operation. If a body
/// region is provided, the SSA arguments are printed in the signature. When
/// `printEmptyResult` is false, `-> function-result-list` is omitted when
/// `resultTypes` is empty.
+/// You can override the default type printing behavior using the typePrinter
+/// parameter.
+///
///
/// function-signature ::= ssa-function-signature
/// | no-ssa-function-signature
@@ -77,11 +87,11 @@ ParseResult parseFunctionSignature(OpAsmParser &parser,
/// -> function-result-list
/// ssa-function-arg-list ::= ssa-function-arg (`,` ssa-function-arg)*
/// ssa-function-arg ::= `%`name `:` type attribute-dict?
-void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes,
- ArrayAttr argAttrs, bool isVariadic,
- TypeRange resultTypes, ArrayAttr resultAttrs,
- Region *body = nullptr,
- bool printEmptyResult = true);
+void printFunctionSignature(
+ OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
+ TypeRange resultTypes, ArrayAttr resultAttrs, Region *body = nullptr,
+ bool printEmptyResult = true,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
/// Adds argument and result attributes, provided as `argAttrs` and
/// `resultAttrs` arguments, to the list of operation attributes in `result`.
diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index 374c2c534f87d..de89a6bc0d50a 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -45,11 +45,14 @@ using FuncTypeBuilder = function_ref<Type(
/// indicates whether functions with variadic arguments are supported. The
/// trailing arguments are populated by this function with names, types,
/// attributes and locations of the arguments and those of the results.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
ParseResult parseFunctionSignatureWithArguments(
OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs);
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
@@ -59,25 +62,32 @@ ParseResult parseFunctionSignatureWithArguments(
/// whether the function is variadic. If the builder returns a null type,
/// `result` will not contain the `type` attribute. The caller can then add a
/// type, report the error or delegate the reporting to the op's verifier.
-ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
- bool allowVariadic, StringAttr typeAttrName,
- FuncTypeBuilder funcTypeBuilder,
- StringAttr argAttrsName, StringAttr resAttrsName);
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
+ParseResult parseFunctionOp(
+ OpAsmParser &parser, OperationState &result, bool allowVariadic,
+ StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
+ StringAttr argAttrsName, StringAttr resAttrsName,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
- StringRef typeAttrName, StringAttr argAttrsName,
- StringAttr resAttrsName);
+/// You can override the default type printing behavior using the typePrinter
+/// parameter.
+void printFunctionOp(
+ OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+ StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
/// Prints the signature of the function-like operation `op`. Assumes `op` has
/// is a FunctionOpInterface and has passed verification.
-inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
- ArrayRef<Type> argTypes, bool isVariadic,
- ArrayRef<Type> resultTypes) {
+inline void printFunctionSignature(
+ OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
+ bool isVariadic, ArrayRef<Type> resultTypes,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) {
call_interface_impl::printFunctionSignature(
p, argTypes, op.getArgAttrsAttr(), isVariadic, resultTypes,
op.getResAttrsAttr(), &op->getRegion(0),
- /*printEmptyResult=*/false);
+ /*printEmptyResult=*/false, typePrinter);
}
/// Prints the list of function prefixed with the "attributes" keyword. The
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 756d3d01a4534..06282f648549f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1830,10 +1830,14 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
/// If `allowType` is false or `allowAttrs` are false then the respective
/// parts of the grammar are not parsed.
ParseResult parseArgument(Argument &result, bool allowType = false,
- bool allowAttrs = false) override {
+ bool allowAttrs = false,
+ function_ref<ParseResult(OpAsmParser &, Type &)>
+ typeParser = nullptr) override {
NamedAttrList attrs;
if (parseOperand(result.ssaName, /*allowResultNumber=*/false) ||
- (allowType && parseColonType(result.type)) ||
+ (allowType && !typeParser && parseColonType(result.type)) ||
+ (allowType && typeParser &&
+ (parseColon() || typeParser(*this, result.type))) ||
(allowAttrs && parseOptionalAttrDict(attrs)) ||
parseOptionalLocationSpecifier(result.sourceLoc))
return failure();
@@ -1842,10 +1846,12 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
}
/// Parse a single argument if present.
- OptionalParseResult parseOptionalArgument(Argument &result, bool allowType,
- bool allowAttrs) override {
+ OptionalParseResult parseOptionalArgument(
+ Argument &result, bool allowType, bool allowAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+ nullptr) override {
if (parser.getToken().is(Token::percent_identifier))
- return parseArgument(result, allowType, allowAttrs);
+ return parseArgument(result, allowType, allowAttrs, typeParser);
return std::nullopt;
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad290a1981..71a372369b4bf 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -783,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
print(&b);
}
- void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
- bool omitType) override {
- printType(arg.getType());
+ void printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override {
+ typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
// Visit the argument location.
if (printerFlags.shouldPrintDebugInfo())
// TODO: Allow deferring argument locations.
@@ -3295,9 +3296,10 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
/// where location printing is controlled by the standard internal option.
/// You may pass omitType=true to not print a type, and pass an empty
/// attribute list if you don't care for attributes.
- void printRegionArgument(BlockArgument arg,
- ArrayRef<NamedAttribute> argAttrs = {},
- bool omitType = false) override;
+ void printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+ bool omitType = false,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override;
/// Print the ID for the given value.
void printOperand(Value value) override { printValueID(value); }
@@ -3545,13 +3547,13 @@ void OperationPrinter::printResourceFileMetadata(
/// where location printing is controlled by the standard internal option.
/// You may pass omitType=true to not print a type, and pass an empty
/// attribute list if you don't care for attributes.
-void OperationPrinter::printRegionArgument(BlockArgument arg,
- ArrayRef<NamedAttribute> argAttrs,
- bool omitType) {
+void OperationPrinter::printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
printOperand(arg);
if (!omitType) {
os << ": ";
- printType(arg.getType());
+ typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
}
printOptionalAttrDict(argAttrs);
// TODO: We should allow location aliases on block arguments.
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index e8ed4b339a0cb..a08338e514a08 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -15,15 +15,24 @@ using namespace mlir;
// Argument and result attributes utilities
//===----------------------------------------------------------------------===//
-static ParseResult
-parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
- SmallVectorImpl<DictionaryAttr> &attrs) {
+static inline ParseResult defaultTypeParser(OpAsmParser &parser, Type &ty) {
+ return parser.parseType(ty);
+}
+
+static inline void defaultTypePrinter(OpAsmPrinter &printer, Type ty) {
+ printer << ty;
+}
+
+static ParseResult parseTypeAndAttrList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &types,
+ SmallVectorImpl<DictionaryAttr> &attrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
// Parse individual function results.
return parser.parseCommaSeparatedList([&]() -> ParseResult {
types.emplace_back();
attrs.emplace_back();
NamedAttrList attrList;
- if (parser.parseType(types.back()) ||
+ if (typeParser(parser, types.back()) ||
parser.parseOptionalAttrDict(attrList))
return failure();
attrs.back() = attrList.getDictionary(parser.getContext());
@@ -33,12 +42,16 @@ parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
ParseResult call_interface_impl::parseFunctionResultList(
OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+ if (!typeParser)
+ typeParser = defaultTypeParser;
+
if (failed(parser.parseOptionalLParen())) {
// We already know that there is no `(`, so parse a type.
// Because there is no `(`, it cannot be a function type.
Type ty;
- if (parser.parseType(ty))
+ if (typeParser(parser, ty))
return failure();
resultTypes.push_back(ty);
resultAttrs.emplace_back();
@@ -48,7 +61,7 @@ ParseResult call_interface_impl::parseFunctionResultList(
// Special case for an empty set of parens.
if (succeeded(parser.parseOptionalRParen()))
return success();
- if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+ if (parseTypeAndAttrList(parser, resultTypes, resultAttrs, typeParser))
return failure();
return parser.parseRParen();
}
@@ -57,20 +70,24 @@ ParseResult call_interface_impl::parseFunctionSignature(
OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<DictionaryAttr> &argAttrs,
SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+ SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+ if (!typeParser)
+ typeParser = defaultTypeParser;
+
// Parse arguments.
if (parser.parseLParen())
return failure();
if (failed(parser.parseOptionalRParen())) {
- if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+ if (parseTypeAndAttrList(parser, argTypes, argAttrs, typeParser))
return failure();
if (parser.parseRParen())
return failure();
}
// Parse results.
if (succeeded(parser.parseOptionalArrow()))
- return call_interface_impl::parseFunctionResultList(parser, resultTypes,
- resultAttrs);
+ return call_interface_impl::parseFunctionResultList(
+ parser, resultTypes, resultAttrs, typeParser);
if (mustParseEmptyResult)
return failure();
return success();
@@ -78,8 +95,12 @@ ParseResult call_interface_impl::parseFunctionSignature(
/// Print a function result list. The provided `attrs` must either be null, or
/// contain a set of DictionaryAttrs of the same arity as `types`.
-static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
- ArrayAttr attrs) {
+static void
+printFunctionResultList(OpAsmPrinter &p, TypeRange types, ArrayAttr attrs,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+ if (!typePrinter)
+ typePrinter = defaultTypePrinter;
+
assert(!types.empty() && "Should not be called for empty result list.");
assert((!attrs || attrs.size() == types.size()) &&
"Invalid number of attributes.");
@@ -90,22 +111,41 @@ static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
if (needsParens)
os << '(';
llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
- p.printType(types[i]);
+ typePrinter(p, types[i]);
if (attrs)
p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
});
if (needsParens)
os << ')';
}
+static void
+printFunctionalType(OpAsmPrinter &p, TypeRange &inputs, TypeRange &results,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+ p << '(';
+ llvm::interleaveComma(inputs, p, [&](Type ty) { typePrinter(p, ty); });
+ p << ')';
+
+ bool wrapped = !llvm::hasSingleElement(results) ||
+ llvm::isa<FunctionType>((*results.begin()));
+ if (wrapped)
+ p << '(';
+ llvm::interleaveComma(results, p, [&](Type ty) { typePrinter(p, ty); });
+ if (wrapped)
+ p << ')';
+}
void call_interface_impl::printFunctionSignature(
OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
T...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/146979
More information about the Mlir-commits
mailing list