[Mlir-commits] [mlir] cf98e82 - Revert "[mlir] FunctionOpInterface: make get/setFunctionType interface methods"

David Spickett llvmlistbot at llvm.org
Fri Dec 9 07:37:07 PST 2022


Author: David Spickett
Date: 2022-12-09T15:36:48Z
New Revision: cf98e8273c5f1c84afcfdd9bcea486ec22f26768

URL: https://github.com/llvm/llvm-project/commit/cf98e8273c5f1c84afcfdd9bcea486ec22f26768
DIFF: https://github.com/llvm/llvm-project/commit/cf98e8273c5f1c84afcfdd9bcea486ec22f26768.diff

LOG: Revert "[mlir] FunctionOpInterface: make get/setFunctionType interface methods"
and "[mlir] Fix examples build"

This reverts commit fbc253fe81da4e1d6bfa2519e01e03f21d8c40a8 and
96cf183bccd7d1c3083f169a89a6af1f263b3aae.

Which I missed in the first revert in f3379feabe38fd3711b13ffcf6de4aab03b7ccdc.

Added: 
    

Modified: 
    mlir/examples/toy/Ch2/mlir/Dialect.cpp
    mlir/examples/toy/Ch3/mlir/Dialect.cpp
    mlir/examples/toy/Ch4/mlir/Dialect.cpp
    mlir/examples/toy/Ch5/mlir/Dialect.cpp
    mlir/examples/toy/Ch6/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/include/mlir/IR/FunctionImplementation.h
    mlir/include/mlir/IR/FunctionInterfaces.h
    mlir/include/mlir/IR/FunctionInterfaces.td
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/Func/IR/FuncOps.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
    mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/IR/FunctionImplementation.cpp
    mlir/lib/IR/FunctionInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
index 8c36fafc2f001..dbc1efb3d06be 100644
--- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
@@ -211,9 +211,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
index 6bf140487420f..50e2dfc7f4a3e 100644
--- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
@@ -198,9 +198,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index 8343a1cb5fbc3..0a6195b12d5d4 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -287,9 +287,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index dde12f51c351e..f236a1ffe0e5a 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -287,9 +287,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index dde12f51c351e..f236a1ffe0e5a 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -287,9 +287,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 3413e57d37b30..cc66a5d44b5f4 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -314,9 +314,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return mlir::function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType,
-      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(mlir::OpAsmPrinter &p) {

diff  --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h
index f4c0cc03050fe..5265f781d1a77 100644
--- a/mlir/include/mlir/IR/FunctionImplementation.h
+++ b/mlir/include/mlir/IR/FunctionImplementation.h
@@ -69,19 +69,17 @@ Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
 
 /// Parser implementation for function-like operations.  Uses
 /// `funcTypeBuilder` to construct the custom function type given lists of
-/// input and output types. The parser sets the `typeAttrName` attribute to the
-/// resulting function type. If `allowVariadic` is set, the parser will accept
+/// input and output types.  If `allowVariadic` is set, the parser will accept
 /// trailing ellipsis in the function signature and indicate to the builder
 /// whether the function is variadic.  If the builder returns a null type,
 /// `result` will not contain the `type` attribute.  The caller can then add a
 /// type, report the error or delegate the reporting to the op's verifier.
 ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
-                            bool allowVariadic, StringAttr typeAttrName,
+                            bool allowVariadic,
                             FuncTypeBuilder funcTypeBuilder);
 
 /// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
-                     StringRef typeAttrName);
+void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic);
 
 /// Prints the signature of the function-like operation `op`. Assumes `op` has
 /// is a FunctionOpInterface and has passed verification.
@@ -94,7 +92,8 @@ void printFunctionSignature(OpAsmPrinter &p, Operation *op,
 /// function-like operation internally are not printed. Nothing is printed
 /// if all attributes are elided. Assumes `op` is a FunctionOpInterface and
 /// has passed verification.
-void printFunctionAttributes(OpAsmPrinter &p, Operation *op,
+void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
+                             unsigned numResults,
                              ArrayRef<StringRef> elided = {});
 
 } // namespace function_interface_impl

