[Mlir-commits] [mlir] cf98e82 - Revert "[mlir] FunctionOpInterface: make get/setFunctionType interface methods"
David Spickett
llvmlistbot at llvm.org
Fri Dec 9 07:37:07 PST 2022
Author: David Spickett
Date: 2022-12-09T15:36:48Z
New Revision: cf98e8273c5f1c84afcfdd9bcea486ec22f26768
URL: https://github.com/llvm/llvm-project/commit/cf98e8273c5f1c84afcfdd9bcea486ec22f26768
DIFF: https://github.com/llvm/llvm-project/commit/cf98e8273c5f1c84afcfdd9bcea486ec22f26768.diff
LOG: Revert "[mlir] FunctionOpInterface: make get/setFunctionType interface methods"
and "[mlir] Fix examples build"
This reverts commit fbc253fe81da4e1d6bfa2519e01e03f21d8c40a8 and
96cf183bccd7d1c3083f169a89a6af1f263b3aae.
Which I missed in the first revert in f3379feabe38fd3711b13ffcf6de4aab03b7ccdc.
Added:
Modified:
mlir/examples/toy/Ch2/mlir/Dialect.cpp
mlir/examples/toy/Ch3/mlir/Dialect.cpp
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/IR/FunctionImplementation.h
mlir/include/mlir/IR/FunctionInterfaces.h
mlir/include/mlir/IR/FunctionInterfaces.td
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Func/IR/FuncOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/IR/FunctionImplementation.cpp
mlir/lib/IR/FunctionInterfaces.cpp
Removed:
################################################################################
diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
index 8c36fafc2f001..dbc1efb3d06be 100644
--- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
@@ -211,9 +211,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
index 6bf140487420f..50e2dfc7f4a3e 100644
--- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
@@ -198,9 +198,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index 8343a1cb5fbc3..0a6195b12d5d4 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -287,9 +287,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index dde12f51c351e..f236a1ffe0e5a 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -287,9 +287,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index dde12f51c351e..f236a1ffe0e5a 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -287,9 +287,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 3413e57d37b30..cc66a5d44b5f4 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -314,9 +314,7 @@ mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h
index f4c0cc03050fe..5265f781d1a77 100644
--- a/mlir/include/mlir/IR/FunctionImplementation.h
+++ b/mlir/include/mlir/IR/FunctionImplementation.h
@@ -69,19 +69,17 @@ Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
-/// input and output types. The parser sets the `typeAttrName` attribute to the
-/// resulting function type. If `allowVariadic` is set, the parser will accept
+/// input and output types. If `allowVariadic` is set, the parser will accept
/// trailing ellipsis in the function signature and indicate to the builder
/// whether the function is variadic. If the builder returns a null type,
/// `result` will not contain the `type` attribute. The caller can then add a
/// type, report the error or delegate the reporting to the op's verifier.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result,
- bool allowVariadic, StringAttr typeAttrName,
+ bool allowVariadic,
FuncTypeBuilder funcTypeBuilder);
/// Printer implementation for function-like operations.
-void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
- StringRef typeAttrName);
+void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic);
/// Prints the signature of the function-like operation `op`. Assumes `op` has
/// is a FunctionOpInterface and has passed verification.
@@ -94,7 +92,8 @@ void printFunctionSignature(OpAsmPrinter &p, Operation *op,
/// function-like operation internally are not printed. Nothing is printed
/// if all attributes are elided. Assumes `op` is a FunctionOpInterface and
/// has passed verification.
-void printFunctionAttributes(OpAsmPrinter &p, Operation *op,
+void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
+ unsigned numResults,
ArrayRef<StringRef> elided = {});
} // namespace function_interface_impl
diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h
index bc2ec4751c582..23fd884d97f14 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.h
+++ b/mlir/include/mlir/IR/FunctionInterfaces.h
@@ -22,10 +22,12 @@
#include "llvm/ADT/SmallString.h"
namespace mlir {
-class FunctionOpInterface;
namespace function_interface_impl {
+/// Return the name of the attribute used for function types.
+inline StringRef getTypeAttrName() { return "function_type"; }
+
/// Return the name of the attribute used for function argument attributes.
inline StringRef getArgDictAttrName() { return "arg_attrs"; }
@@ -70,29 +72,28 @@ inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
}
/// Insert the specified arguments and update the function type attribute.
-void insertFunctionArguments(FunctionOpInterface op,
- ArrayRef<unsigned> argIndices, TypeRange argTypes,
+void insertFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
+ TypeRange argTypes,
ArrayRef<DictionaryAttr> argAttrs,
ArrayRef<Location> argLocs,
unsigned originalNumArgs, Type newType);
/// Insert the specified results and update the function type attribute.
-void insertFunctionResults(FunctionOpInterface op,
- ArrayRef<unsigned> resultIndices,
+void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
TypeRange resultTypes,
ArrayRef<DictionaryAttr> resultAttrs,
unsigned originalNumResults, Type newType);
/// Erase the specified arguments and update the function type attribute.
-void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices,
+void eraseFunctionArguments(Operation *op, const BitVector &argIndices,
Type newType);
/// Erase the specified results and update the function type attribute.
-void eraseFunctionResults(FunctionOpInterface op,
- const BitVector &resultIndices, Type newType);
+void eraseFunctionResults(Operation *op, const BitVector &resultIndices,
+ Type newType);
/// Set a FunctionOpInterface operation's type signature.
-void setFunctionType(FunctionOpInterface op, Type newType);
+void setFunctionType(Operation *op, Type newType);
/// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any
/// types are inserted, `storage` is used to hold the new type list. The new
@@ -206,6 +207,10 @@ Attribute removeResultAttr(ConcreteType op, unsigned index, StringAttr name) {
/// method on FunctionOpInterface::Trait.
template <typename ConcreteOp>
LogicalResult verifyTrait(ConcreteOp op) {
+ if (!op.getFunctionTypeAttr())
+ return op.emitOpError("requires a type attribute '")
+ << function_interface_impl::getTypeAttrName() << '\'';
+
if (failed(op.verifyType()))
return failure();
diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td
index e86057aa7ec2f..c56129ea895d9 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.td
+++ b/mlir/include/mlir/IR/FunctionInterfaces.td
@@ -49,16 +49,6 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
for each of the function results.
}];
let methods = [
- InterfaceMethod<[{
- Returns the type of the function.
- }],
- "::mlir::Type", "getFunctionType">,
- InterfaceMethod<[{
- Set the type of the function. This method should perform an unsafe
- modification to the function type; it should not update argument or
- result attributes.
- }],
- "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>,
InterfaceMethod<[{
Returns the function argument types based exclusively on
the type (to allow for this method may be called on function
@@ -149,7 +139,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
ArrayRef<NamedAttribute> attrs, TypeRange inputTypes) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name),
+ state.addAttribute(function_interface_impl::getTypeAttrName(),
TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
@@ -254,6 +244,11 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
// the derived operation, which should already have these defined
// (via ODS).
+ /// Returns the name of the attribute used for function types.
+ static StringRef getTypeAttrName() {
+ return function_interface_impl::getTypeAttrName();
+ }
+
/// Returns the name of the attribute used for function argument attributes.
static StringRef getArgDictAttrName() {
return function_interface_impl::getArgDictAttrName();
@@ -264,6 +259,15 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
return function_interface_impl::getResultDictAttrName();
}
+ /// Return the attribute containing the type of this function.
+ TypeAttr getFunctionTypeAttr() {
+ return this->getOperation()->template getAttrOfType<TypeAttr>(
+ getTypeAttrName());
+ }
+
+ /// Return the type of this function.
+ Type getFunctionType() { return getFunctionTypeAttr().getValue(); }
+
//===------------------------------------------------------------------===//
// Argument and Result Handling
//===------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 9f522aaa49f92..d0e82de839c0c 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -59,11 +59,12 @@ using namespace mlir;
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
-static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
+static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
+ bool filterArgAndResAttrs,
SmallVectorImpl<NamedAttribute> &result) {
- for (const NamedAttribute &attr : func->getAttrs()) {
+ for (const auto &attr : attrs) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
- attr.getName() == func.getFunctionTypeAttrName() ||
+ attr.getName() == FunctionOpInterface::getTypeAttrName() ||
attr.getName() == "func.varargs" ||
(filterArgAndResAttrs &&
(attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
@@ -137,7 +138,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
LLVM::LLVMFuncOp newFuncOp) {
auto type = funcOp.getFunctionType();
SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
+ filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
+ attributes);
auto [wrapperFuncType, resultIsNowArg] =
typeConverter.convertFunctionTypeCWrapper(type);
if (resultIsNowArg)
@@ -202,7 +204,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
assert(wrapperType && "unexpected type conversion failure");
SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes);
+ filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
+ attributes);
if (resultIsNowArg)
prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
@@ -301,7 +304,8 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
// Propagate argument/result attributes to all converted arguments/result
// obtained after converting a given original argument/result.
SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes);
+ filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
+ attributes);
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
assert(!resAttrDicts.empty() && "expected array to be non-empty");
auto newResAttrDicts =
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 48effe24f674e..85001d54d093d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -60,7 +60,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
SmallVector<NamedAttribute, 4> attributes;
for (const auto &attr : gpuFuncOp->getAttrs()) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
- attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
+ attr.getName() == FunctionOpInterface::getTypeAttrName() ||
attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
continue;
attributes.push_back(attr);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 2a8389598f36a..119b1d3dea91e 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -226,7 +226,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
std::nullopt));
for (const auto &namedAttr : funcOp->getAttrs()) {
- if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
+ if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() ||
namedAttr.getName() == SymbolTable::getSymbolAttrName())
continue;
newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 064bf525db238..e0772b4dd90bc 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -332,7 +332,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+ state.addAttribute(FunctionOpInterface::getTypeAttrName(),
+ TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
@@ -351,13 +352,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
/// Check that the result type of async.func is not void and must be
diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index fc9bd115e2223..961cf2eb36e35 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -244,7 +244,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+ state.addAttribute(FunctionOpInterface::getTypeAttrName(),
+ TypeAttr::get(type));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
@@ -262,13 +263,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
/// Clone the internal blocks from this function into dest and all attributes
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 80db6461ecc55..7f73a651d0e9b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -859,8 +859,7 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<NamedAttribute> attrs) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- result.addAttribute(getFunctionTypeAttrName(result.name),
- TypeAttr::get(type));
+ result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
result.addAttribute(getNumWorkgroupAttributionsAttrName(),
builder.getI64IntegerAttr(workgroupAttributions.size()));
result.addAttributes(attrs);
@@ -931,8 +930,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
for (auto &arg : entryArgs)
argTypes.push_back(arg.type);
auto type = builder.getFunctionType(argTypes, resultTypes);
- result.addAttribute(getFunctionTypeAttrName(result.name),
- TypeAttr::get(type));
+ result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
resultAttrs);
@@ -994,14 +992,19 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
p << ' ' << getKernelKeyword();
function_interface_impl::printFunctionAttributes(
- p, *this,
+ p, *this, type.getNumInputs(), type.getNumResults(),
{getNumWorkgroupAttributionsAttrName(),
- GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()});
+ GPUDialect::getKernelFuncAttrName()});
p << ' ';
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
LogicalResult GPUFuncOp::verifyType() {
+ Type type = getFunctionTypeAttr().getValue();
+ if (!type.isa<FunctionType>())
+ return emitOpError("requires '" + getTypeAttrName() +
+ "' attribute of function type");
+
if (isKernel() && getFunctionType().getNumResults() != 0)
return emitOpError() << "expected void return type for kernel function";
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 028f25ec3c1d2..1087bcf5d32f4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2090,7 +2090,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
function_interface_impl::VariadicFlag(isVariadic));
if (!type)
return failure();
- result.addAttribute(getFunctionTypeAttrName(result.name),
+ result.addAttribute(FunctionOpInterface::getTypeAttrName(),
TypeAttr::get(type));
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
@@ -2130,8 +2130,8 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionSignature(p, *this, argTypes,
isVarArg(), resTypes);
function_interface_impl::printFunctionAttributes(
- p, *this,
- {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()});
+ p, *this, argTypes.size(), resTypes.size(),
+ {getLinkageAttrName(), getCConvAttrName()});
// Print the body if this is not an external function.
Region &body = getBody();
diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
index 27c613088df56..2f1e4b93a6ac3 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -152,13 +152,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
@@ -315,13 +313,11 @@ ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void SubgraphOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
index 28fc4dbf97aea..e8a61ef4c6a4d 100644
--- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
+++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -220,13 +220,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3ce3913f2814b..52ad8ad5fe7c7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2382,7 +2382,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
for (auto &arg : entryArgs)
argTypes.push_back(arg.type);
auto fnType = builder.getFunctionType(argTypes, resultTypes);
- result.addAttribute(getFunctionTypeAttrName(result.name),
+ result.addAttribute(FunctionOpInterface::getTypeAttrName(),
TypeAttr::get(fnType));
// Parse the optional function control keyword.
@@ -2417,9 +2417,8 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
<< "\"";
function_interface_impl::printFunctionAttributes(
- printer, *this,
- {spirv::attributeName<spirv::FunctionControl>(),
- getFunctionTypeAttrName(), getFunctionControlAttrName()});
+ printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
+ {spirv::attributeName<spirv::FunctionControl>()});
// Print the body if this is not an external function.
Region &body = this->getBody();
@@ -2431,6 +2430,10 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::FuncOp::verifyType() {
+ auto type = getFunctionTypeAttr().getValue();
+ if (!type.isa<FunctionType>())
+ return emitOpError("requires '" + getTypeAttrName() +
+ "' attribute of function type");
if (getFunctionType().getNumResults() > 1)
return emitOpError("cannot have more than one result");
return success();
@@ -2470,7 +2473,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
- state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+ state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
builder.getAttr<spirv::FunctionControlAttr>(control));
state.attributes.append(attrs.begin(), attrs.end());
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62e3a3de0f11b..2772c0150cda7 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -531,7 +531,7 @@ FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
// Copy over all attributes other than the function name and type.
for (const auto &namedAttr : funcOp->getAttrs()) {
- if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
+ if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
namedAttr.getName() != SymbolTable::getSymbolAttrName())
newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 30c5f56f13886..8c89ec8bba6cf 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1311,13 +1311,11 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType);
+ parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false,
- getFunctionTypeAttrName());
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp
index af692befb0fde..9481e4ae8175b 100644
--- a/mlir/lib/IR/FunctionImplementation.cpp
+++ b/mlir/lib/IR/FunctionImplementation.cpp
@@ -163,7 +163,7 @@ void mlir::function_interface_impl::addArgAndResultAttrs(
ParseResult mlir::function_interface_impl::parseFunctionOp(
OpAsmParser &parser, OperationState &result, bool allowVariadic,
- StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) {
+ FuncTypeBuilder funcTypeBuilder) {
SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
@@ -197,7 +197,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
<< "failed to construct function type"
<< (errorMessage.empty() ? "" : ": ") << errorMessage;
}
- result.addAttribute(typeAttrName, TypeAttr::get(type));
+ result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
// If function attributes are present, parse them.
NamedAttrList parsedAttributes;
@@ -209,7 +209,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
// dictionary.
for (StringRef disallowed :
{SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
- typeAttrName.getValue()}) {
+ getTypeAttrName()}) {
if (parsedAttributes.get(disallowed))
return parser.emitError(attributeDictLocation, "'")
<< disallowed
@@ -301,11 +301,12 @@ void mlir::function_interface_impl::printFunctionSignature(
}
void mlir::function_interface_impl::printFunctionAttributes(
- OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
+ OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
+ ArrayRef<StringRef> elided) {
// Print out function attributes, if present.
- SmallVector<StringRef, 2> ignoredAttrs = {SymbolTable::getSymbolAttrName(),
- getArgDictAttrName(),
- getResultDictAttrName()};
+ SmallVector<StringRef, 2> ignoredAttrs = {
+ ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
+ getArgDictAttrName(), getResultDictAttrName()};
ignoredAttrs.append(elided.begin(), elided.end());
p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
@@ -313,8 +314,7 @@ void mlir::function_interface_impl::printFunctionAttributes(
void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
FunctionOpInterface op,
- bool isVariadic,
- StringRef typeAttrName) {
+ bool isVariadic) {
// Print the operation and the function name.
auto funcName =
op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
@@ -329,7 +329,8 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
ArrayRef<Type> argTypes = op.getArgumentTypes();
ArrayRef<Type> resultTypes = op.getResultTypes();
printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
- printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName});
+ printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
+ {visibilityAttrName});
// Print the body if this is not an external function.
Region &body = op->getRegion(0);
if (!body.empty()) {
diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp
index 9ba830366056c..3331aefc76d13 100644
--- a/mlir/lib/IR/FunctionInterfaces.cpp
+++ b/mlir/lib/IR/FunctionInterfaces.cpp
@@ -112,7 +112,7 @@ void mlir::function_interface_impl::setAllResultAttrDicts(
}
void mlir::function_interface_impl::insertFunctionArguments(
- FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
+ Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
unsigned originalNumArgs, Type newType) {
assert(argIndices.size() == argTypes.size());
@@ -152,15 +152,15 @@ void mlir::function_interface_impl::insertFunctionArguments(
}
// Update the function type and any entry block arguments.
- op.setFunctionTypeAttr(TypeAttr::get(newType));
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
}
void mlir::function_interface_impl::insertFunctionResults(
- FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
- TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
- unsigned originalNumResults, Type newType) {
+ Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
+ ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
+ Type newType) {
assert(resultIndices.size() == resultTypes.size());
assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
if (resultIndices.empty())
@@ -196,11 +196,11 @@ void mlir::function_interface_impl::insertFunctionResults(
}
// Update the function type.
- op.setFunctionTypeAttr(TypeAttr::get(newType));
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
}
void mlir::function_interface_impl::eraseFunctionArguments(
- FunctionOpInterface op, const BitVector &argIndices, Type newType) {
+ Operation *op, const BitVector &argIndices, Type newType) {
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
@@ -218,12 +218,12 @@ void mlir::function_interface_impl::eraseFunctionArguments(
}
// Update the function type and any entry block arguments.
- op.setFunctionTypeAttr(TypeAttr::get(newType));
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
entry.eraseArguments(argIndices);
}
void mlir::function_interface_impl::eraseFunctionResults(
- FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
+ Operation *op, const BitVector &resultIndices, Type newType) {
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
@@ -239,7 +239,7 @@ void mlir::function_interface_impl::eraseFunctionResults(
}
// Update the function type.
- op.setFunctionTypeAttr(TypeAttr::get(newType));
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
}
TypeRange mlir::function_interface_impl::insertTypesInto(
@@ -276,13 +276,14 @@ TypeRange mlir::function_interface_impl::filterTypesOut(
// Function type signature.
//===----------------------------------------------------------------------===//
-void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op,
+void mlir::function_interface_impl::setFunctionType(Operation *op,
Type newType) {
- unsigned oldNumArgs = op.getNumArguments();
- unsigned oldNumResults = op.getNumResults();
- op.setFunctionTypeAttr(TypeAttr::get(newType));
- unsigned newNumArgs = op.getNumArguments();
- unsigned newNumResults = op.getNumResults();
+ FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
+ unsigned oldNumArgs = funcOp.getNumArguments();
+ unsigned oldNumResults = funcOp.getNumResults();
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+ unsigned newNumArgs = funcOp.getNumArguments();
+ unsigned newNumResults = funcOp.getNumResults();
// Functor used to update the argument and result attributes of the function.
auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
More information about the Mlir-commits
mailing list