[Mlir-commits] [mlir] fbc253f - [mlir] FunctionOpInterface: make get/setFunctionType interface methods

Jeff Niu llvmlistbot at llvm.org
Thu Dec 8 11:32:46 PST 2022


Author: Jeff Niu
Date: 2022-12-08T11:32:27-08:00
New Revision: fbc253fe81da4e1d6bfa2519e01e03f21d8c40a8

URL: https://github.com/llvm/llvm-project/commit/fbc253fe81da4e1d6bfa2519e01e03f21d8c40a8
DIFF: https://github.com/llvm/llvm-project/commit/fbc253fe81da4e1d6bfa2519e01e03f21d8c40a8.diff

LOG: [mlir] FunctionOpInterface: make get/setFunctionType interface methods

This patch removes the concept of a `function_type`-named type attribute
as a requirement for implementors of FunctionOpInterface. Instead, this
type should be provided through two interface methods, `getFunctionType`
and `setFunctionTypeAttr` (*Attr because functions may use different
concrete function types), which should be automatically implemented by
ODS for ops that define a `$function_type` attribute.

This also allows FunctionOpInterface to materialize function types if
they don't carry them in an attribute, for example.

Importantly, all the function "helper" still accept an attribute name to
use in parsing and printing functions, for example.

Reviewed By: rriddle, lattner

Differential Revision: https://reviews.llvm.org/D139447

Added: 
    

Modified: 
    mlir/examples/toy/Ch2/mlir/Dialect.cpp
    mlir/examples/toy/Ch3/mlir/Dialect.cpp
    mlir/examples/toy/Ch4/mlir/Dialect.cpp
    mlir/examples/toy/Ch5/mlir/Dialect.cpp
    mlir/examples/toy/Ch6/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/include/mlir/IR/FunctionImplementation.h
    mlir/include/mlir/IR/FunctionInterfaces.h
    mlir/include/mlir/IR/FunctionInterfaces.td
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/Func/IR/FuncOps.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
    mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/IR/FunctionImplementation.cpp
    mlir/lib/IR/FunctionInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
index dbc1efb3d06be..ac12c5c0ac512 100644
--- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
@@ -211,7 +211,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
index 50e2dfc7f4a3e..75cb57ebd7b17 100644
--- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
@@ -198,7 +198,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index 0a6195b12d5d4..2d5a369630a0d 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -287,7 +287,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index f236a1ffe0e5a..280bf3122fbd5 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -287,7 +287,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index f236a1ffe0e5a..280bf3122fbd5 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -287,7 +287,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index cc66a5d44b5f4..b0d213027f4c9 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -314,7 +314,8 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h
index 5265f781d1a77..f4c0cc03050fe 100644
--- a/mlir/include/mlir/IR/FunctionImplementation.h
+++ b/mlir/include/mlir/IR/FunctionImplementation.h
@@ -69,17 +69,19 @@ Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
 
 /// Parser implementation for function-like operations.  Uses
 /// `funcTypeBuilder` to construct the custom function type given lists of
-/// input and output types.  If `allowVariadic` is set, the parser will accept
+/// input and output types. The parser sets the `typeAttrName` attribute to the
+/// resulting function type. If `allowVariadic` is set, the parser will accept
 /// trailing ellipsis in the function signature and indicate to the builder
 /// whether the function is variadic.  If the builder returns a null type,
 /// `result` will not contain the `type` attribute.  The caller can then add a
 /// type, report the error or delegate the reporting to the op's verifier.
 ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
-                            bool allowVariadic,
+                            bool allowVariadic, StringAttr typeAttrName,
                             FuncTypeBuilder funcTypeBuilder);
 
 /// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic);
+void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
+                     StringRef typeAttrName);
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
@@ -92,8 +94,7 @@ void printFunctionSignature(OpAsmPrinter &p, Operation *op,
 /// function-like operation internally are not printed. Nothing is printed
 /// if all attributes are elided. Assumes `op` is a FunctionOpInterface and
 /// has passed verification.
-void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
-                             unsigned numResults,
+void printFunctionAttributes(OpAsmPrinter &p, Operation *op,
                              ArrayRef<StringRef> elided = {});
 
 } // namespace function_interface_impl