diff  --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h
index bc2ec4751c582..23fd884d97f14 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.h
+++ b/mlir/include/mlir/IR/FunctionInterfaces.h
@@ -22,10 +22,12 @@
 #include "llvm/ADT/SmallString.h"
 
 namespace mlir {
-class FunctionOpInterface;
 
 namespace function_interface_impl {
 
+/// Return the name of the attribute used for function types.
+inline StringRef getTypeAttrName() { return "function_type"; }
+
 /// Return the name of the attribute used for function argument attributes.
 inline StringRef getArgDictAttrName() { return "arg_attrs"; }
 
@@ -70,29 +72,28 @@ inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
 }
 
 /// Insert the specified arguments and update the function type attribute.
-void insertFunctionArguments(FunctionOpInterface op,
-                             ArrayRef<unsigned> argIndices, TypeRange argTypes,
+void insertFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
+                             TypeRange argTypes,
                              ArrayRef<DictionaryAttr> argAttrs,
                              ArrayRef<Location> argLocs,
                              unsigned originalNumArgs, Type newType);
 
 /// Insert the specified results and update the function type attribute.
-void insertFunctionResults(FunctionOpInterface op,
-                           ArrayRef<unsigned> resultIndices,
+void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
                            TypeRange resultTypes,
                            ArrayRef<DictionaryAttr> resultAttrs,
                            unsigned originalNumResults, Type newType);
 
 /// Erase the specified arguments and update the function type attribute.
-void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices,
+void eraseFunctionArguments(Operation *op, const BitVector &argIndices,
                             Type newType);
 
 /// Erase the specified results and update the function type attribute.
-void eraseFunctionResults(FunctionOpInterface op,
-                          const BitVector &resultIndices, Type newType);
+void eraseFunctionResults(Operation *op, const BitVector &resultIndices,
+                          Type newType);
 
 /// Set a FunctionOpInterface operation's type signature.
-void setFunctionType(FunctionOpInterface op, Type newType);
+void setFunctionType(Operation *op, Type newType);
 
 /// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any
 /// types are inserted, `storage` is used to hold the new type list. The new
@@ -206,6 +207,10 @@ Attribute removeResultAttr(ConcreteType op, unsigned index, StringAttr name) {
 /// method on FunctionOpInterface::Trait.
 template <typename ConcreteOp>
 LogicalResult verifyTrait(ConcreteOp op) {
+  if (!op.getFunctionTypeAttr())
+    return op.emitOpError("requires a type attribute '")
+           << function_interface_impl::getTypeAttrName() << '\'';
+
   if (failed(op.verifyType()))
     return failure();
 

diff  --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td
index e86057aa7ec2f..c56129ea895d9 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.td
+++ b/mlir/include/mlir/IR/FunctionInterfaces.td
@@ -49,16 +49,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
           for each of the function results.
   }];
   let methods = [
-    InterfaceMethod<[{
-      Returns the type of the function.
-    }],
-    "::mlir::Type", "getFunctionType">,
-    InterfaceMethod<[{
-      Set the type of the function. This method should perform an unsafe
-      modification to the function type; it should not update argument or
-      result attributes.
-    }],
-    "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>,
     InterfaceMethod<[{
       Returns the function argument types based exclusively on
       the type (to allow for this method may be called on function
@@ -149,7 +139,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
         ArrayRef<NamedAttribute> attrs, TypeRange inputTypes) {
       state.addAttribute(SymbolTable::getSymbolAttrName(),
                         builder.getStringAttr(name));
-      state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name),
+      state.addAttribute(function_interface_impl::getTypeAttrName(),
                         TypeAttr::get(type));
       state.attributes.append(attrs.begin(), attrs.end());
 
@@ -254,6 +244,11 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
     // the derived operation, which should already have these defined
     // (via ODS).
 
+    /// Returns the name of the attribute used for function types.
+    static StringRef getTypeAttrName() {
+      return function_interface_impl::getTypeAttrName();
+    }
+
     /// Returns the name of the attribute used for function argument attributes.
     static StringRef getArgDictAttrName() {
       return function_interface_impl::getArgDictAttrName();
@@ -264,6 +259,15 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
       return function_interface_impl::getResultDictAttrName();
     }
 
+    /// Return the attribute containing the type of this function.
+    TypeAttr getFunctionTypeAttr() {
+      return this->getOperation()->template getAttrOfType<TypeAttr>(
+          getTypeAttrName());
+    }
+
+    /// Return the type of this function.
+    Type getFunctionType() { return getFunctionTypeAttr().getValue(); }
+
     //===------------------------------------------------------------------===//
     // Argument and Result Handling
     //===------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 9f522aaa49f92..d0e82de839c0c 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -59,11 +59,12 @@ using namespace mlir;
 /// Only retain those attributes that are not constructed by
 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
 /// attributes.
-static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
+static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
+                                 bool filterArgAndResAttrs,
                                  SmallVectorImpl<NamedAttribute> &result) {
-  for (const NamedAttribute &attr : func->getAttrs()) {
+  for (const auto &attr : attrs) {
     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
-        attr.getName() == func.getFunctionTypeAttrName() ||
+        attr.getName() == FunctionOpInterface::getTypeAttrName() ||
         attr.getName() == "func.varargs" ||
         (filterArgAndResAttrs &&
          (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
@@ -137,7 +138,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
                                    LLVM::LLVMFuncOp newFuncOp) {
   auto type = funcOp.getFunctionType();
   SmallVector<NamedAttribute, 4> attributes;
-  filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
+  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
+                       attributes);
   auto [wrapperFuncType, resultIsNowArg] =
       typeConverter.convertFunctionTypeCWrapper(type);
   if (resultIsNowArg)
@@ -202,7 +204,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   assert(wrapperType && "unexpected type conversion failure");
 
   SmallVector<NamedAttribute, 4> attributes;
-  filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
+  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
+                       attributes);
 
   if (resultIsNowArg)
     prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
@@ -301,7 +304,8 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
     // Propagate argument/result attributes to all converted arguments/result
     // obtained after converting a given original argument/result.
     SmallVector<NamedAttribute, 4> attributes;
-    filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes);
+    filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
+                         attributes);
     if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
       assert(!resAttrDicts.empty() && "expected array to be non-empty");
       auto newResAttrDicts =

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 48effe24f674e..85001d54d093d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -60,7 +60,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   SmallVector<NamedAttribute, 4> attributes;
   for (const auto &attr : gpuFuncOp->getAttrs()) {
     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
-        attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
+        attr.getName() == FunctionOpInterface::getTypeAttrName() ||
         attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
       continue;
     attributes.push_back(attr);

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 2a8389598f36a..119b1d3dea91e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -226,7 +226,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
                                std::nullopt));
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
+    if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() ||
         namedAttr.getName() == SymbolTable::getSymbolAttrName())
       continue;
     newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 064bf525db238..e0772b4dd90bc 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -332,7 +332,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                    ArrayRef<DictionaryAttr> argAttrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.addAttribute(FunctionOpInterface::getTypeAttrName(),
+                     TypeAttr::get(type));
 
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
@@ -351,13 +352,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 /// Check that the result type of async.func is not void and must be

diff  --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index fc9bd115e2223..961cf2eb36e35 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -244,7 +244,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                    ArrayRef<DictionaryAttr> argAttrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.addAttribute(FunctionOpInterface::getTypeAttrName(),
+                     TypeAttr::get(type));
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
 
@@ -262,13 +263,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 /// Clone the internal blocks from this function into dest and all attributes

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 80db6461ecc55..7f73a651d0e9b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -859,8 +859,7 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
                       ArrayRef<NamedAttribute> attrs) {
   result.addAttribute(SymbolTable::getSymbolAttrName(),
                       builder.getStringAttr(name));
-  result.addAttribute(getFunctionTypeAttrName(result.name),
-                      TypeAttr::get(type));
+  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
   result.addAttribute(getNumWorkgroupAttributionsAttrName(),
                       builder.getI64IntegerAttr(workgroupAttributions.size()));
   result.addAttributes(attrs);
@@ -931,8 +930,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto &arg : entryArgs)
     argTypes.push_back(arg.type);
   auto type = builder.getFunctionType(argTypes, resultTypes);
-  result.addAttribute(getFunctionTypeAttrName(result.name),
-                      TypeAttr::get(type));
+  result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
 
   function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
                                                 resultAttrs);
