[Mlir-commits] [mlir] [mlir] add optional type functor to call and function interfaces (PR #146979)

Gibran Essa llvmlistbot at llvm.org
Tue Jul 8 13:52:50 PDT 2025


https://github.com/gibrane updated https://github.com/llvm/llvm-project/pull/146979

>From ef05b4214273241697ddf72177487cd924933b40 Mon Sep 17 00:00:00 2001
From: Gibran Essa <gessa at nvidia.com>
Date: Wed, 2 Jul 2025 20:25:00 +0000
Subject: [PATCH] [mlir] add optional type functor to call and function
 interfaces

---
 mlir/include/mlir/IR/OpImplementation.h       | 24 +++---
 mlir/include/mlir/Interfaces/CallInterfaces.h | 38 +++++----
 .../mlir/Interfaces/FunctionImplementation.h  | 34 +++++---
 mlir/lib/AsmParser/Parser.cpp                 | 16 ++--
 mlir/lib/IR/AsmPrinter.cpp                    | 22 +++---
 mlir/lib/Interfaces/CallInterfaces.cpp        | 79 ++++++++++++++-----
 .../lib/Interfaces/FunctionImplementation.cpp | 46 +++++++----
 .../custom-type-parse-and-print.mlir          | 26 ++++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 34 ++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 36 +++++++++
 10 files changed, 272 insertions(+), 83 deletions(-)
 create mode 100644 mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d70aa346eaa1f..5232a31ff5c77 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,10 +463,11 @@ class OpAsmPrinter : public AsmPrinter {
   /// where location printing is controlled by the standard internal option.
   /// You may pass omitType=true to not print a type, and pass an empty
   /// attribute list if you don't care for attributes.
-  virtual void printRegionArgument(BlockArgument arg,
-                                   ArrayRef<NamedAttribute> argAttrs = {},
-                                   bool omitType = false) = 0;
-
+  /// You can override default type printing behavior with the typePrinter arg.
+  virtual void printRegionArgument(
+      BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+      bool omitType = false,
+      function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) = 0;
   /// Print implementations for various things an operation contains.
   virtual void printOperand(Value value) = 0;
   virtual void printOperand(Value value, raw_ostream &os) = 0;
@@ -1701,13 +1702,18 @@ class OpAsmParser : public AsmParser {
   ///
   /// If `allowType` is false or `allowAttrs` are false then the respective
   /// parts of the grammar are not parsed.
-  virtual ParseResult parseArgument(Argument &result, bool allowType = false,
-                                    bool allowAttrs = false) = 0;
+  /// You can override default type parsing behavior with the typeParser arg.
+  virtual ParseResult
+  parseArgument(Argument &result, bool allowType = false,
+                bool allowAttrs = false,
+                function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+                    nullptr) = 0;
 
   /// Parse a single argument if present.
-  virtual OptionalParseResult
-  parseOptionalArgument(Argument &result, bool allowType = false,
-                        bool allowAttrs = false) = 0;
+  virtual OptionalParseResult parseOptionalArgument(
+      Argument &result, bool allowType = false, bool allowAttrs = false,
+      function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+          nullptr) = 0;
 
   /// Parse zero or more arguments with a specified surrounding delimiter.
   virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 2bf3a3ca5f8a8..66f0287471da5 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -37,6 +37,8 @@ Operation *resolveCallable(CallOpInterface call,
                            SymbolTableCollection *symbolTable = nullptr);
 
 /// Parse a function or call result list.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
 ///
 ///   function-result-list ::= function-result-list-parens
 ///                          | non-function-type
@@ -45,31 +47,39 @@ Operation *resolveCallable(CallOpInterface call,
 ///   function-result-list-no-parens ::= function-result (`,` function-result)*
 ///   function-result ::= type attribute-dict?
 ///
-ParseResult
-parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
-                        SmallVectorImpl<DictionaryAttr> &resultAttrs);
+ParseResult parseFunctionResultList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Parses a function signature using `parser`. This does not deal with function
 /// signatures containing SSA region arguments (to parse these signatures, use
 /// function_interface_impl::parseFunctionSignature). When
 /// `mustParseEmptyResult`, `-> ()` is expected when there is no result type.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
+///
 ///
 ///   no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)`
 ///                               -> function-result-list
 ///   no-ssa-function-arg-list  ::= no-ssa-function-arg
 ///                               (`,` no-ssa-function-arg)*
 ///   no-ssa-function-arg       ::= type attribute-dict?
-ParseResult parseFunctionSignature(OpAsmParser &parser,
-                                   SmallVectorImpl<Type> &argTypes,
-                                   SmallVectorImpl<DictionaryAttr> &argAttrs,
-                                   SmallVectorImpl<Type> &resultTypes,
-                                   SmallVectorImpl<DictionaryAttr> &resultAttrs,
-                                   bool mustParseEmptyResult = true);
+ParseResult parseFunctionSignature(
+    OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+    SmallVectorImpl<DictionaryAttr> &argAttrs,
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    bool mustParseEmptyResult = true,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Print a function signature for a call or callable operation. If a body
 /// region is provided, the SSA arguments are printed in the signature. When
 /// `printEmptyResult` is false, `-> function-result-list` is omitted when
 /// `resultTypes` is empty.
+/// You can override the default type printing behavior using the typePrinter
+/// parameter.
+///
 ///
 ///   function-signature     ::= ssa-function-signature
 ///                            | no-ssa-function-signature
@@ -77,11 +87,11 @@ ParseResult parseFunctionSignature(OpAsmParser &parser,
 ///                            -> function-result-list
 ///   ssa-function-arg-list  ::= ssa-function-arg (`,` ssa-function-arg)*
 ///   ssa-function-arg       ::= `%`name `:` type attribute-dict?
-void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes,
-                            ArrayAttr argAttrs, bool isVariadic,
-                            TypeRange resultTypes, ArrayAttr resultAttrs,
-                            Region *body = nullptr,
-                            bool printEmptyResult = true);
+void printFunctionSignature(
+    OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
+    TypeRange resultTypes, ArrayAttr resultAttrs, Region *body = nullptr,
+    bool printEmptyResult = true,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
 
 /// Adds argument and result attributes, provided as `argAttrs` and
 /// `resultAttrs` arguments, to the list of operation attributes in `result`.
diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index 374c2c534f87d..de89a6bc0d50a 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -45,11 +45,14 @@ using FuncTypeBuilder = function_ref<Type(
 /// indicates whether functions with variadic arguments are supported. The
 /// trailing arguments are populated by this function with names, types,
 /// attributes and locations of the arguments and those of the results.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
 ParseResult parseFunctionSignatureWithArguments(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
     SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs);
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Parser implementation for function-like operations.  Uses
 /// `funcTypeBuilder` to construct the custom function type given lists of
@@ -59,25 +62,32 @@ ParseResult parseFunctionSignatureWithArguments(
 /// whether the function is variadic.  If the builder returns a null type,
 /// `result` will not contain the `type` attribute.  The caller can then add a
 /// type, report the error or delegate the reporting to the op's verifier.
-ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
-                            bool allowVariadic, StringAttr typeAttrName,
-                            FuncTypeBuilder funcTypeBuilder,
-                            StringAttr argAttrsName, StringAttr resAttrsName);
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
+ParseResult parseFunctionOp(
+    OpAsmParser &parser, OperationState &result, bool allowVariadic,
+    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
+    StringAttr argAttrsName, StringAttr resAttrsName,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
-                     StringRef typeAttrName, StringAttr argAttrsName,
-                     StringAttr resAttrsName);
+/// You can override the default type printing behavior using the typePrinter
+/// parameter.
+void printFunctionOp(
+    OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+    StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
-inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
-                                   ArrayRef<Type> argTypes, bool isVariadic,
-                                   ArrayRef<Type> resultTypes) {
+inline void printFunctionSignature(
+    OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
+    bool isVariadic, ArrayRef<Type> resultTypes,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) {
   call_interface_impl::printFunctionSignature(
       p, argTypes, op.getArgAttrsAttr(), isVariadic, resultTypes,
       op.getResAttrsAttr(), &op->getRegion(0),
-      /*printEmptyResult=*/false);
+      /*printEmptyResult=*/false, typePrinter);
 }
 
 /// Prints the list of function prefixed with the "attributes" keyword. The
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 756d3d01a4534..06282f648549f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1830,10 +1830,14 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
   /// If `allowType` is false or `allowAttrs` are false then the respective
   /// parts of the grammar are not parsed.
   ParseResult parseArgument(Argument &result, bool allowType = false,
-                            bool allowAttrs = false) override {
+                            bool allowAttrs = false,
+                            function_ref<ParseResult(OpAsmParser &, Type &)>
+                                typeParser = nullptr) override {
     NamedAttrList attrs;
     if (parseOperand(result.ssaName, /*allowResultNumber=*/false) ||
-        (allowType && parseColonType(result.type)) ||
+        (allowType && !typeParser && parseColonType(result.type)) ||
+        (allowType && typeParser &&
+         (parseColon() || typeParser(*this, result.type))) ||
         (allowAttrs && parseOptionalAttrDict(attrs)) ||
         parseOptionalLocationSpecifier(result.sourceLoc))
       return failure();
@@ -1842,10 +1846,12 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
   }
 
   /// Parse a single argument if present.
-  OptionalParseResult parseOptionalArgument(Argument &result, bool allowType,
-                                            bool allowAttrs) override {
+  OptionalParseResult parseOptionalArgument(
+      Argument &result, bool allowType, bool allowAttrs,
+      function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+          nullptr) override {
     if (parser.getToken().is(Token::percent_identifier))
-      return parseArgument(result, allowType, allowAttrs);
+      return parseArgument(result, allowType, allowAttrs, typeParser);
     return std::nullopt;
   }
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad290a1981..71a372369b4bf 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -783,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
       print(&b);
   }
 
-  void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
-                           bool omitType) override {
-    printType(arg.getType());
+  void printRegionArgument(
+      BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+      function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override {
+    typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
     // Visit the argument location.
     if (printerFlags.shouldPrintDebugInfo())
       // TODO: Allow deferring argument locations.
@@ -3295,9 +3296,10 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
   /// where location printing is controlled by the standard internal option.
   /// You may pass omitType=true to not print a type, and pass an empty
   /// attribute list if you don't care for attributes.
-  void printRegionArgument(BlockArgument arg,
-                           ArrayRef<NamedAttribute> argAttrs = {},
-                           bool omitType = false) override;
+  void printRegionArgument(
+      BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+      bool omitType = false,
+      function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override;
 
   /// Print the ID for the given value.
   void printOperand(Value value) override { printValueID(value); }
@@ -3545,13 +3547,13 @@ void OperationPrinter::printResourceFileMetadata(
 /// where location printing is controlled by the standard internal option.
 /// You may pass omitType=true to not print a type, and pass an empty
 /// attribute list if you don't care for attributes.
-void OperationPrinter::printRegionArgument(BlockArgument arg,
-                                           ArrayRef<NamedAttribute> argAttrs,
-                                           bool omitType) {
+void OperationPrinter::printRegionArgument(
+    BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
   printOperand(arg);
   if (!omitType) {
     os << ": ";
-    printType(arg.getType());
+    typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
   }
   printOptionalAttrDict(argAttrs);
   // TODO: We should allow location aliases on block arguments.
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index e8ed4b339a0cb..a08338e514a08 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -15,15 +15,24 @@ using namespace mlir;
 // Argument and result attributes utilities
 //===----------------------------------------------------------------------===//
 
-static ParseResult
-parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
-                     SmallVectorImpl<DictionaryAttr> &attrs) {
+static inline ParseResult defaultTypeParser(OpAsmParser &parser, Type &ty) {
+  return parser.parseType(ty);
+}
+
+static inline void defaultTypePrinter(OpAsmPrinter &printer, Type ty) {
+  printer << ty;
+}
+
+static ParseResult parseTypeAndAttrList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &types,
+    SmallVectorImpl<DictionaryAttr> &attrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
   // Parse individual function results.
   return parser.parseCommaSeparatedList([&]() -> ParseResult {
     types.emplace_back();
     attrs.emplace_back();
     NamedAttrList attrList;
-    if (parser.parseType(types.back()) ||
+    if (typeParser(parser, types.back()) ||
         parser.parseOptionalAttrDict(attrList))
       return failure();
     attrs.back() = attrList.getDictionary(parser.getContext());
@@ -33,12 +42,16 @@ parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
 
 ParseResult call_interface_impl::parseFunctionResultList(
     OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+  if (!typeParser)
+    typeParser = defaultTypeParser;
+
   if (failed(parser.parseOptionalLParen())) {
     // We already know that there is no `(`, so parse a type.
     // Because there is no `(`, it cannot be a function type.
     Type ty;
-    if (parser.parseType(ty))
+    if (typeParser(parser, ty))
       return failure();
     resultTypes.push_back(ty);
     resultAttrs.emplace_back();
@@ -48,7 +61,7 @@ ParseResult call_interface_impl::parseFunctionResultList(
   // Special case for an empty set of parens.
   if (succeeded(parser.parseOptionalRParen()))
     return success();
-  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs, typeParser))
     return failure();
   return parser.parseRParen();
 }
@@ -57,20 +70,24 @@ ParseResult call_interface_impl::parseFunctionSignature(
     OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
     SmallVectorImpl<DictionaryAttr> &argAttrs,
     SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+    SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+  if (!typeParser)
+    typeParser = defaultTypeParser;
+
   // Parse arguments.
   if (parser.parseLParen())
     return failure();
   if (failed(parser.parseOptionalRParen())) {
-    if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+    if (parseTypeAndAttrList(parser, argTypes, argAttrs, typeParser))
       return failure();
     if (parser.parseRParen())
       return failure();
   }
   // Parse results.
   if (succeeded(parser.parseOptionalArrow()))
-    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
-                                                        resultAttrs);
+    return call_interface_impl::parseFunctionResultList(
+        parser, resultTypes, resultAttrs, typeParser);
   if (mustParseEmptyResult)
     return failure();
   return success();
@@ -78,8 +95,12 @@ ParseResult call_interface_impl::parseFunctionSignature(
 
 /// Print a function result list. The provided `attrs` must either be null, or
 /// contain a set of DictionaryAttrs of the same arity as `types`.
-static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
-                                    ArrayAttr attrs) {
+static void
+printFunctionResultList(OpAsmPrinter &p, TypeRange types, ArrayAttr attrs,
+                        function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+  if (!typePrinter)
+    typePrinter = defaultTypePrinter;
+
   assert(!types.empty() && "Should not be called for empty result list.");
   assert((!attrs || attrs.size() == types.size()) &&
          "Invalid number of attributes.");
@@ -90,22 +111,41 @@ static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
   if (needsParens)
     os << '(';
   llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
-    p.printType(types[i]);
+    typePrinter(p, types[i]);
     if (attrs)
       p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
   });
   if (needsParens)
     os << ')';
 }
+static void
+printFunctionalType(OpAsmPrinter &p, TypeRange &inputs, TypeRange &results,
+                    function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+  p << '(';
+  llvm::interleaveComma(inputs, p, [&](Type ty) { typePrinter(p, ty); });
+  p << ')';
+
+  bool wrapped = !llvm::hasSingleElement(results) ||
+                 llvm::isa<FunctionType>((*results.begin()));
+  if (wrapped)
+    p << '(';
+  llvm::interleaveComma(results, p, [&](Type ty) { typePrinter(p, ty); });
+  if (wrapped)
+    p << ')';
+}
 
 void call_interface_impl::printFunctionSignature(
     OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
     TypeRange resultTypes, ArrayAttr resultAttrs, Region *body,
-    bool printEmptyResult) {
+    bool printEmptyResult,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+  if (!typePrinter)
+    typePrinter = defaultTypePrinter;
+
   bool isExternal = !body || body->empty();
   if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
       printEmptyResult) {
-    p.printFunctionalType(argTypes, resultTypes);
+    printFunctionalType(p, argTypes, resultTypes, typePrinter);
     return;
   }
 
@@ -118,9 +158,10 @@ void call_interface_impl::printFunctionSignature(
       ArrayRef<NamedAttribute> attrs;
       if (argAttrs)
         attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
-      p.printRegionArgument(body->getArgument(i), attrs);
+      p.printRegionArgument(body->getArgument(i), attrs, /*omitType=*/false,
+                            typePrinter);
     } else {
-      p.printType(argTypes[i]);
+      typePrinter(p, argTypes[i]);
       if (argAttrs)
         p.printOptionalAttrDict(
             llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
@@ -137,7 +178,7 @@ void call_interface_impl::printFunctionSignature(
 
   if (!resultTypes.empty()) {
     p << " -> ";
-    printFunctionResultList(p, resultTypes, resultAttrs);
+    printFunctionResultList(p, resultTypes, resultAttrs, typePrinter);
   } else if (printEmptyResult) {
     p << " -> ()";
   }
diff --git a/mlir/lib/Interfaces/FunctionImplementation.cpp b/mlir/lib/Interfaces/FunctionImplementation.cpp
index 90f32896e8181..89512e68bb22b 100644
--- a/mlir/lib/Interfaces/FunctionImplementation.cpp
+++ b/mlir/lib/Interfaces/FunctionImplementation.cpp
@@ -13,11 +13,18 @@
 
 using namespace mlir;
 
-static ParseResult
-parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
-                          SmallVectorImpl<OpAsmParser::Argument> &arguments,
-                          bool &isVariadic) {
+static inline ParseResult defaultTypeParser(OpAsmParser &parser, Type &ty) {
+  return parser.parseType(ty);
+}
+
+static inline void defaultTypePrinter(OpAsmPrinter &printer, Type ty) {
+  printer << ty;
+}
 
+static ParseResult parseFunctionArgumentList(
+    OpAsmParser &parser, bool allowVariadic,
+    SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
   // Parse the function arguments.  The argument list either has to consistently
   // have ssa-id's followed by types, or just be a type list.  It isn't ok to
   // sometimes have SSA ID's and sometimes not.
@@ -40,7 +47,7 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
         // Parse argument name if present.
         OpAsmParser::Argument argument;
         auto argPresent = parser.parseOptionalArgument(
-            argument, /*allowType=*/true, /*allowAttrs=*/true);
+            argument, /*allowType=*/true, /*allowAttrs=*/true, typeParser);
         if (argPresent.has_value()) {
           if (failed(argPresent.value()))
             return failure(); // Present but malformed.
@@ -59,7 +66,7 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
                                     "expected SSA identifier");
 
           NamedAttrList attrs;
-          if (parser.parseType(argument.type) ||
+          if (typeParser(parser, argument.type) ||
               parser.parseOptionalAttrDict(attrs) ||
               parser.parseOptionalLocationSpecifier(argument.sourceLoc))
             return failure();
@@ -74,19 +81,25 @@ ParseResult function_interface_impl::parseFunctionSignatureWithArguments(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
     SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
-  if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic))
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+  if (!typeParser)
+    typeParser = defaultTypeParser;
+
+  if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic,
+                                typeParser))
     return failure();
   if (succeeded(parser.parseOptionalArrow()))
-    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
-                                                        resultAttrs);
+    return call_interface_impl::parseFunctionResultList(
+        parser, resultTypes, resultAttrs, typeParser);
   return success();
 }
 
 ParseResult function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
     StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
-    StringAttr argAttrsName, StringAttr resAttrsName) {
+    StringAttr argAttrsName, StringAttr resAttrsName,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
   SmallVector<OpAsmParser::Argument> entryArgs;
   SmallVector<DictionaryAttr> resultAttrs;
   SmallVector<Type> resultTypes;
@@ -105,7 +118,8 @@ ParseResult function_interface_impl::parseFunctionOp(
   SMLoc signatureLocation = parser.getCurrentLocation();
   bool isVariadic = false;
   if (parseFunctionSignatureWithArguments(parser, allowVariadic, entryArgs,
-                                          isVariadic, resultTypes, resultAttrs))
+                                          isVariadic, resultTypes, resultAttrs,
+                                          typeParser))
     return failure();
 
   std::string errorMessage;
@@ -174,7 +188,11 @@ void function_interface_impl::printFunctionAttributes(
 
 void function_interface_impl::printFunctionOp(
     OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
-    StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) {
+    StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+  if (!typePrinter)
+    typePrinter = defaultTypePrinter;
+
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -188,7 +206,7 @@ void function_interface_impl::printFunctionOp(
 
   ArrayRef<Type> argTypes = op.getArgumentTypes();
   ArrayRef<Type> resultTypes = op.getResultTypes();
-  printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
+  printFunctionSignature(p, op, argTypes, isVariadic, resultTypes, typePrinter);
   printFunctionAttributes(
       p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName});
   // Print the body if this is not an external function.
diff --git a/mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir b/mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir
new file mode 100644
index 0000000000000..8a19b6327a122
--- /dev/null
+++ b/mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+//      CHECK: test.custom_type_format_func @single_arg_no_return(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: type:f32) {
+// CHECK-NEXT: }
+test.custom_type_format_func @single_arg_no_return(%arg0 : type:f32) {}
+
+//      CHECK: test.custom_type_format_func @multiple_arg_multiple_return(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: type:f32
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: type:f32)
+// CHECK-SAME: -> (type:f32, type:f32) {
+// CHECK-NEXT: }
+test.custom_type_format_func @multiple_arg_multiple_return(
+    %arg0 : type:f32, %arg1 : type:f32)
+    -> (type:f32, type:f32) {}
+
+//      CHECK: test.custom_type_format_func @no_block
+// CHECK-SAME: (type:f32, type:f32)
+// CHECK-SAME: -> (type:f32, type:f32)
+test.custom_type_format_func @no_block(%arg0 : type:f32, %arg1 : type:f32)
+    -> (type:f32, type:f32)
+
+//      CHECK: test.custom_type_format_func @one_return
+// CHECK-SAME: -> type:f32
+test.custom_type_format_func @one_return()
+    -> type:f32
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 3ab4ef2680978..030cdd4083c25 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -10,9 +10,12 @@
 #include "TestOps.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/Support/LogicalResult.h"
 
 using namespace mlir;
 using namespace test;
@@ -1454,3 +1457,34 @@ test::TestCreateTensorOp::getBufferType(
   return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
       getContext(), type.getShape(), type.getElementType(), nullptr));
 }
+
+//===----------------------------------------------------------------------===//
+// CustomTypeFormatFuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult CustomTypeFormatFuncOp::parse(OpAsmParser &parser,
+                                          OperationState &result) {
+  auto buildFuncType =
+      [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+         function_interface_impl::VariadicFlag,
+         std::string &) { return builder.getFunctionType(argTypes, results); };
+
+  auto typeParser = [&](OpAsmParser &p, Type &ty) {
+    if (p.parseKeyword("type") || p.parseColon())
+      return (ParseResult)failure();
+    return p.parseType(ty);
+  };
+
+  return function_interface_impl::parseFunctionOp(
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name),
+      typeParser);
+}
+
+void CustomTypeFormatFuncOp::print(OpAsmPrinter &p) {
+  auto typePrinter = [&](OpAsmPrinter &p, Type ty) { p << "type:" << ty; };
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName(), typePrinter);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 535f5e9b4a15d..e10ac79e5d1fc 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -692,6 +692,42 @@ def FoldToCallOp : TEST_Op<"fold_to_call_op"> {
   let hasCanonicalizer = 1;
 }
 
+def CustomTypeFormatFuncOp
+    : TEST_Op<"custom_type_format_func", [FunctionOpInterface, NoTerminator]> {
+
+  let arguments = (ins TypeAttrOf<FunctionType>:$function_type,
+      OptionalAttr<DictArrayAttr>:$arg_attrs,
+      OptionalAttr<DictArrayAttr>:$res_attrs);
+
+  let regions = (region AnyRegion:$body);
+  let extraClassDeclaration = [{
+    //===------------------------------------------------------------------===//
+    // FunctionOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Returns the region on the current operation that is callable. This may
+    /// return null in the case of an external callable object, e.g. an external
+    /// function.
+    ::mlir::Region *getCallableRegion() {
+      return nullptr;
+    }
+
+    /// Returns the argument types of this function.
+    ::mlir::ArrayRef<::mlir::Type> getArgumentTypes() {
+      return getFunctionType().getInputs();
+    }
+
+    /// Returns the result types of this function.
+    ::mlir::ArrayRef<::mlir::Type> getResultTypes() {
+      return getFunctionType().getResults();
+    }
+
+    /// Returns the number of results of this function
+    unsigned getNumResults() {return getResultTypes().size();}
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+}
 //===----------------------------------------------------------------------===//
 // Test Traits
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list