[Mlir-commits] [mlir] [mlir] add optional type functor to call and function interfaces (PR #146979)

Gibran Essa llvmlistbot at llvm.org
Thu Jul 3 17:19:09 PDT 2025


https://github.com/gibrane created https://github.com/llvm/llvm-project/pull/146979

None

>From b4d8671055e492b689f30fa914ae747e2a2fd7ec Mon Sep 17 00:00:00 2001
From: Gibran Essa <gessa at nvidia.com>
Date: Wed, 2 Jul 2025 20:25:00 +0000
Subject: [PATCH] [mlir] add optional type functor to call and function
 interfaces

---
 mlir/include/mlir/IR/OpImplementation.h       | 22 +++---
 mlir/include/mlir/Interfaces/CallInterfaces.h | 30 +++----
 .../mlir/Interfaces/FunctionImplementation.h  | 28 ++++---
 mlir/lib/AsmParser/Parser.cpp                 | 16 ++--
 mlir/lib/IR/AsmPrinter.cpp                    | 22 +++---
 mlir/lib/Interfaces/CallInterfaces.cpp        | 79 ++++++++++++++-----
 .../lib/Interfaces/FunctionImplementation.cpp | 46 +++++++----
 .../custom-type-parse-and-print.mlir          | 26 ++++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 34 ++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 36 +++++++++
 10 files changed, 256 insertions(+), 83 deletions(-)
 create mode 100644 mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d70aa346eaa1f..6ebb021446e15 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,10 +463,10 @@ class OpAsmPrinter : public AsmPrinter {
   /// where location printing is controlled by the standard internal option.
   /// You may pass omitType=true to not print a type, and pass an empty
   /// attribute list if you don't care for attributes.
-  virtual void printRegionArgument(BlockArgument arg,
-                                   ArrayRef<NamedAttribute> argAttrs = {},
-                                   bool omitType = false) = 0;
-
+  virtual void printRegionArgument(
+      BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+      bool omitType = false,
+      function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) = 0;
   /// Print implementations for various things an operation contains.
   virtual void printOperand(Value value) = 0;
   virtual void printOperand(Value value, raw_ostream &os) = 0;
@@ -1701,13 +1701,17 @@ class OpAsmParser : public AsmParser {
   ///
   /// If `allowType` is false or `allowAttrs` are false then the respective
   /// parts of the grammar are not parsed.
-  virtual ParseResult parseArgument(Argument &result, bool allowType = false,
-                                    bool allowAttrs = false) = 0;
+  virtual ParseResult
+  parseArgument(Argument &result, bool allowType = false,
+                bool allowAttrs = false,
+                function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+                    nullptr) = 0;
 
   /// Parse a single argument if present.
-  virtual OptionalParseResult
-  parseOptionalArgument(Argument &result, bool allowType = false,
-                        bool allowAttrs = false) = 0;
+  virtual OptionalParseResult parseOptionalArgument(
+      Argument &result, bool allowType = false, bool allowAttrs = false,
+      function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+          nullptr) = 0;
 
   /// Parse zero or more arguments with a specified surrounding delimiter.
   virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 2bf3a3ca5f8a8..78aa9181931e0 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -45,9 +45,10 @@ Operation *resolveCallable(CallOpInterface call,
 ///   function-result-list-no-parens ::= function-result (`,` function-result)*
 ///   function-result ::= type attribute-dict?
 ///
-ParseResult
-parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
-                        SmallVectorImpl<DictionaryAttr> &resultAttrs);
+ParseResult parseFunctionResultList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Parses a function signature using `parser`. This does not deal with function
 /// signatures containing SSA region arguments (to parse these signatures, use
@@ -59,12 +60,13 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
 ///   no-ssa-function-arg-list  ::= no-ssa-function-arg
 ///                               (`,` no-ssa-function-arg)*
 ///   no-ssa-function-arg       ::= type attribute-dict?
-ParseResult parseFunctionSignature(OpAsmParser &parser,
-                                   SmallVectorImpl<Type> &argTypes,
-                                   SmallVectorImpl<DictionaryAttr> &argAttrs,
-                                   SmallVectorImpl<Type> &resultTypes,
-                                   SmallVectorImpl<DictionaryAttr> &resultAttrs,
-                                   bool mustParseEmptyResult = true);
+ParseResult parseFunctionSignature(
+    OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+    SmallVectorImpl<DictionaryAttr> &argAttrs,
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    bool mustParseEmptyResult = true,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Print a function signature for a call or callable operation. If a body
 /// region is provided, the SSA arguments are printed in the signature. When
@@ -77,11 +79,11 @@ ParseResult parseFunctionSignature(OpAsmParser &parser,
 ///                            -> function-result-list
 ///   ssa-function-arg-list  ::= ssa-function-arg (`,` ssa-function-arg)*
 ///   ssa-function-arg       ::= `%`name `:` type attribute-dict?
-void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes,
-                            ArrayAttr argAttrs, bool isVariadic,
-                            TypeRange resultTypes, ArrayAttr resultAttrs,
-                            Region *body = nullptr,
-                            bool printEmptyResult = true);
+void printFunctionSignature(
+    OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
+    TypeRange resultTypes, ArrayAttr resultAttrs, Region *body = nullptr,
+    bool printEmptyResult = true,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
 
 /// Adds argument and result attributes, provided as `argAttrs` and
 /// `resultAttrs` arguments, to the list of operation attributes in `result`.
diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index 374c2c534f87d..1e42e09936658 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -49,7 +49,8 @@ ParseResult parseFunctionSignatureWithArguments(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
     SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs);
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Parser implementation for function-like operations.  Uses
 /// `funcTypeBuilder` to construct the custom function type given lists of
@@ -59,25 +60,28 @@ ParseResult parseFunctionSignatureWithArguments(
 /// whether the function is variadic.  If the builder returns a null type,
 /// `result` will not contain the `type` attribute.  The caller can then add a
 /// type, report the error or delegate the reporting to the op's verifier.
-ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
-                            bool allowVariadic, StringAttr typeAttrName,
-                            FuncTypeBuilder funcTypeBuilder,
-                            StringAttr argAttrsName, StringAttr resAttrsName);
+ParseResult parseFunctionOp(
+    OpAsmParser &parser, OperationState &result, bool allowVariadic,
+    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
+    StringAttr argAttrsName, StringAttr resAttrsName,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
-                     StringRef typeAttrName, StringAttr argAttrsName,
-                     StringAttr resAttrsName);
+void printFunctionOp(
+    OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+    StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
-inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
-                                   ArrayRef<Type> argTypes, bool isVariadic,
-                                   ArrayRef<Type> resultTypes) {
+inline void printFunctionSignature(
+    OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
+    bool isVariadic, ArrayRef<Type> resultTypes,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) {
   call_interface_impl::printFunctionSignature(
       p, argTypes, op.getArgAttrsAttr(), isVariadic, resultTypes,
       op.getResAttrsAttr(), &op->getRegion(0),
-      /*printEmptyResult=*/false);
+      /*printEmptyResult=*/false, typePrinter);
 }
 
 /// Prints the list of function prefixed with the "attributes" keyword. The
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 756d3d01a4534..06282f648549f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1830,10 +1830,14 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
   /// If `allowType` is false or `allowAttrs` are false then the respective
   /// parts of the grammar are not parsed.
   ParseResult parseArgument(Argument &result, bool allowType = false,
-                            bool allowAttrs = false) override {
+                            bool allowAttrs = false,
+                            function_ref<ParseResult(OpAsmParser &, Type &)>
+                                typeParser = nullptr) override {
     NamedAttrList attrs;
     if (parseOperand(result.ssaName, /*allowResultNumber=*/false) ||
-        (allowType && parseColonType(result.type)) ||
+        (allowType && !typeParser && parseColonType(result.type)) ||
+        (allowType && typeParser &&
+         (parseColon() || typeParser(*this, result.type))) ||
         (allowAttrs && parseOptionalAttrDict(attrs)) ||
         parseOptionalLocationSpecifier(result.sourceLoc))
       return failure();
@@ -1842,10 +1846,12 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
   }
 
   /// Parse a single argument if present.
-  OptionalParseResult parseOptionalArgument(Argument &result, bool allowType,
-                                            bool allowAttrs) override {
+  OptionalParseResult parseOptionalArgument(
+      Argument &result, bool allowType, bool allowAttrs,
+      function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+          nullptr) override {
     if (parser.getToken().is(Token::percent_identifier))
-      return parseArgument(result, allowType, allowAttrs);
+      return parseArgument(result, allowType, allowAttrs, typeParser);
     return std::nullopt;
   }
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad290a1981..71a372369b4bf 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -783,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
       print(&b);
   }
 
-  void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
-                           bool omitType) override {
-    printType(arg.getType());
+  void printRegionArgument(
+      BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+      function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override {
+    typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
     // Visit the argument location.
     if (printerFlags.shouldPrintDebugInfo())
       // TODO: Allow deferring argument locations.
@@ -3295,9 +3296,10 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
   /// where location printing is controlled by the standard internal option.
   /// You may pass omitType=true to not print a type, and pass an empty
   /// attribute list if you don't care for attributes.
-  void printRegionArgument(BlockArgument arg,
-                           ArrayRef<NamedAttribute> argAttrs = {},
-                           bool omitType = false) override;
+  void printRegionArgument(
+      BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+      bool omitType = false,
+      function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override;
 
   /// Print the ID for the given value.
   void printOperand(Value value) override { printValueID(value); }
@@ -3545,13 +3547,13 @@ void OperationPrinter::printResourceFileMetadata(
 /// where location printing is controlled by the standard internal option.
 /// You may pass omitType=true to not print a type, and pass an empty
 /// attribute list if you don't care for attributes.
-void OperationPrinter::printRegionArgument(BlockArgument arg,
-                                           ArrayRef<NamedAttribute> argAttrs,
-                                           bool omitType) {
+void OperationPrinter::printRegionArgument(
+    BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
   printOperand(arg);
   if (!omitType) {
     os << ": ";
-    printType(arg.getType());
+    typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
   }
   printOptionalAttrDict(argAttrs);
   // TODO: We should allow location aliases on block arguments.
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index e8ed4b339a0cb..a08338e514a08 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -15,15 +15,24 @@ using namespace mlir;
 // Argument and result attributes utilities
 //===----------------------------------------------------------------------===//
 
-static ParseResult
-parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
-                     SmallVectorImpl<DictionaryAttr> &attrs) {
+static inline ParseResult defaultTypeParser(OpAsmParser &parser, Type &ty) {
+  return parser.parseType(ty);
+}
+
+static inline void defaultTypePrinter(OpAsmPrinter &printer, Type ty) {
+  printer << ty;
+}
+
+static ParseResult parseTypeAndAttrList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &types,
+    SmallVectorImpl<DictionaryAttr> &attrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
   // Parse individual function results.
   return parser.parseCommaSeparatedList([&]() -> ParseResult {
     types.emplace_back();
     attrs.emplace_back();
     NamedAttrList attrList;
-    if (parser.parseType(types.back()) ||
+    if (typeParser(parser, types.back()) ||
         parser.parseOptionalAttrDict(attrList))
       return failure();
     attrs.back() = attrList.getDictionary(parser.getContext());
@@ -33,12 +42,16 @@ parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
 
 ParseResult call_interface_impl::parseFunctionResultList(
     OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+  if (!typeParser)
+    typeParser = defaultTypeParser;
+
   if (failed(parser.parseOptionalLParen())) {
     // We already know that there is no `(`, so parse a type.
     // Because there is no `(`, it cannot be a function type.
     Type ty;
-    if (parser.parseType(ty))
+    if (typeParser(parser, ty))
       return failure();
     resultTypes.push_back(ty);
     resultAttrs.emplace_back();
@@ -48,7 +61,7 @@ ParseResult call_interface_impl::parseFunctionResultList(
   // Special case for an empty set of parens.
   if (succeeded(parser.parseOptionalRParen()))
     return success();
-  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs, typeParser))
     return failure();
   return parser.parseRParen();
 }
@@ -57,20 +70,24 @@ ParseResult call_interface_impl::parseFunctionSignature(
     OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
     SmallVectorImpl<DictionaryAttr> &argAttrs,
     SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+    SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+  if (!typeParser)
+    typeParser = defaultTypeParser;
+
   // Parse arguments.
   if (parser.parseLParen())
     return failure();
   if (failed(parser.parseOptionalRParen())) {
-    if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+    if (parseTypeAndAttrList(parser, argTypes, argAttrs, typeParser))
       return failure();
     if (parser.parseRParen())
       return failure();
   }
   // Parse results.
   if (succeeded(parser.parseOptionalArrow()))
-    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
-                                                        resultAttrs);
+    return call_interface_impl::parseFunctionResultList(
+        parser, resultTypes, resultAttrs, typeParser);
   if (mustParseEmptyResult)
     return failure();
   return success();
@@ -78,8 +95,12 @@ ParseResult call_interface_impl::parseFunctionSignature(
 
 /// Print a function result list. The provided `attrs` must either be null, or
 /// contain a set of DictionaryAttrs of the same arity as `types`.
-static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
-                                    ArrayAttr attrs) {
+static void
+printFunctionResultList(OpAsmPrinter &p, TypeRange types, ArrayAttr attrs,
+                        function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+  if (!typePrinter)
+    typePrinter = defaultTypePrinter;
+
   assert(!types.empty() && "Should not be called for empty result list.");
   assert((!attrs || attrs.size() == types.size()) &&
          "Invalid number of attributes.");
@@ -90,22 +111,41 @@ static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
   if (needsParens)
     os << '(';
   llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
-    p.printType(types[i]);
+    typePrinter(p, types[i]);
     if (attrs)
       p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
   });
   if (needsParens)
     os << ')';
 }
+static void
+printFunctionalType(OpAsmPrinter &p, TypeRange &inputs, TypeRange &results,
+                    function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+  p << '(';
+  llvm::interleaveComma(inputs, p, [&](Type ty) { typePrinter(p, ty); });
+  p << ')';
+
+  bool wrapped = !llvm::hasSingleElement(results) ||
+                 llvm::isa<FunctionType>((*results.begin()));
+  if (wrapped)
+    p << '(';
+  llvm::interleaveComma(results, p, [&](Type ty) { typePrinter(p, ty); });
+  if (wrapped)
+    p << ')';
+}
 
 void call_interface_impl::printFunctionSignature(
     OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
     TypeRange resultTypes, ArrayAttr resultAttrs, Region *body,
-    bool printEmptyResult) {
+    bool printEmptyResult,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+  if (!typePrinter)
+    typePrinter = defaultTypePrinter;
+
   bool isExternal = !body || body->empty();
   if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
       printEmptyResult) {
-    p.printFunctionalType(argTypes, resultTypes);
+    printFunctionalType(p, argTypes, resultTypes, typePrinter);
     return;
   }
 
@@ -118,9 +158,10 @@ void call_interface_impl::printFunctionSignature(
       ArrayRef<NamedAttribute> attrs;
       if (argAttrs)
         attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
-      p.printRegionArgument(body->getArgument(i), attrs);
+      p.printRegionArgument(body->getArgument(i), attrs, /*omitType=*/false,
+                            typePrinter);
     } else {
-      p.printType(argTypes[i]);
+      typePrinter(p, argTypes[i]);
       if (argAttrs)
         p.printOptionalAttrDict(
             llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
@@ -137,7 +178,7 @@ void call_interface_impl::printFunctionSignature(
 
   if (!resultTypes.empty()) {
     p << " -> ";
-    printFunctionResultList(p, resultTypes, resultAttrs);
+    printFunctionResultList(p, resultTypes, resultAttrs, typePrinter);
   } else if (printEmptyResult) {
     p << " -> ()";
   }
diff --git a/mlir/lib/Interfaces/FunctionImplementation.cpp b/mlir/lib/Interfaces/FunctionImplementation.cpp
index 90f32896e8181..89512e68bb22b 100644
--- a/mlir/lib/Interfaces/FunctionImplementation.cpp
+++ b/mlir/lib/Interfaces/FunctionImplementation.cpp
@@ -13,11 +13,18 @@
 
 using namespace mlir;
 
-static ParseResult
-parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
-                          SmallVectorImpl<OpAsmParser::Argument> &arguments,
-                          bool &isVariadic) {
+static inline ParseResult defaultTypeParser(OpAsmParser &parser, Type &ty) {
+  return parser.parseType(ty);
+}
+
+static inline void defaultTypePrinter(OpAsmPrinter &printer, Type ty) {
+  printer << ty;
+}
 
+static ParseResult parseFunctionArgumentList(
+    OpAsmParser &parser, bool allowVariadic,
+    SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
   // Parse the function arguments.  The argument list either has to consistently
   // have ssa-id's followed by types, or just be a type list.  It isn't ok to
   // sometimes have SSA ID's and sometimes not.
@@ -40,7 +47,7 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
         // Parse argument name if present.
         OpAsmParser::Argument argument;
         auto argPresent = parser.parseOptionalArgument(
-            argument, /*allowType=*/true, /*allowAttrs=*/true);
+            argument, /*allowType=*/true, /*allowAttrs=*/true, typeParser);
         if (argPresent.has_value()) {
           if (failed(argPresent.value()))
             return failure(); // Present but malformed.
@@ -59,7 +66,7 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
                                     "expected SSA identifier");
 
           NamedAttrList attrs;
-          if (parser.parseType(argument.type) ||
+          if (typeParser(parser, argument.type) ||
               parser.parseOptionalAttrDict(attrs) ||
               parser.parseOptionalLocationSpecifier(argument.sourceLoc))
             return failure();
@@ -74,19 +81,25 @@ ParseResult function_interface_impl::parseFunctionSignatureWithArguments(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
     SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
-  if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic))
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+  if (!typeParser)
+    typeParser = defaultTypeParser;
+
+  if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic,
+                                typeParser))
     return failure();
   if (succeeded(parser.parseOptionalArrow()))
-    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
-                                                        resultAttrs);
+    return call_interface_impl::parseFunctionResultList(
+        parser, resultTypes, resultAttrs, typeParser);
   return success();
 }
 
 ParseResult function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
     StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
-    StringAttr argAttrsName, StringAttr resAttrsName) {
+    StringAttr argAttrsName, StringAttr resAttrsName,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
   SmallVector<OpAsmParser::Argument> entryArgs;
   SmallVector<DictionaryAttr> resultAttrs;
   SmallVector<Type> resultTypes;
@@ -105,7 +118,8 @@ ParseResult function_interface_impl::parseFunctionOp(
   SMLoc signatureLocation = parser.getCurrentLocation();
   bool isVariadic = false;
   if (parseFunctionSignatureWithArguments(parser, allowVariadic, entryArgs,
-                                          isVariadic, resultTypes, resultAttrs))
+                                          isVariadic, resultTypes, resultAttrs,
+                                          typeParser))
     return failure();
 
   std::string errorMessage;
@@ -174,7 +188,11 @@ void function_interface_impl::printFunctionAttributes(
 
 void function_interface_impl::printFunctionOp(
     OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
-    StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) {
+    StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+  if (!typePrinter)
+    typePrinter = defaultTypePrinter;
+
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -188,7 +206,7 @@ void function_interface_impl::printFunctionOp(
 
   ArrayRef<Type> argTypes = op.getArgumentTypes();
   ArrayRef<Type> resultTypes = op.getResultTypes();
-  printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
+  printFunctionSignature(p, op, argTypes, isVariadic, resultTypes, typePrinter);
   printFunctionAttributes(
       p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName});
   // Print the body if this is not an external function.
diff --git a/mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir b/mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir
new file mode 100644
index 0000000000000..8a19b6327a122
--- /dev/null
+++ b/mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+//      CHECK: test.custom_type_format_func @single_arg_no_return(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: type:f32) {
+// CHECK-NEXT: }
+test.custom_type_format_func @single_arg_no_return(%arg0 : type:f32) {}
+
+//      CHECK: test.custom_type_format_func @multiple_arg_multiple_return(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: type:f32
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: type:f32)
+// CHECK-SAME: -> (type:f32, type:f32) {
+// CHECK-NEXT: }
+test.custom_type_format_func @multiple_arg_multiple_return(
+    %arg0 : type:f32, %arg1 : type:f32)
+    -> (type:f32, type:f32) {}
+
+//      CHECK: test.custom_type_format_func @no_block
+// CHECK-SAME: (type:f32, type:f32)
+// CHECK-SAME: -> (type:f32, type:f32)
+test.custom_type_format_func @no_block(%arg0 : type:f32, %arg1 : type:f32)
+    -> (type:f32, type:f32)
+
+//      CHECK: test.custom_type_format_func @one_return
+// CHECK-SAME: -> type:f32
+test.custom_type_format_func @one_return()
+    -> type:f32
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 3ab4ef2680978..030cdd4083c25 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -10,9 +10,12 @@
 #include "TestOps.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/Support/LogicalResult.h"
 
 using namespace mlir;
 using namespace test;
@@ -1454,3 +1457,34 @@ test::TestCreateTensorOp::getBufferType(
   return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
       getContext(), type.getShape(), type.getElementType(), nullptr));
 }
+
+//===----------------------------------------------------------------------===//
+// CustomTypeFormatFuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult CustomTypeFormatFuncOp::parse(OpAsmParser &parser,
+                                          OperationState &result) {
+  auto buildFuncType =
+      [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+         function_interface_impl::VariadicFlag,
+         std::string &) { return builder.getFunctionType(argTypes, results); };
+
+  auto typeParser = [&](OpAsmParser &p, Type &ty) {
+    if (p.parseKeyword("type") || p.parseColon())
+      return (ParseResult)failure();
+    return p.parseType(ty);
+  };
+
+  return function_interface_impl::parseFunctionOp(
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name),
+      typeParser);
+}
+
+void CustomTypeFormatFuncOp::print(OpAsmPrinter &p) {
+  auto typePrinter = [&](OpAsmPrinter &p, Type ty) { p << "type:" << ty; };
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName(), typePrinter);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5320ba4ea3829..604b3f0683558 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -692,6 +692,42 @@ def FoldToCallOp : TEST_Op<"fold_to_call_op"> {
   let hasCanonicalizer = 1;
 }
 
+def CustomTypeFormatFuncOp
+    : TEST_Op<"custom_type_format_func", [FunctionOpInterface, NoTerminator]> {
+
+  let arguments = (ins TypeAttrOf<FunctionType>:$function_type,
+      OptionalAttr<DictArrayAttr>:$arg_attrs,
+      OptionalAttr<DictArrayAttr>:$res_attrs);
+
+  let regions = (region AnyRegion:$body);
+  let extraClassDeclaration = [{
+    //===------------------------------------------------------------------===//
+    // FunctionOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Returns the region on the current operation that is callable. This may
+    /// return null in the case of an external callable object, e.g. an external
+    /// function.
+    ::mlir::Region *getCallableRegion() {
+      return nullptr;
+    }
+
+    /// Returns the argument types of this function.
+    ::mlir::ArrayRef<::mlir::Type> getArgumentTypes() {
+      return getFunctionType().getInputs();
+    }
+
+    /// Returns the result types of this function.
+    ::mlir::ArrayRef<::mlir::Type> getResultTypes() {
+      return getFunctionType().getResults();
+    }
+
+    /// Returns the number of results of this function
+    unsigned getNumResults() {return getResultTypes().size();}
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+}
 //===----------------------------------------------------------------------===//
 // Test Traits
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list