@@ -994,14 +992,19 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
     p << ' ' << getKernelKeyword();
 
   function_interface_impl::printFunctionAttributes(
-      p, *this,
+      p, *this, type.getNumInputs(), type.getNumResults(),
       {getNumWorkgroupAttributionsAttrName(),
-       GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()});
+       GPUDialect::getKernelFuncAttrName()});
   p << ' ';
   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
 }
 
 LogicalResult GPUFuncOp::verifyType() {
+  Type type = getFunctionTypeAttr().getValue();
+  if (!type.isa<FunctionType>())
+    return emitOpError("requires '" + getTypeAttrName() +
+                       "' attribute of function type");
+
   if (isKernel() && getFunctionType().getNumResults() != 0)
     return emitOpError() << "expected void return type for kernel function";
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 028f25ec3c1d2..1087bcf5d32f4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2090,7 +2090,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
                             function_interface_impl::VariadicFlag(isVariadic));
   if (!type)
     return failure();
-  result.addAttribute(getFunctionTypeAttrName(result.name),
+  result.addAttribute(FunctionOpInterface::getTypeAttrName(),
                       TypeAttr::get(type));
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
@@ -2130,8 +2130,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
   function_interface_impl::printFunctionSignature(p, *this, argTypes,
                                                   isVarArg(), resTypes);
   function_interface_impl::printFunctionAttributes(
-      p, *this,
-      {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()});
+      p, *this, argTypes.size(), resTypes.size(),
+      {getLinkageAttrName(), getCConvAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = getBody();

diff  --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
index 27c613088df56..2f1e4b93a6ac3 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -152,13 +152,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 //===----------------------------------------------------------------------===//
@@ -315,13 +313,11 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void SubgraphOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
index 28fc4dbf97aea..e8a61ef4c6a4d 100644
--- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
+++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -220,13 +220,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3ce3913f2814b..52ad8ad5fe7c7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2382,7 +2382,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
   for (auto &arg : entryArgs)
     argTypes.push_back(arg.type);
   auto fnType = builder.getFunctionType(argTypes, resultTypes);
-  result.addAttribute(getFunctionTypeAttrName(result.name),
+  result.addAttribute(FunctionOpInterface::getTypeAttrName(),
                       TypeAttr::get(fnType));
 
   // Parse the optional function control keyword.
@@ -2417,9 +2417,8 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
   printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
           << "\"";
   function_interface_impl::printFunctionAttributes(
-      printer, *this,
-      {spirv::attributeName<spirv::FunctionControl>(),
-       getFunctionTypeAttrName(), getFunctionControlAttrName()});
+      printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
+      {spirv::attributeName<spirv::FunctionControl>()});
 
   // Print the body if this is not an external function.
   Region &body = this->getBody();
@@ -2431,6 +2430,10 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::FuncOp::verifyType() {
+  auto type = getFunctionTypeAttr().getValue();
+  if (!type.isa<FunctionType>())
+    return emitOpError("requires '" + getTypeAttrName() +
+                       "' attribute of function type");
   if (getFunctionType().getNumResults() > 1)
     return emitOpError("cannot have more than one result");
   return success();
@@ -2470,7 +2473,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
                           ArrayRef<NamedAttribute> attrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
   state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
                      builder.getAttr<spirv::FunctionControlAttr>(control));
   state.attributes.append(attrs.begin(), attrs.end());

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62e3a3de0f11b..2772c0150cda7 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -531,7 +531,7 @@ FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
 
   // Copy over all attributes other than the function name and type.
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
+    if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
         namedAttr.getName() != SymbolTable::getSymbolAttrName())
       newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
   }

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 30c5f56f13886..8c89ec8bba6cf 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1311,13 +1311,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
          std::string &) { return builder.getFunctionType(argTypes, results); };
 
   return function_interface_impl::parseFunctionOp(
-      parser, result, /*allowVariadic=*/false,
-      getFunctionTypeAttrName(result.name), buildFuncType);
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 void FuncOp::print(OpAsmPrinter &p) {
-  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
-                                           getFunctionTypeAttrName());
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp
index af692befb0fde..9481e4ae8175b 100644
--- a/mlir/lib/IR/FunctionImplementation.cpp
+++ b/mlir/lib/IR/FunctionImplementation.cpp
@@ -163,7 +163,7 @@ void mlir::function_interface_impl::addArgAndResultAttrs(
 
 ParseResult mlir::function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
-    StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) {
+    FuncTypeBuilder funcTypeBuilder) {
   SmallVector<OpAsmParser::Argument> entryArgs;
   SmallVector<DictionaryAttr> resultAttrs;
   SmallVector<Type> resultTypes;
@@ -197,7 +197,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
            << "failed to construct function type"
            << (errorMessage.empty() ? "" : ": ") << errorMessage;
   }