diff  --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h
index 23fd884d97f14..bc2ec4751c582 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.h
+++ b/mlir/include/mlir/IR/FunctionInterfaces.h
@@ -22,12 +22,10 @@
 #include "llvm/ADT/SmallString.h"
 
 namespace mlir {
+class FunctionOpInterface;
 
 namespace function_interface_impl {
 
-/// Return the name of the attribute used for function types.
-inline StringRef getTypeAttrName() { return "function_type"; }
-
 /// Return the name of the attribute used for function argument attributes.
 inline StringRef getArgDictAttrName() { return "arg_attrs"; }
 
@@ -72,28 +70,29 @@ inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
 }
 
 /// Insert the specified arguments and update the function type attribute.
-void insertFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
-                             TypeRange argTypes,
+void insertFunctionArguments(FunctionOpInterface op,
+                             ArrayRef<unsigned> argIndices, TypeRange argTypes,
                              ArrayRef<DictionaryAttr> argAttrs,
                              ArrayRef<Location> argLocs,
                              unsigned originalNumArgs, Type newType);
 
 /// Insert the specified results and update the function type attribute.
-void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
+void insertFunctionResults(FunctionOpInterface op,
+                           ArrayRef<unsigned> resultIndices,
                            TypeRange resultTypes,
                            ArrayRef<DictionaryAttr> resultAttrs,
                            unsigned originalNumResults, Type newType);
 
 /// Erase the specified arguments and update the function type attribute.
-void eraseFunctionArguments(Operation *op, const BitVector &argIndices,
+void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices,
                             Type newType);
 
 /// Erase the specified results and update the function type attribute.
-void eraseFunctionResults(Operation *op, const BitVector &resultIndices,
-                          Type newType);
+void eraseFunctionResults(FunctionOpInterface op,
+                          const BitVector &resultIndices, Type newType);
 
 /// Set a FunctionOpInterface operation's type signature.
-void setFunctionType(Operation *op, Type newType);
+void setFunctionType(FunctionOpInterface op, Type newType);
 
 /// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any
 /// types are inserted, `storage` is used to hold the new type list. The new
@@ -207,10 +206,6 @@ Attribute removeResultAttr(ConcreteType op, unsigned index, StringAttr name) {
 /// method on FunctionOpInterface::Trait.
 template <typename ConcreteOp>
 LogicalResult verifyTrait(ConcreteOp op) {
-  if (!op.getFunctionTypeAttr())
-    return op.emitOpError("requires a type attribute '")
-           << function_interface_impl::getTypeAttrName() << '\'';
-
   if (failed(op.verifyType()))
     return failure();
 

diff  --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td
index c56129ea895d9..e86057aa7ec2f 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.td
+++ b/mlir/include/mlir/IR/FunctionInterfaces.td
@@ -49,6 +49,16 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
           for each of the function results.
   }];
   let methods = [
+    InterfaceMethod<[{
+      Returns the type of the function.
+    }],
+    "::mlir::Type", "getFunctionType">,
+    InterfaceMethod<[{
+      Set the type of the function. This method should perform an unsafe
+      modification to the function type; it should not update argument or
+      result attributes.
+    }],
+    "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>,
     InterfaceMethod<[{
       Returns the function argument types based exclusively on
       the type (to allow for this method may be called on function
@@ -139,7 +149,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
         ArrayRef<NamedAttribute> attrs, TypeRange inputTypes) {
       state.addAttribute(SymbolTable::getSymbolAttrName(),
                         builder.getStringAttr(name));
-      state.addAttribute(function_interface_impl::getTypeAttrName(),
+      state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name),
                         TypeAttr::get(type));
       state.attributes.append(attrs.begin(), attrs.end());
 
@@ -244,11 +254,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
     // the derived operation, which should already have these defined
     // (via ODS).
 
-    /// Returns the name of the attribute used for function types.
-    static StringRef getTypeAttrName() {
-      return function_interface_impl::getTypeAttrName();
-    }
-
     /// Returns the name of the attribute used for function argument attributes.
     static StringRef getArgDictAttrName() {
       return function_interface_impl::getArgDictAttrName();
@@ -259,15 +264,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
       return function_interface_impl::getResultDictAttrName();
     }
 
-    /// Return the attribute containing the type of this function.
-    TypeAttr getFunctionTypeAttr() {
-      return this->getOperation()->template getAttrOfType<TypeAttr>(
-          getTypeAttrName());
-    }
-
-    /// Return the type of this function.
-    Type getFunctionType() { return getFunctionTypeAttr().getValue(); }
-
     //===------------------------------------------------------------------===//
     // Argument and Result Handling
     //===------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index d0e82de839c0c..9f522aaa49f92 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -59,12 +59,11 @@ using namespace mlir;
 /// Only retain those attributes that are not constructed by
 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
 /// attributes.
-static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
-                                 bool filterArgAndResAttrs,
+static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
                                  SmallVectorImpl<NamedAttribute> &result) {
-  for (const auto &attr : attrs) {
+  for (const NamedAttribute &attr : func->getAttrs()) {
     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
-        attr.getName() == FunctionOpInterface::getTypeAttrName() ||
+        attr.getName() == func.getFunctionTypeAttrName() ||
         attr.getName() == "func.varargs" ||
         (filterArgAndResAttrs &&
          (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
@@ -138,8 +137,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
                                    LLVM::LLVMFuncOp newFuncOp) {
   auto type = funcOp.getFunctionType();
   SmallVector<NamedAttribute, 4> attributes;
-  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
-                       attributes);
+  filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
   auto [wrapperFuncType, resultIsNowArg] =
       typeConverter.convertFunctionTypeCWrapper(type);
   if (resultIsNowArg)
@@ -204,8 +202,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   assert(wrapperType && "unexpected type conversion failure");
 
   SmallVector<NamedAttribute, 4> attributes;
-  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
-                       attributes);
+  filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
 
   if (resultIsNowArg)
     prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
@@ -304,8 +301,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
     // Propagate argument/result attributes to all converted arguments/result
     // obtained after converting a given original argument/result.
     SmallVector<NamedAttribute, 4> attributes;
-    filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
-                         attributes);
+    filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes);
     if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
       assert(!resAttrDicts.empty() && "expected array to be non-empty");
       auto newResAttrDicts =

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 85001d54d093d..48effe24f674e 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -60,7 +60,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   SmallVector<NamedAttribute, 4> attributes;
   for (const auto &attr : gpuFuncOp->getAttrs()) {
     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
-        attr.getName() == FunctionOpInterface::getTypeAttrName() ||
+        attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
         attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
       continue;
     attributes.push_back(attr);

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 119b1d3dea91e..2a8389598f36a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -226,7 +226,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
                                std::nullopt));
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() ||
+    if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
         namedAttr.getName() == SymbolTable::getSymbolAttrName())
       continue;
     newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index e0772b4dd90bc..064bf525db238 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -332,8 +332,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                    ArrayRef<DictionaryAttr> argAttrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(FunctionOpInterface::getTypeAttrName(),
-                     TypeAttr::get(type));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
 
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
@@ -352,11 +351,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
+                                           getFunctionTypeAttrName());
 }
 
 /// Check that the result type of async.func is not void and must be

