[Mlir-commits] [mlir] [mlir] add optional type functor to call and function interfaces (PR #146979)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 8 13:54:49 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Gibran Essa (gibrane)

<details>
<summary>Changes</summary>

adds type parsing functors to `call_interface_impl` and `function_interface_impl`

---

Patch is 31.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146979.diff


10 Files Affected:

- (modified) mlir/include/mlir/IR/OpImplementation.h (+15-9) 
- (modified) mlir/include/mlir/Interfaces/CallInterfaces.h (+24-14) 
- (modified) mlir/include/mlir/Interfaces/FunctionImplementation.h (+22-12) 
- (modified) mlir/lib/AsmParser/Parser.cpp (+11-5) 
- (modified) mlir/lib/IR/AsmPrinter.cpp (+12-10) 
- (modified) mlir/lib/Interfaces/CallInterfaces.cpp (+60-19) 
- (modified) mlir/lib/Interfaces/FunctionImplementation.cpp (+32-14) 
- (added) mlir/test/Interfaces/FunctionOpInterface/custom-type-parse-and-print.mlir (+26) 
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+34) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+36) 


``````````diff
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index d70aa346eaa1f..5232a31ff5c77 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,10 +463,11 @@ class OpAsmPrinter : public AsmPrinter {
   /// where location printing is controlled by the standard internal option.
   /// You may pass omitType=true to not print a type, and pass an empty
   /// attribute list if you don't care for attributes.
-  virtual void printRegionArgument(BlockArgument arg,
-                                   ArrayRef<NamedAttribute> argAttrs = {},
-                                   bool omitType = false) = 0;
-
+  /// You can override default type printing behavior with the typePrinter arg.
+  virtual void printRegionArgument(
+      BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+      bool omitType = false,
+      function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) = 0;
   /// Print implementations for various things an operation contains.
   virtual void printOperand(Value value) = 0;
   virtual void printOperand(Value value, raw_ostream &os) = 0;
@@ -1701,13 +1702,18 @@ class OpAsmParser : public AsmParser {
   ///
   /// If `allowType` is false or `allowAttrs` are false then the respective
   /// parts of the grammar are not parsed.
-  virtual ParseResult parseArgument(Argument &result, bool allowType = false,
-                                    bool allowAttrs = false) = 0;
+  /// You can override default type parsing behavior with the typeParser arg.
+  virtual ParseResult
+  parseArgument(Argument &result, bool allowType = false,
+                bool allowAttrs = false,
+                function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+                    nullptr) = 0;
 
   /// Parse a single argument if present.
-  virtual OptionalParseResult
-  parseOptionalArgument(Argument &result, bool allowType = false,
-                        bool allowAttrs = false) = 0;
+  virtual OptionalParseResult parseOptionalArgument(
+      Argument &result, bool allowType = false, bool allowAttrs = false,
+      function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+          nullptr) = 0;
 
   /// Parse zero or more arguments with a specified surrounding delimiter.
   virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 2bf3a3ca5f8a8..66f0287471da5 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -37,6 +37,8 @@ Operation *resolveCallable(CallOpInterface call,
                            SymbolTableCollection *symbolTable = nullptr);
 
 /// Parse a function or call result list.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
 ///
 ///   function-result-list ::= function-result-list-parens
 ///                          | non-function-type
@@ -45,31 +47,39 @@ Operation *resolveCallable(CallOpInterface call,
 ///   function-result-list-no-parens ::= function-result (`,` function-result)*
 ///   function-result ::= type attribute-dict?
 ///
-ParseResult
-parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
-                        SmallVectorImpl<DictionaryAttr> &resultAttrs);
+ParseResult parseFunctionResultList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Parses a function signature using `parser`. This does not deal with function
 /// signatures containing SSA region arguments (to parse these signatures, use
 /// function_interface_impl::parseFunctionSignature). When
 /// `mustParseEmptyResult`, `-> ()` is expected when there is no result type.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
+///
 ///
 ///   no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)`
 ///                               -> function-result-list
 ///   no-ssa-function-arg-list  ::= no-ssa-function-arg
 ///                               (`,` no-ssa-function-arg)*
 ///   no-ssa-function-arg       ::= type attribute-dict?
-ParseResult parseFunctionSignature(OpAsmParser &parser,
-                                   SmallVectorImpl<Type> &argTypes,
-                                   SmallVectorImpl<DictionaryAttr> &argAttrs,
-                                   SmallVectorImpl<Type> &resultTypes,
-                                   SmallVectorImpl<DictionaryAttr> &resultAttrs,
-                                   bool mustParseEmptyResult = true);
+ParseResult parseFunctionSignature(
+    OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+    SmallVectorImpl<DictionaryAttr> &argAttrs,
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    bool mustParseEmptyResult = true,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Print a function signature for a call or callable operation. If a body
 /// region is provided, the SSA arguments are printed in the signature. When
 /// `printEmptyResult` is false, `-> function-result-list` is omitted when
 /// `resultTypes` is empty.
+/// You can override the default type printing behavior using the typePrinter
+/// parameter.
+///
 ///
 ///   function-signature     ::= ssa-function-signature
 ///                            | no-ssa-function-signature
@@ -77,11 +87,11 @@ ParseResult parseFunctionSignature(OpAsmParser &parser,
 ///                            -> function-result-list
 ///   ssa-function-arg-list  ::= ssa-function-arg (`,` ssa-function-arg)*
 ///   ssa-function-arg       ::= `%`name `:` type attribute-dict?
-void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes,
-                            ArrayAttr argAttrs, bool isVariadic,
-                            TypeRange resultTypes, ArrayAttr resultAttrs,
-                            Region *body = nullptr,
-                            bool printEmptyResult = true);
+void printFunctionSignature(
+    OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
+    TypeRange resultTypes, ArrayAttr resultAttrs, Region *body = nullptr,
+    bool printEmptyResult = true,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
 
 /// Adds argument and result attributes, provided as `argAttrs` and
 /// `resultAttrs` arguments, to the list of operation attributes in `result`.
diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h
index 374c2c534f87d..de89a6bc0d50a 100644
--- a/mlir/include/mlir/Interfaces/FunctionImplementation.h
+++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h
@@ -45,11 +45,14 @@ using FuncTypeBuilder = function_ref<Type(
 /// indicates whether functions with variadic arguments are supported. The
 /// trailing arguments are populated by this function with names, types,
 /// attributes and locations of the arguments and those of the results.
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
 ParseResult parseFunctionSignatureWithArguments(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
     SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs);
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Parser implementation for function-like operations.  Uses
 /// `funcTypeBuilder` to construct the custom function type given lists of
@@ -59,25 +62,32 @@ ParseResult parseFunctionSignatureWithArguments(
 /// whether the function is variadic.  If the builder returns a null type,
 /// `result` will not contain the `type` attribute.  The caller can then add a
 /// type, report the error or delegate the reporting to the op's verifier.
-ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
-                            bool allowVariadic, StringAttr typeAttrName,
-                            FuncTypeBuilder funcTypeBuilder,
-                            StringAttr argAttrsName, StringAttr resAttrsName);
+/// You can override the default type parsing behavior using the typeParser
+/// parameter.
+ParseResult parseFunctionOp(
+    OpAsmParser &parser, OperationState &result, bool allowVariadic,
+    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
+    StringAttr argAttrsName, StringAttr resAttrsName,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser = nullptr);
 
 /// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
-                     StringRef typeAttrName, StringAttr argAttrsName,
-                     StringAttr resAttrsName);
+/// You can override the default type printing behavior using the typePrinter
+/// parameter.
+void printFunctionOp(
+    OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+    StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr);
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
-inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op,
-                                   ArrayRef<Type> argTypes, bool isVariadic,
-                                   ArrayRef<Type> resultTypes) {
+inline void printFunctionSignature(
+    OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
+    bool isVariadic, ArrayRef<Type> resultTypes,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) {
   call_interface_impl::printFunctionSignature(
       p, argTypes, op.getArgAttrsAttr(), isVariadic, resultTypes,
       op.getResAttrsAttr(), &op->getRegion(0),
-      /*printEmptyResult=*/false);
+      /*printEmptyResult=*/false, typePrinter);
 }
 
 /// Prints the list of function prefixed with the "attributes" keyword. The
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 756d3d01a4534..06282f648549f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1830,10 +1830,14 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
   /// If `allowType` is false or `allowAttrs` are false then the respective
   /// parts of the grammar are not parsed.
   ParseResult parseArgument(Argument &result, bool allowType = false,
-                            bool allowAttrs = false) override {
+                            bool allowAttrs = false,
+                            function_ref<ParseResult(OpAsmParser &, Type &)>
+                                typeParser = nullptr) override {
     NamedAttrList attrs;
     if (parseOperand(result.ssaName, /*allowResultNumber=*/false) ||
-        (allowType && parseColonType(result.type)) ||
+        (allowType && !typeParser && parseColonType(result.type)) ||
+        (allowType && typeParser &&
+         (parseColon() || typeParser(*this, result.type))) ||
         (allowAttrs && parseOptionalAttrDict(attrs)) ||
         parseOptionalLocationSpecifier(result.sourceLoc))
       return failure();
@@ -1842,10 +1846,12 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
   }
 
   /// Parse a single argument if present.
-  OptionalParseResult parseOptionalArgument(Argument &result, bool allowType,
-                                            bool allowAttrs) override {
+  OptionalParseResult parseOptionalArgument(
+      Argument &result, bool allowType, bool allowAttrs,
+      function_ref<ParseResult(OpAsmParser &, Type &)> typeParser =
+          nullptr) override {
     if (parser.getToken().is(Token::percent_identifier))
-      return parseArgument(result, allowType, allowAttrs);
+      return parseArgument(result, allowType, allowAttrs, typeParser);
     return std::nullopt;
   }
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad290a1981..71a372369b4bf 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -783,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
       print(&b);
   }
 
-  void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
-                           bool omitType) override {
-    printType(arg.getType());
+  void printRegionArgument(
+      BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+      function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override {
+    typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
     // Visit the argument location.
     if (printerFlags.shouldPrintDebugInfo())
       // TODO: Allow deferring argument locations.
@@ -3295,9 +3296,10 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
   /// where location printing is controlled by the standard internal option.
   /// You may pass omitType=true to not print a type, and pass an empty
   /// attribute list if you don't care for attributes.
-  void printRegionArgument(BlockArgument arg,
-                           ArrayRef<NamedAttribute> argAttrs = {},
-                           bool omitType = false) override;
+  void printRegionArgument(
+      BlockArgument arg, ArrayRef<NamedAttribute> argAttrs = {},
+      bool omitType = false,
+      function_ref<void(OpAsmPrinter &, Type)> typePrinter = nullptr) override;
 
   /// Print the ID for the given value.
   void printOperand(Value value) override { printValueID(value); }
@@ -3545,13 +3547,13 @@ void OperationPrinter::printResourceFileMetadata(
 /// where location printing is controlled by the standard internal option.
 /// You may pass omitType=true to not print a type, and pass an empty
 /// attribute list if you don't care for attributes.
-void OperationPrinter::printRegionArgument(BlockArgument arg,
-                                           ArrayRef<NamedAttribute> argAttrs,
-                                           bool omitType) {
+void OperationPrinter::printRegionArgument(
+    BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, bool omitType,
+    function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
   printOperand(arg);
   if (!omitType) {
     os << ": ";
-    printType(arg.getType());
+    typePrinter ? typePrinter(*this, arg.getType()) : printType(arg.getType());
   }
   printOptionalAttrDict(argAttrs);
   // TODO: We should allow location aliases on block arguments.
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index e8ed4b339a0cb..a08338e514a08 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -15,15 +15,24 @@ using namespace mlir;
 // Argument and result attributes utilities
 //===----------------------------------------------------------------------===//
 
-static ParseResult
-parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
-                     SmallVectorImpl<DictionaryAttr> &attrs) {
+static inline ParseResult defaultTypeParser(OpAsmParser &parser, Type &ty) {
+  return parser.parseType(ty);
+}
+
+static inline void defaultTypePrinter(OpAsmPrinter &printer, Type ty) {
+  printer << ty;
+}
+
+static ParseResult parseTypeAndAttrList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &types,
+    SmallVectorImpl<DictionaryAttr> &attrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
   // Parse individual function results.
   return parser.parseCommaSeparatedList([&]() -> ParseResult {
     types.emplace_back();
     attrs.emplace_back();
     NamedAttrList attrList;
-    if (parser.parseType(types.back()) ||
+    if (typeParser(parser, types.back()) ||
         parser.parseOptionalAttrDict(attrList))
       return failure();
     attrs.back() = attrList.getDictionary(parser.getContext());
@@ -33,12 +42,16 @@ parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
 
 ParseResult call_interface_impl::parseFunctionResultList(
     OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+    SmallVectorImpl<DictionaryAttr> &resultAttrs,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+  if (!typeParser)
+    typeParser = defaultTypeParser;
+
   if (failed(parser.parseOptionalLParen())) {
     // We already know that there is no `(`, so parse a type.
     // Because there is no `(`, it cannot be a function type.
     Type ty;
-    if (parser.parseType(ty))
+    if (typeParser(parser, ty))
       return failure();
     resultTypes.push_back(ty);
     resultAttrs.emplace_back();
@@ -48,7 +61,7 @@ ParseResult call_interface_impl::parseFunctionResultList(
   // Special case for an empty set of parens.
   if (succeeded(parser.parseOptionalRParen()))
     return success();
-  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs, typeParser))
     return failure();
   return parser.parseRParen();
 }
@@ -57,20 +70,24 @@ ParseResult call_interface_impl::parseFunctionSignature(
     OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
     SmallVectorImpl<DictionaryAttr> &argAttrs,
     SmallVectorImpl<Type> &resultTypes,
-    SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+    SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult,
+    function_ref<ParseResult(OpAsmParser &, Type &)> typeParser) {
+  if (!typeParser)
+    typeParser = defaultTypeParser;
+
   // Parse arguments.
   if (parser.parseLParen())
     return failure();
   if (failed(parser.parseOptionalRParen())) {
-    if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+    if (parseTypeAndAttrList(parser, argTypes, argAttrs, typeParser))
       return failure();
     if (parser.parseRParen())
       return failure();
   }
   // Parse results.
   if (succeeded(parser.parseOptionalArrow()))
-    return call_interface_impl::parseFunctionResultList(parser, resultTypes,
-                                                        resultAttrs);
+    return call_interface_impl::parseFunctionResultList(
+        parser, resultTypes, resultAttrs, typeParser);
   if (mustParseEmptyResult)
     return failure();
   return success();
@@ -78,8 +95,12 @@ ParseResult call_interface_impl::parseFunctionSignature(
 
 /// Print a function result list. The provided `attrs` must either be null, or
 /// contain a set of DictionaryAttrs of the same arity as `types`.
-static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
-                                    ArrayAttr attrs) {
+static void
+printFunctionResultList(OpAsmPrinter &p, TypeRange types, ArrayAttr attrs,
+                        function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+  if (!typePrinter)
+    typePrinter = defaultTypePrinter;
+
   assert(!types.empty() && "Should not be called for empty result list.");
   assert((!attrs || attrs.size() == types.size()) &&
          "Invalid number of attributes.");
@@ -90,22 +111,41 @@ static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
   if (needsParens)
     os << '(';
   llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
-    p.printType(types[i]);
+    typePrinter(p, types[i]);
     if (attrs)
       p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
   });
   if (needsParens)
     os << ')';
 }
+static void
+printFunctionalType(OpAsmPrinter &p, TypeRange &inputs, TypeRange &results,
+                    function_ref<void(OpAsmPrinter &, Type)> typePrinter) {
+  p << '(';
+  llvm::interleaveComma(inputs, p, [&](Type ty) { typePrinter(p, ty); });
+  p << ')';
+
+  bool wrapped = !llvm::hasSingleElement(results) ||
+                 llvm::isa<FunctionType>((*results.begin()));
+  if (wrapped)
+    p << '(';
+  llvm::interleaveComma(results, p, [&](Type ty) { typePrinter(p, ty); });
+  if (wrapped)
+    p << ')';
+}
 
 void call_interface_impl::printFunctionSignature(
     OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
     T...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/146979


More information about the Mlir-commits mailing list