-  result.addAttribute(typeAttrName, TypeAttr::get(type));
+  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
 
   // If function attributes are present, parse them.
   NamedAttrList parsedAttributes;
@@ -209,7 +209,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
   // dictionary.
   for (StringRef disallowed :
        {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
-        typeAttrName.getValue()}) {
+        getTypeAttrName()}) {
     if (parsedAttributes.get(disallowed))
       return parser.emitError(attributeDictLocation, "'")
              << disallowed
@@ -301,11 +301,12 @@ void mlir::function_interface_impl::printFunctionSignature(
 }
 
 void mlir::function_interface_impl::printFunctionAttributes(
-    OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
+    OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
+    ArrayRef<StringRef> elided) {
   // Print out function attributes, if present.
-  SmallVector<StringRef, 2> ignoredAttrs = {SymbolTable::getSymbolAttrName(),
-                                            getArgDictAttrName(),
-                                            getResultDictAttrName()};
+  SmallVector<StringRef, 2> ignoredAttrs = {
+      ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
+      getArgDictAttrName(), getResultDictAttrName()};
   ignoredAttrs.append(elided.begin(), elided.end());
 
   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
@@ -313,8 +314,7 @@ void mlir::function_interface_impl::printFunctionAttributes(
 
 void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
                                                     FunctionOpInterface op,
-                                                    bool isVariadic,
-                                                    StringRef typeAttrName) {
+                                                    bool isVariadic) {
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -329,7 +329,8 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
   ArrayRef<Type> argTypes = op.getArgumentTypes();
   ArrayRef<Type> resultTypes = op.getResultTypes();
   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
-  printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName});
+  printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
+                          {visibilityAttrName});
   // Print the body if this is not an external function.
   Region &body = op->getRegion(0);
   if (!body.empty()) {

diff  --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp
index 9ba830366056c..3331aefc76d13 100644
--- a/mlir/lib/IR/FunctionInterfaces.cpp
+++ b/mlir/lib/IR/FunctionInterfaces.cpp
@@ -112,7 +112,7 @@ void mlir::function_interface_impl::setAllResultAttrDicts(
 }
 
 void mlir::function_interface_impl::insertFunctionArguments(
-    FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
+    Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
     unsigned originalNumArgs, Type newType) {
   assert(argIndices.size() == argTypes.size());
@@ -152,15 +152,15 @@ void mlir::function_interface_impl::insertFunctionArguments(
   }
 
   // Update the function type and any entry block arguments.
-  op.setFunctionTypeAttr(TypeAttr::get(newType));
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
   for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
     entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
 }
 
 void mlir::function_interface_impl::insertFunctionResults(
-    FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
-    TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
-    unsigned originalNumResults, Type newType) {
+    Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
+    ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
+    Type newType) {
   assert(resultIndices.size() == resultTypes.size());
   assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
   if (resultIndices.empty())
@@ -196,11 +196,11 @@ void mlir::function_interface_impl::insertFunctionResults(
   }
 
   // Update the function type.
-  op.setFunctionTypeAttr(TypeAttr::get(newType));
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
 }
 
 void mlir::function_interface_impl::eraseFunctionArguments(
-    FunctionOpInterface op, const BitVector &argIndices, Type newType) {
+    Operation *op, const BitVector &argIndices, Type newType) {
   // There are 3 things that need to be updated:
   // - Function type.
   // - Arg attrs.
@@ -218,12 +218,12 @@ void mlir::function_interface_impl::eraseFunctionArguments(
   }
 
   // Update the function type and any entry block arguments.
-  op.setFunctionTypeAttr(TypeAttr::get(newType));
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
   entry.eraseArguments(argIndices);
 }
 
 void mlir::function_interface_impl::eraseFunctionResults(
-    FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
+    Operation *op, const BitVector &resultIndices, Type newType) {
   // There are 2 things that need to be updated:
   // - Function type.
   // - Result attrs.
@@ -239,7 +239,7 @@ void mlir::function_interface_impl::eraseFunctionResults(
   }
 
   // Update the function type.
-  op.setFunctionTypeAttr(TypeAttr::get(newType));
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
 }
 
 TypeRange mlir::function_interface_impl::insertTypesInto(
@@ -276,13 +276,14 @@ TypeRange mlir::function_interface_impl::filterTypesOut(
 // Function type signature.
 //===----------------------------------------------------------------------===//
 
-void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op,
+void mlir::function_interface_impl::setFunctionType(Operation *op,
                                                     Type newType) {
-  unsigned oldNumArgs = op.getNumArguments();
-  unsigned oldNumResults = op.getNumResults();
-  op.setFunctionTypeAttr(TypeAttr::get(newType));
-  unsigned newNumArgs = op.getNumArguments();
-  unsigned newNumResults = op.getNumResults();
+  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
+  unsigned oldNumArgs = funcOp.getNumArguments();
+  unsigned oldNumResults = funcOp.getNumResults();
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  unsigned newNumArgs = funcOp.getNumArguments();
+  unsigned newNumResults = funcOp.getNumResults();
 
   // Functor used to update the argument and result attributes of the function.
   auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,


        


More information about the Mlir-commits mailing list