[Mlir-commits] [mlir] [mlir] add optional type functor to call and function interfaces (PR #146979)
Gibran Essa
llvmlistbot at llvm.org
Thu Jul 3 17:19:09 PDT 2025
https://github.com/gibrane created https://github.com/llvm/llvm-project/pull/146979
None
>From b4d8671055e492b689f30fa914ae747e2a2fd7ec Mon Sep 17 00:00:00 2001
From: Gibran Essa <gessa at nvidia.com>
Date: Wed, 2 Jul 2025 20:25:00 +0000
Subject: [PATCH] [mlir] add optional type functor to call and function
interfaces
---
mlir/include/mlir/IR/OpImplementation.h | 22 +++---
mlir/include/mlir/Interfaces/CallInterfaces.h | 30 +++----
.../mlir/Interfaces/FunctionImplementation.h | 28 ++++---
mlir/lib/AsmParser/Parser.cpp | 16 ++--
mlir/lib/IR/AsmPrinter.cpp | 22 +++---
mlir/lib/Interfaces/CallInterfaces.cpp | 79 ++++++++++++++-----
.../lib/Interfaces/FunctionImplementation.cpp | 46 +++++++----
.../custom-type-parse-and-print.mlir | 26 ++++++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 34 ++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 36 +++++++++
10 files changed, 256 insertions(+), 83 deletions(-)
create mode 100644 mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d70aa346eaa1f..6ebb021446e15 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,10 +463,10 @@ class OpAsmPrinter : public AsmPrinter {
/// where location printing is controlled by the standard internal option.
/// You may pass omitType=true to not print a type, and pass an empty
/// attribute list if you don't care for attributes.
- virtual void printRegionArgument(BlockArgument arg,
- ArrayRef<NamedAttribute> argAttrs = {},
- bool omitType = false) = 0;
-
+ virtual void printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+ bool omitType = false,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) = 0;
/// Print implementations for various things an operation contains.
virtual void printOperand(Value value) = 0;
virtual void printOperand(Value value, raw_ostream &os) = 0;
@@ -1701,13 +1701,17 @@ class OpAsmParser : public AsmParser {
///
/// If `allowType` is false or `allowAttrs` are false then the respective
/// parts of the grammar are not parsed.
- virtual ParseResult parseArgument(Argument &result, bool allowType = false,
- bool allowAttrs = false) = 0;
+ virtual ParseResult
+ parseArgument(Argument &result, bool allowType = false,
+ bool allowAttrs = false,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+ nullptr) = 0;
/// Parse a single argument if present.
- virtual OptionalParseResult
- parseOptionalArgument(Argument &result, bool allowType = false,
- bool allowAttrs = false) = 0;
+ virtual OptionalParseResult parseOptionalArgument(
+ Argument &result, bool allowType = false, bool allowAttrs = false,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+ nullptr) = 0;
/// Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 2bf3a3ca5f8a8..78aa9181931e0 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -45,9 +45,10 @@ Operation *resolveCallable(CallOpInterface call,
/// function-result-list-no-parens ::= function-result (`,` function-result)*
/// function-result ::= type attribute-dict?
///
-ParseResult
-parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs);
+ParseResult parseFunctionResultList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Parses a function signature using `parser`. This does not deal with function
/// signatures containing SSA region arguments (to parse these signatures, use
@@ -59,12 +60,13 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
/// no-ssa-function-arg-list ::= no-ssa-function-arg
/// (`,` no-ssa-function-arg)*
/// no-ssa-function-arg ::= type attribute-dict?
-ParseResult parseFunctionSignature(OpAsmParser &parser,
- SmallVectorImpl<Type> &argTypes,
- SmallVectorImpl<DictionaryAttr> &argAttrs,
- SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs,
- bool mustParseEmptyResult = true);
+ParseResult parseFunctionSignature(
+ OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+ SmallVectorImpl<DictionaryAttr> &argAttrs,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ bool mustParseEmptyResult = true,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Print a function signature for a call or callable operation. If a body
/// region is provided, the SSA arguments are printed in the signature. When
@@ -77,11 +79,11 @@ ParseResult parseFunctionSignature(OpAsmParser &parser,
/// -> function-result-list
/// ssa-function-arg-list ::= ssa-function-arg (`,` ssa-function-arg)*
/// ssa-function-arg ::= `%`name `:` type attribute-dict?
-void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes,
- ArrayAttr argAttrs, bool isVariadic,
- TypeRange resultTypes, ArrayAttr resultAttrs,
- Region *body = nullptr,
- bool printEmptyResult = true);
+void printFunctionSignature(
+ OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
+ TypeRange resultTypes, ArrayAttr resultAttrs, Region *body = nullptr,
+ bool printEmptyResult = true,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
/// Adds argument and result attributes, provided as `argAttrs` and
/// `resultAttrs` arguments, to the list of operation attributes in `result`.
diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index 374c2c534f87d..1e42e09936658 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -49,7 +49,8 @@ ParseResult parseFunctionSignatureWithArguments(
OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs);
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
@@ -59,25 +60,28 @@ ParseResult parseFunctionSignatureWithArguments(
/// whether the function is variadic. If the builder returns a null type,
/// `result` will not contain the `type` attribute. The caller can then add a
/// type, report the error or delegate the reporting to the op's verifier.
-ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
- bool allowVariadic, StringAttr typeAttrName,
- FuncTypeBuilder funcTypeBuilder,
- StringAttr argAttrsName, StringAttr resAttrsName);
+ParseResult parseFunctionOp(
+ OpAsmParser &parser, OperationState &result, bool allowVariadic,
+ StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
+ StringAttr argAttrsName, StringAttr resAttrsName,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
/// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
- StringRef typeAttrName, StringAttr argAttrsName,
- StringAttr resAttrsName);
+void printFunctionOp(
+ OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+ StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
/// Prints the signature of the function-like operation `op`. Assumes `op` has
/// is a FunctionOpInterface and has passed verification.
-inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
- ArrayRef<Type> argTypes, bool isVariadic,
- ArrayRef<Type> resultTypes) {
+inline void printFunctionSignature(
+ OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
+ bool isVariadic, ArrayRef<Type> resultTypes,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) {
call_interface_impl::printFunctionSignature(
p, argTypes, op.getArgAttrsAttr(), isVariadic, resultTypes,
op.getResAttrsAttr(), &op->getRegion(0),
- /*printEmptyResult=*/false);
+ /*printEmptyResult=*/false, typePrinter);
}
/// Prints the list of function prefixed with the "attributes" keyword. The
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 756d3d01a4534..06282f648549f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1830,10 +1830,14 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
/// If `allowType` is false or `allowAttrs` are false then the respective
/// parts of the grammar are not parsed.
ParseResult parseArgument(Argument &result, bool allowType = false,
- bool allowAttrs = false) override {
+ bool allowAttrs = false,
+ function_ref<ParseResult(OpAsmParser &, Type &)>
+ typeParser = nullptr) override {
NamedAttrList attrs;
if (parseOperand(result.ssaName, /*allowResultNumber=*/false) ||
- (allowType && parseColonType(result.type)) ||
+ (allowType && !typeParser && parseColonType(result.type)) ||
+ (allowType && typeParser &&
+ (parseColon() || typeParser(*this, result.type))) ||
(allowAttrs && parseOptionalAttrDict(attrs)) ||
parseOptionalLocationSpecifier(result.sourceLoc))
return failure();
@@ -1842,10 +1846,12 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
}
/// Parse a single argument if present.
- OptionalParseResult parseOptionalArgument(Argument &result, bool allowType,
- bool allowAttrs) override {
+ OptionalParseResult parseOptionalArgument(
+ Argument &result, bool allowType, bool allowAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+ nullptr) override {
if (parser.getToken().is(Token::percent_identifier))
- return parseArgument(result, allowType, allowAttrs);
+ return parseArgument(result, allowType, allowAttrs, typeParser);
return std::nullopt;
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad290a1981..71a372369b4bf 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -783,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
print(&b);
}
- void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
- bool omitType) override {
- printType(arg.getType());
+ void printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override {
+ typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
// Visit the argument location.
if (printerFlags.shouldPrintDebugInfo())
// TODO: Allow deferring argument locations.
@@ -3295,9 +3296,10 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
/// where location printing is controlled by the standard internal option.
/// You may pass omitType=true to not print a type, and pass an empty
/// attribute list if you don't care for attributes.
- void printRegionArgument(BlockArgument arg,
- ArrayRef<NamedAttribute> argAttrs = {},
- bool omitType = false) override;
+ void printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+ bool omitType = false,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override;
/// Print the ID for the given value.
void printOperand(Value value) override { printValueID(value); }
@@ -3545,13 +3547,13 @@ void OperationPrinter::printResourceFileMetadata(
/// where location printing is controlled by the standard internal option.
/// You may pass omitType=true to not print a type, and pass an empty
/// attribute list if you don't care for attributes.
-void OperationPrinter::printRegionArgument(BlockArgument arg,
- ArrayRef<NamedAttribute> argAttrs,
- bool omitType) {
+void OperationPrinter::printRegionArgument(
+ BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
printOperand(arg);
if (!omitType) {
os << ": ";
- printType(arg.getType());
+ typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
}
printOptionalAttrDict(argAttrs);
// TODO: We should allow location aliases on block arguments.
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index e8ed4b339a0cb..a08338e514a08 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -15,15 +15,24 @@ using namespace mlir;
// Argument and result attributes utilities
//===----------------------------------------------------------------------===//
-static ParseResult
-parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
- SmallVectorImpl<DictionaryAttr> &attrs) {
+static inline ParseResult defaultTypeParser(OpAsmParser &parser, Type &ty) {
+ return parser.parseType(ty);
+}
+
+static inline void defaultTypePrinter(OpAsmPrinter &printer, Type ty) {
+ printer << ty;
+}
+
+static ParseResult parseTypeAndAttrList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &types,
+ SmallVectorImpl<DictionaryAttr> &attrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
// Parse individual function results.
return parser.parseCommaSeparatedList([&]() -> ParseResult {
types.emplace_back();
attrs.emplace_back();
NamedAttrList attrList;
- if (parser.parseType(types.back()) ||
+ if (typeParser(parser, types.back()) ||
parser.parseOptionalAttrDict(attrList))
return failure();
attrs.back() = attrList.getDictionary(parser.getContext());
@@ -33,12 +42,16 @@ parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
ParseResult call_interface_impl::parseFunctionResultList(
OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+ if (!typeParser)
+ typeParser = defaultTypeParser;
+
if (failed(parser.parseOptionalLParen())) {
// We already know that there is no `(`, so parse a type.
// Because there is no `(`, it cannot be a function type.
Type ty;
- if (parser.parseType(ty))
+ if (typeParser(parser, ty))
return failure();
resultTypes.push_back(ty);
resultAttrs.emplace_back();
@@ -48,7 +61,7 @@ ParseResult call_interface_impl::parseFunctionResultList(
// Special case for an empty set of parens.
if (succeeded(parser.parseOptionalRParen()))
return success();
- if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+ if (parseTypeAndAttrList(parser, resultTypes, resultAttrs, typeParser))
return failure();
return parser.parseRParen();
}
@@ -57,20 +70,24 @@ ParseResult call_interface_impl::parseFunctionSignature(
OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<DictionaryAttr> &argAttrs,
SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+ SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+ if (!typeParser)
+ typeParser = defaultTypeParser;
+
// Parse arguments.
if (parser.parseLParen())
return failure();
if (failed(parser.parseOptionalRParen())) {
- if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+ if (parseTypeAndAttrList(parser, argTypes, argAttrs, typeParser))
return failure();
if (parser.parseRParen())
return failure();
}
// Parse results.
if (succeeded(parser.parseOptionalArrow()))
- return call_interface_impl::parseFunctionResultList(parser, resultTypes,
- resultAttrs);
+ return call_interface_impl::parseFunctionResultList(
+ parser, resultTypes, resultAttrs, typeParser);
if (mustParseEmptyResult)
return failure();
return success();
@@ -78,8 +95,12 @@ ParseResult call_interface_impl::parseFunctionSignature(
/// Print a function result list. The provided `attrs` must either be null, or
/// contain a set of DictionaryAttrs of the same arity as `types`.
-static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
- ArrayAttr attrs) {
+static void
+printFunctionResultList(OpAsmPrinter &p, TypeRange types, ArrayAttr attrs,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+ if (!typePrinter)
+ typePrinter = defaultTypePrinter;
+
assert(!types.empty() && "Should not be called for empty result list.");
assert((!attrs || attrs.size() == types.size()) &&
"Invalid number of attributes.");
@@ -90,22 +111,41 @@ static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
if (needsParens)
os << '(';
llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
- p.printType(types[i]);
+ typePrinter(p, types[i]);
if (attrs)
p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
});
if (needsParens)
os << ')';
}
+static void
+printFunctionalType(OpAsmPrinter &p, TypeRange &inputs, TypeRange &results,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+ p << '(';
+ llvm::interleaveComma(inputs, p, [&](Type ty) { typePrinter(p, ty); });
+ p << ')';
+
+ bool wrapped = !llvm::hasSingleElement(results) ||
+ llvm::isa<FunctionType>((*results.begin()));
+ if (wrapped)
+ p << '(';
+ llvm::interleaveComma(results, p, [&](Type ty) { typePrinter(p, ty); });
+ if (wrapped)
+ p << ')';
+}
void call_interface_impl::printFunctionSignature(
OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
TypeRange resultTypes, ArrayAttr resultAttrs, Region *body,
- bool printEmptyResult) {
+ bool printEmptyResult,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+ if (!typePrinter)
+ typePrinter = defaultTypePrinter;
+
bool isExternal = !body || body->empty();
if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
printEmptyResult) {
- p.printFunctionalType(argTypes, resultTypes);
+ printFunctionalType(p, argTypes, resultTypes, typePrinter);
return;
}
@@ -118,9 +158,10 @@ void call_interface_impl::printFunctionSignature(
ArrayRef<NamedAttribute> attrs;
if (argAttrs)
attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
- p.printRegionArgument(body->getArgument(i), attrs);
+ p.printRegionArgument(body->getArgument(i), attrs, /*omitType=*/false,
+ typePrinter);
} else {
- p.printType(argTypes[i]);
+ typePrinter(p, argTypes[i]);
if (argAttrs)
p.printOptionalAttrDict(
llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
@@ -137,7 +178,7 @@ void call_interface_impl::printFunctionSignature(
if (!resultTypes.empty()) {
p << " -> ";
- printFunctionResultList(p, resultTypes, resultAttrs);
+ printFunctionResultList(p, resultTypes, resultAttrs, typePrinter);
} else if (printEmptyResult) {
p << " -> ()";
}
diff --git a/mlir/lib/Interfaces/FunctionImplementation.cpp b/mlir/lib/Interfaces/FunctionImplementation.cpp
index 90f32896e8181..89512e68bb22b 100644
--- a/mlir/lib/Interfaces/FunctionImplementation.cpp
+++ b/mlir/lib/Interfaces/FunctionImplementation.cpp
@@ -13,11 +13,18 @@
using namespace mlir;
-static ParseResult
-parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
- SmallVectorImpl<OpAsmParser::Argument> &arguments,
- bool &isVariadic) {
+static inline ParseResult defaultTypeParser(OpAsmParser &parser, Type &ty) {
+ return parser.parseType(ty);
+}
+
+static inline void defaultTypePrinter(OpAsmPrinter &printer, Type ty) {
+ printer << ty;
+}
+static ParseResult parseFunctionArgumentList(
+ OpAsmParser &parser, bool allowVariadic,
+ SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
// Parse the function arguments. The argument list either has to consistently
// have ssa-id's followed by types, or just be a type list. It isn't ok to
// sometimes have SSA ID's and sometimes not.
@@ -40,7 +47,7 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
// Parse argument name if present.
OpAsmParser::Argument argument;
auto argPresent = parser.parseOptionalArgument(
- argument, /*allowType=*/true, /*allowAttrs=*/true);
+ argument, /*allowType=*/true, /*allowAttrs=*/true, typeParser);
if (argPresent.has_value()) {
if (failed(argPresent.value()))
return failure(); // Present but malformed.
@@ -59,7 +66,7 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
"expected SSA identifier");
NamedAttrList attrs;
- if (parser.parseType(argument.type) ||
+ if (typeParser(parser, argument.type) ||
parser.parseOptionalAttrDict(attrs) ||
parser.parseOptionalLocationSpecifier(argument.sourceLoc))
return failure();
@@ -74,19 +81,25 @@ ParseResult function_interface_impl::parseFunctionSignatureWithArguments(
OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<DictionaryAttr> &resultAttrs) {
- if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic))
+ SmallVectorImpl<DictionaryAttr> &resultAttrs,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+ if (!typeParser)
+ typeParser = defaultTypeParser;
+
+ if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic,
+ typeParser))
return failure();
if (succeeded(parser.parseOptionalArrow()))
- return call_interface_impl::parseFunctionResultList(parser, resultTypes,
- resultAttrs);
+ return call_interface_impl::parseFunctionResultList(
+ parser, resultTypes, resultAttrs, typeParser);
return success();
}
ParseResult function_interface_impl::parseFunctionOp(
OpAsmParser &parser, OperationState &result, bool allowVariadic,
StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
- StringAttr argAttrsName, StringAttr resAttrsName) {
+ StringAttr argAttrsName, StringAttr resAttrsName,
+ function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
@@ -105,7 +118,8 @@ ParseResult function_interface_impl::parseFunctionOp(
SMLoc signatureLocation = parser.getCurrentLocation();
bool isVariadic = false;
if (parseFunctionSignatureWithArguments(parser, allowVariadic, entryArgs,
- isVariadic, resultTypes, resultAttrs))
+ isVariadic, resultTypes, resultAttrs,
+ typeParser))
return failure();
std::string errorMessage;
@@ -174,7 +188,11 @@ void function_interface_impl::printFunctionAttributes(
void function_interface_impl::printFunctionOp(
OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
- StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) {
+ StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName,
+ function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+ if (!typePrinter)
+ typePrinter = defaultTypePrinter;
+
// Print the operation and the function name.
auto funcName =
op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -188,7 +206,7 @@ void function_interface_impl::printFunctionOp(
ArrayRef<Type> argTypes = op.getArgumentTypes();
ArrayRef<Type> resultTypes = op.getResultTypes();
- printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
+ printFunctionSignature(p, op, argTypes, isVariadic, resultTypes, typePrinter);
printFunctionAttributes(
p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName});
// Print the body if this is not an external function.
diff --git a/mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir b/mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir
new file mode 100644
index 0000000000000..8a19b6327a122
--- /dev/null
+++ b/mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK: test.custom_type_format_func @single_arg_no_return(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: type:f32) {
+// CHECK-NEXT: }
+test.custom_type_format_func @single_arg_no_return(%arg0 : type:f32) {}
+
+// CHECK: test.custom_type_format_func @multiple_arg_multiple_return(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: type:f32
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: type:f32)
+// CHECK-SAME: -> (type:f32, type:f32) {
+// CHECK-NEXT: }
+test.custom_type_format_func @multiple_arg_multiple_return(
+ %arg0 : type:f32, %arg1 : type:f32)
+ -> (type:f32, type:f32) {}
+
+// CHECK: test.custom_type_format_func @no_block
+// CHECK-SAME: (type:f32, type:f32)
+// CHECK-SAME: -> (type:f32, type:f32)
+test.custom_type_format_func @no_block(%arg0 : type:f32, %arg1 : type:f32)
+ -> (type:f32, type:f32)
+
+// CHECK: test.custom_type_format_func @one_return
+// CHECK-SAME: -> type:f32
+test.custom_type_format_func @one_return()
+ -> type:f32
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 3ab4ef2680978..030cdd4083c25 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -10,9 +10,12 @@
#include "TestOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/Support/LogicalResult.h"
using namespace mlir;
using namespace test;
@@ -1454,3 +1457,34 @@ test::TestCreateTensorOp::getBufferType(
return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
getContext(), type.getShape(), type.getElementType(), nullptr));
}
+
+//===----------------------------------------------------------------------===//
+// CustomTypeFormatFuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult CustomTypeFormatFuncOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ auto typeParser = [&](OpAsmParser &p, Type &ty) {
+ if (p.parseKeyword("type") || p.parseColon())
+ return (ParseResult)failure();
+ return p.parseType(ty);
+ };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name),
+ typeParser);
+}
+
+void CustomTypeFormatFuncOp::print(OpAsmPrinter &p) {
+ auto typePrinter = [&](OpAsmPrinter &p, Type ty) { p << "type:" << ty; };
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName(), typePrinter);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5320ba4ea3829..604b3f0683558 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -692,6 +692,42 @@ def FoldToCallOp : TEST_Op<"fold_to_call_op"> {
let hasCanonicalizer = 1;
}
+def CustomTypeFormatFuncOp
+ : TEST_Op<"custom_type_format_func", [FunctionOpInterface, NoTerminator]> {
+
+ let arguments = (ins TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs);
+
+ let regions = (region AnyRegion:$body);
+ let extraClassDeclaration = [{
+ //===------------------------------------------------------------------===//
+ // FunctionOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// Returns the region on the current operation that is callable. This may
+ /// return null in the case of an external callable object, e.g. an external
+ /// function.
+ ::mlir::Region *getCallableRegion() {
+ return nullptr;
+ }
+
+ /// Returns the argument types of this function.
+ ::mlir::ArrayRef<::mlir::Type> getArgumentTypes() {
+ return getFunctionType().getInputs();
+ }
+
+ /// Returns the result types of this function.
+ ::mlir::ArrayRef<::mlir::Type> getResultTypes() {
+ return getFunctionType().getResults();
+ }
+
+ /// Returns the number of results of this function
+ unsigned getNumResults() {return getResultTypes().size();}
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+}
//===----------------------------------------------------------------------===//
// Test Traits
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list