diff  --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index 961cf2eb36e35..fc9bd115e2223 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -244,8 +244,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                    ArrayRef<DictionaryAttr> argAttrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(FunctionOpInterface::getTypeAttrName(),
-                     TypeAttr::get(type));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
 
@@ -263,11 +262,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
+                                           getFunctionTypeAttrName());
 }
 
 /// Clone the internal blocks from this function into dest and all attributes

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7f73a651d0e9b..80db6461ecc55 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -859,7 +859,8 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
                       ArrayRef<NamedAttribute> attrs) {
   result.addAttribute(SymbolTable::getSymbolAttrName(),
                       builder.getStringAttr(name));
-  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(type));
   result.addAttribute(getNumWorkgroupAttributionsAttrName(),
                       builder.getI64IntegerAttr(workgroupAttributions.size()));
   result.addAttributes(attrs);
@@ -930,7 +931,8 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto &arg : entryArgs)
     argTypes.push_back(arg.type);
   auto type = builder.getFunctionType(argTypes, resultTypes);
-  result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(type));
 
   function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
                                                 resultAttrs);
@@ -992,19 +994,14 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
     p << ' ' << getKernelKeyword();
 
   function_interface_impl::printFunctionAttributes(
-      p, *this, type.getNumInputs(), type.getNumResults(),
+      p, *this,
       {getNumWorkgroupAttributionsAttrName(),
-       GPUDialect::getKernelFuncAttrName()});
+       GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()});
   p << ' ';
   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
 }
 
 LogicalResult GPUFuncOp::verifyType() {
-  Type type = getFunctionTypeAttr().getValue();
-  if (!type.isa<FunctionType>())
-    return emitOpError("requires '" + getTypeAttrName() +
-                       "' attribute of function type");
-
   if (isKernel() && getFunctionType().getNumResults() != 0)
     return emitOpError() << "expected void return type for kernel function";
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f114acdbd0f86..6b428a171d79e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2090,7 +2090,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
                             function_interface_impl::VariadicFlag(isVariadic));
   if (!type)
     return failure();
-  result.addAttribute(FunctionOpInterface::getTypeAttrName(),
+  result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(type));
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
@@ -2130,8 +2130,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
   function_interface_impl::printFunctionSignature(p, *this, argTypes,
                                                   isVarArg(), resTypes);
   function_interface_impl::printFunctionAttributes(
-      p, *this, argTypes.size(), resTypes.size(),
-      {getLinkageAttrName(), getCConvAttrName()});
+      p, *this,
+      {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();

diff  --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
index 2f1e4b93a6ac3..27c613088df56 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -152,11 +152,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
+                                           getFunctionTypeAttrName());
 }
 
 //===----------------------------------------------------------------------===//
@@ -313,11 +315,13 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void SubgraphOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
+                                           getFunctionTypeAttrName());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
index e8a61ef4c6a4d..28fc4dbf97aea 100644
--- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
+++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -220,11 +220,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
+                                           getFunctionTypeAttrName());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 52ad8ad5fe7c7..3ce3913f2814b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2382,7 +2382,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto &arg : entryArgs)
     argTypes.push_back(arg.type);
   auto fnType = builder.getFunctionType(argTypes, resultTypes);
-  result.addAttribute(FunctionOpInterface::getTypeAttrName(),
+  result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(fnType));
 
   // Parse the optional function control keyword.
@@ -2417,8 +2417,9 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
   printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
           << "\"";
   function_interface_impl::printFunctionAttributes(
-      printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
-      {spirv::attributeName<spirv::FunctionControl>()});
+      printer, *this,
+      {spirv::attributeName<spirv::FunctionControl>(),
+       getFunctionTypeAttrName(), getFunctionControlAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = this->getBody();
@@ -2430,10 +2431,6 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::FuncOp::verifyType() {
-  auto type = getFunctionTypeAttr().getValue();
-  if (!type.isa<FunctionType>())
-    return emitOpError("requires '" + getTypeAttrName() +
-                       "' attribute of function type");
   if (getFunctionType().getNumResults() > 1)
     return emitOpError("cannot have more than one result");
   return success();
@@ -2473,7 +2470,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
                           ArrayRef<NamedAttribute> attrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
   state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
                      builder.getAttr<spirv::FunctionControlAttr>(control));
   state.attributes.append(attrs.begin(), attrs.end());

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2772c0150cda7..62e3a3de0f11b 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -531,7 +531,7 @@ FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
 
   // Copy over all attributes other than the function name and type.
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
+    if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
         namedAttr.getName() != SymbolTable::getSymbolAttrName())
       newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
   }

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 8c89ec8bba6cf..30c5f56f13886 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1311,11 +1311,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false, buildFuncType);
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
+                                           getFunctionTypeAttrName());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp
index 9481e4ae8175b..af692befb0fde 100644
--- a/mlir/lib/IR/FunctionImplementation.cpp
+++ b/mlir/lib/IR/FunctionImplementation.cpp
@@ -163,7 +163,7 @@ void mlir::function_interface_impl::addArgAndResultAttrs(
 
 ParseResult mlir::function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
-    FuncTypeBuilder funcTypeBuilder) {
+    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) {
   SmallVector<OpAsmParser::Argument> entryArgs;
   SmallVector<DictionaryAttr> resultAttrs;
   SmallVector<Type> resultTypes;
@@ -197,7 +197,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
            << "failed to construct function type"
            << (errorMessage.empty() ? "" : ": ") << errorMessage;
   }
-  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   // If function attributes are present, parse them.
   NamedAttrList parsedAttributes;
@@ -209,7 +209,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
   // dictionary.
   for (StringRef disallowed :
        {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
-        getTypeAttrName()}) {
+        typeAttrName.getValue()}) {
     if (parsedAttributes.get(disallowed))
       return parser.emitError(attributeDictLocation, "'")
              << disallowed
@@ -301,12 +301,11 @@ void mlir::function_interface_impl::printFunctionSignature(
 }
 
 void mlir::function_interface_impl::printFunctionAttributes(
-    OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
-    ArrayRef<StringRef> elided) {
+    OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
   // Print out function attributes, if present.
-  SmallVector<StringRef, 2> ignoredAttrs = {
-      ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
-      getArgDictAttrName(), getResultDictAttrName()};
+  SmallVector<StringRef, 2> ignoredAttrs = {SymbolTable::getSymbolAttrName(),
+                                            getArgDictAttrName(),
+                                            getResultDictAttrName()};
   ignoredAttrs.append(elided.begin(), elided.end());
 
   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
@@ -314,7 +313,8 @@ void mlir::function_interface_impl::printFunctionAttributes(
 
 void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
                                                     FunctionOpInterface op,
-                                                    bool isVariadic) {
+                                                    bool isVariadic,
+                                                    StringRef typeAttrName) {
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -329,8 +329,7 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
   ArrayRef<Type> argTypes = op.getArgumentTypes();
   ArrayRef<Type> resultTypes = op.getResultTypes();
   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
-  printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
-                          {visibilityAttrName});
+  printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName});
   // Print the body if this is not an external function.
   Region &body = op->getRegion(0);
   if (!body.empty()) {

diff  --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp
index 3331aefc76d13..9ba830366056c 100644
--- a/mlir/lib/IR/FunctionInterfaces.cpp
+++ b/mlir/lib/IR/FunctionInterfaces.cpp
@@ -112,7 +112,7 @@ void mlir::function_interface_impl::setAllResultAttrDicts(
 }
 
 void mlir::function_interface_impl::insertFunctionArguments(
-    Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
+    FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
     unsigned originalNumArgs, Type newType) {
   assert(argIndices.size() == argTypes.size());
@@ -152,15 +152,15 @@ void mlir::function_interface_impl::insertFunctionArguments(
   }
 
   // Update the function type and any entry block arguments.
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  op.setFunctionTypeAttr(TypeAttr::get(newType));
   for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
     entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
 }
 
 void mlir::function_interface_impl::insertFunctionResults(
-    Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
-    ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
-    Type newType) {
+    FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
+    TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
+    unsigned originalNumResults, Type newType) {
   assert(resultIndices.size() == resultTypes.size());
   assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
   if (resultIndices.empty())
@@ -196,11 +196,11 @@ void mlir::function_interface_impl::insertFunctionResults(
   }
 
   // Update the function type.
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  op.setFunctionTypeAttr(TypeAttr::get(newType));
 }
 
 void mlir::function_interface_impl::eraseFunctionArguments(
-    Operation *op, const BitVector &argIndices, Type newType) {
+    FunctionOpInterface op, const BitVector &argIndices, Type newType) {
   // There are 3 things that need to be updated:
   // - Function type.
   // - Arg attrs.
@@ -218,12 +218,12 @@ void mlir::function_interface_impl::eraseFunctionArguments(
   }
 
   // Update the function type and any entry block arguments.
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  op.setFunctionTypeAttr(TypeAttr::get(newType));
   entry.eraseArguments(argIndices);
 }
 
 void mlir::function_interface_impl::eraseFunctionResults(
-    Operation *op, const BitVector &resultIndices, Type newType) {
+    FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
   // There are 2 things that need to be updated:
   // - Function type.
   // - Result attrs.
@@ -239,7 +239,7 @@ void mlir::function_interface_impl::eraseFunctionResults(
   }
 
   // Update the function type.
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  op.setFunctionTypeAttr(TypeAttr::get(newType));
 }
 
 TypeRange mlir::function_interface_impl::insertTypesInto(
@@ -276,14 +276,13 @@ TypeRange mlir::function_interface_impl::filterTypesOut(
 // Function type signature.
 //===----------------------------------------------------------------------===//
 
-void mlir::function_interface_impl::setFunctionType(Operation *op,
+void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op,
                                                     Type newType) {
-  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
-  unsigned oldNumArgs = funcOp.getNumArguments();
-  unsigned oldNumResults = funcOp.getNumResults();
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
-  unsigned newNumArgs = funcOp.getNumArguments();
-  unsigned newNumResults = funcOp.getNumResults();
+  unsigned oldNumArgs = op.getNumArguments();
+  unsigned oldNumResults = op.getNumResults();
+  op.setFunctionTypeAttr(TypeAttr::get(newType));
+  unsigned newNumArgs = op.getNumArguments();
+  unsigned newNumResults = op.getNumResults();
 
   // Functor used to update the argument and result attributes of the function.
   auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,


        


More information about the Mlir-commits mailing list