[Mlir-commits] [mlir] 8066f22 - [MLIR] Add argument insertion helpers for FunctionLike
Fabian Schuiki
llvmlistbot at llvm.org
Thu Jul 1 00:22:32 PDT 2021
Author: Fabian Schuiki
Date: 2021-07-01T09:18:57+02:00
New Revision: 8066f22c4663d9ee6c763d9108c89448e5c19848
URL: https://github.com/llvm/llvm-project/commit/8066f22c4663d9ee6c763d9108c89448e5c19848
DIFF: https://github.com/llvm/llvm-project/commit/8066f22c4663d9ee6c763d9108c89448e5c19848.diff
LOG: [MLIR] Add argument insertion helpers for FunctionLike
Add helpers to facilitate adding arguments and results to operations
that implement the `FunctionLike` trait. These operations already have a
convenient argument and result *erasure* mechanism, but a corresopnding
utility for insertion is missing. This introduces such a utility.
Added:
mlir/test/IR/test-func-insert-arg.mlir
mlir/test/IR/test-func-insert-result.mlir
Modified:
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/FunctionSupport.h
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/FunctionSupport.cpp
mlir/test/lib/IR/TestFunc.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4edf72667bd6e..edbd1ea2ae912 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -154,6 +154,13 @@ def Builtin_Function : Builtin_Type<"Function", [
unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; }
+ /// Returns a new function type with the specified arguments and results
+ /// inserted.
+ FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+ TypeRange argTypes,
+ ArrayRef<unsigned> resultIndices,
+ TypeRange resultTypes);
+
/// Returns a new function type without the specified arguments and results.
FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
ArrayRef<unsigned> resultIndices);
diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index f27c857a175a6..c7ee9429d583e 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -68,6 +68,19 @@ inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
return resultDict ? resultDict.getValue() : llvm::None;
}
+/// Insert the specified arguments and update the function type attribute.
+void insertFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
+ TypeRange argTypes,
+ ArrayRef<DictionaryAttr> argAttrs,
+ ArrayRef<Optional<Location>> argLocs,
+ unsigned originalNumArgs, Type newType);
+
+/// Insert the specified results and update the function type attribute.
+void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
+ TypeRange resultTypes,
+ ArrayRef<DictionaryAttr> resultAttrs,
+ unsigned originalNumResults, Type newType);
+
/// Erase the specified arguments and update the function type attribute.
void eraseFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
unsigned originalNumArgs, Type newType);
@@ -208,6 +221,22 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
return function_like_impl::getFunctionType(this->getOperation());
}
+ /// Return the type of this function with the specified arguments and results
+ /// inserted. This is used to update the function's signature in the
+ /// `insertArguments` and `insertResults` methods. The arrays must be sorted
+ /// by increasing index.
+ ///
+ /// Note that the concrete class must define a method with the same name to
+ /// hide this one if the concrete class does not use FunctionType for the
+ /// function type under the hood.
+ FunctionType getTypeWithArgsAndResults(ArrayRef<unsigned> argIndices,
+ TypeRange argTypes,
+ ArrayRef<unsigned> resultIndices,
+ TypeRange resultTypes) {
+ return getType().getWithArgsAndResults(argIndices, argTypes, resultIndices,
+ resultTypes);
+ }
+
/// Return the type of this function without the specified arguments and
/// results. This is used to update the function's signature in the
/// `eraseArguments` and `eraseResults` methods. The arrays of indices are
@@ -267,6 +296,48 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
return getBody().getArgumentTypes();
}
+ /// Insert a single argument of type `argType` with attributes `argAttrs` and
+ /// location `argLoc` at `argIndex`.
+ void insertArgument(unsigned argIndex, Type argType, DictionaryAttr argAttrs,
+ Optional<Location> argLoc = {}) {
+ insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc});
+ }
+
+ /// Inserts arguments with the listed types, attributes, and locations at the
+ /// listed indices. `argIndices` must be sorted. Arguments are inserted in the
+ /// order they are listed, such that arguments with identical index will
+ /// appear in the same order that they were listed here.
+ void insertArguments(ArrayRef<unsigned> argIndices, TypeRange argTypes,
+ ArrayRef<DictionaryAttr> argAttrs,
+ ArrayRef<Optional<Location>> argLocs) {
+ unsigned originalNumArgs = getNumArguments();
+ Type newType = getTypeWithArgsAndResults(
+ argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{});
+ function_like_impl::insertFunctionArguments(
+ this->getOperation(), argIndices, argTypes, argAttrs, argLocs,
+ originalNumArgs, newType);
+ }
+
+ /// Insert a single result of type `resultType` at `resultIndex`.
+ void insertResult(unsigned resultIndex, Type resultType,
+ DictionaryAttr resultAttrs) {
+ insertResults({resultIndex}, {resultType}, {resultAttrs});
+ }
+
+ /// Inserts results with the listed types at the listed indices.
+ /// `resultIndices` must be sorted. Results are inserted in the order they are
+ /// listed, such that results with identical index will appear in the same
+ /// order that they were listed here.
+ void insertResults(ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
+ ArrayRef<DictionaryAttr> resultAttrs) {
+ unsigned originalNumResults = getNumResults();
+ Type newType = getTypeWithArgsAndResults(/*argIndices=*/{}, /*argTypes=*/{},
+ resultIndices, resultTypes);
+ function_like_impl::insertFunctionResults(
+ this->getOperation(), resultIndices, resultTypes, resultAttrs,
+ originalNumResults, newType);
+ }
+
/// Erase a single argument at `argIndex`.
void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d5fd1eadbb69f..f350596384a90 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -172,6 +172,45 @@ inline void iterateIndicesExcept(unsigned totalIndices,
callback(i);
}
+/// Returns a new function type with the specified arguments and results
+/// inserted.
+FunctionType FunctionType::getWithArgsAndResults(
+ ArrayRef<unsigned> argIndices, TypeRange argTypes,
+ ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
+ assert(argIndices.size() == argTypes.size());
+ assert(resultIndices.size() == resultTypes.size());
+
+ ArrayRef<Type> newInputTypes = getInputs();
+ SmallVector<Type, 4> newInputTypesBuffer;
+ if (!argIndices.empty()) {
+ const auto *fromIt = newInputTypes.begin();
+ for (auto it : llvm::zip(argIndices, argTypes)) {
+ const auto *toIt = newInputTypes.begin() + std::get<0>(it);
+ newInputTypesBuffer.append(fromIt, toIt);
+ newInputTypesBuffer.push_back(std::get<1>(it));
+ fromIt = toIt;
+ }
+ newInputTypesBuffer.append(fromIt, newInputTypes.end());
+ newInputTypes = newInputTypesBuffer;
+ }
+
+ ArrayRef<Type> newResultTypes = getResults();
+ SmallVector<Type, 4> newResultTypesBuffer;
+ if (!resultIndices.empty()) {
+ const auto *fromIt = newResultTypes.begin();
+ for (auto it : llvm::zip(resultIndices, resultTypes)) {
+ const auto *toIt = newResultTypes.begin() + std::get<0>(it);
+ newResultTypesBuffer.append(fromIt, toIt);
+ newResultTypesBuffer.push_back(std::get<1>(it));
+ fromIt = toIt;
+ }
+ newResultTypesBuffer.append(fromIt, newResultTypes.end());
+ newResultTypes = newResultTypesBuffer;
+ }
+
+ return FunctionType::get(getContext(), newInputTypes, newResultTypes);
+}
+
/// Returns a new function type without the specified arguments and results.
FunctionType
FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp
index b8a0ebc3f4a53..4f6f76cfbcfb9 100644
--- a/mlir/lib/IR/FunctionSupport.cpp
+++ b/mlir/lib/IR/FunctionSupport.cpp
@@ -121,6 +121,95 @@ void mlir::function_like_impl::setAllResultAttrDicts(
llvm::to_vector<8>(wrappedAttrs));
}
+void mlir::function_like_impl::insertFunctionArguments(
+ Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
+ ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Optional<Location>> argLocs,
+ unsigned originalNumArgs, Type newType) {
+ assert(argIndices.size() == argTypes.size());
+ assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
+ assert(argIndices.size() == argLocs.size() || argLocs.empty());
+ if (argIndices.empty())
+ return;
+
+ // There are 3 things that need to be updated:
+ // - Function type.
+ // - Arg attrs.
+ // - Block arguments of entry block.
+ Block &entry = op->getRegion(0).front();
+
+ // Update the argument attributes of the function.
+ auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+ if (oldArgAttrs || !argAttrs.empty()) {
+ SmallVector<DictionaryAttr, 4> newArgAttrs;
+ newArgAttrs.reserve(originalNumArgs + argIndices.size());
+ unsigned oldIdx = 0;
+ auto migrate = [&](unsigned untilIdx) {
+ if (!oldArgAttrs) {
+ newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
+ } else {
+ auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
+ newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
+ oldArgAttrRange.begin() + untilIdx);
+ }
+ oldIdx = untilIdx;
+ };
+ for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
+ migrate(argIndices[i]);
+ newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
+ }
+ migrate(originalNumArgs);
+ setAllArgAttrDicts(op, newArgAttrs);
+ }
+
+ // Update the function type and any entry block arguments.
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+ for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
+ entry.insertArgument(argIndices[i], argTypes[i],
+ argLocs.empty() ? Optional<Location>{} : argLocs[i]);
+}
+
+void mlir::function_like_impl::insertFunctionResults(
+ Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
+ ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
+ Type newType) {
+ assert(resultIndices.size() == resultTypes.size());
+ assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
+ if (resultIndices.empty())
+ return;
+
+ // There are 2 things that need to be updated:
+ // - Function type.
+ // - Result attrs.
+
+ // Update the result attributes of the function.
+ auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+ if (oldResultAttrs || !resultAttrs.empty()) {
+ SmallVector<DictionaryAttr, 4> newResultAttrs;
+ newResultAttrs.reserve(originalNumResults + resultIndices.size());
+ unsigned oldIdx = 0;
+ auto migrate = [&](unsigned untilIdx) {
+ if (!oldResultAttrs) {
+ newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
+ } else {
+ auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
+ newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
+ oldResultAttrsRange.begin() + untilIdx);
+ }
+ oldIdx = untilIdx;
+ };
+ for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
+ migrate(resultIndices[i]);
+ newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
+ : resultAttrs[i]);
+ }
+ migrate(originalNumResults);
+ setAllResultAttrDicts(op, newResultAttrs);
+ }
+
+ // Update the function type.
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+}
+
void mlir::function_like_impl::eraseFunctionArguments(
Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs,
Type newType) {
diff --git a/mlir/test/IR/test-func-insert-arg.mlir b/mlir/test/IR/test-func-insert-arg.mlir
new file mode 100644
index 0000000000000..2de6c666d0d31
--- /dev/null
+++ b/mlir/test/IR/test-func-insert-arg.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -test-func-insert-arg -split-input-file | FileCheck %s
+
+// CHECK: func @f(%arg0: f32 {test.A})
+func @f() attributes {test.insert_args = [
+ [0, f32, {test.A}]]} {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B})
+func @f(%arg0: f32 {test.B}) attributes {test.insert_args = [
+ [0, f32, {test.A}]]} {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B})
+func @f(%arg0: f32 {test.A}) attributes {test.insert_args = [
+ [1, f32, {test.B}]]} {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B}, %arg2: f32 {test.C})
+func @f(%arg0: f32 {test.A}, %arg1: f32 {test.C}) attributes {test.insert_args = [
+ [1, f32, {test.B}]]} {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B}, %arg2: f32 {test.C})
+func @f(%arg0: f32 {test.B}) attributes {test.insert_args = [
+ [0, f32, {test.A}],
+ [1, f32, {test.C}]]} {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B}, %arg2: f32 {test.C})
+func @f(%arg0: f32 {test.C}) attributes {test.insert_args = [
+ [0, f32, {test.A}],
+ [0, f32, {test.B}]]} {
+ return
+}
diff --git a/mlir/test/IR/test-func-insert-result.mlir b/mlir/test/IR/test-func-insert-result.mlir
new file mode 100644
index 0000000000000..129fff4b56fe9
--- /dev/null
+++ b/mlir/test/IR/test-func-insert-result.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -test-func-insert-result -split-input-file | FileCheck %s
+
+// CHECK: func private @f() -> (f32 {test.A})
+func private @f() attributes {test.insert_results = [
+ [0, f32, {test.A}]]}
+
+// -----
+
+// CHECK: func private @f() -> (f32 {test.A}, f32 {test.B})
+func private @f() -> (f32 {test.B}) attributes {test.insert_results = [
+ [0, f32, {test.A}]]}
+
+// -----
+
+// CHECK: func private @f() -> (f32 {test.A}, f32 {test.B})
+func private @f() -> (f32 {test.A}) attributes {test.insert_results = [
+ [1, f32, {test.B}]]}
+
+// -----
+
+// CHECK: func private @f() -> (f32 {test.A}, f32 {test.B}, f32 {test.C})
+func private @f() -> (f32 {test.A}, f32 {test.C}) attributes {test.insert_results = [
+ [1, f32, {test.B}]]}
+
+// -----
+
+// CHECK: func private @f() -> (f32 {test.A}, f32 {test.B}, f32 {test.C})
+func private @f() -> (f32 {test.B}) attributes {test.insert_results = [
+ [0, f32, {test.A}],
+ [1, f32, {test.C}]]}
+
+// -----
+
+// CHECK: func private @f() -> (f32 {test.A}, f32 {test.B}, f32 {test.C})
+func private @f() -> (f32 {test.C}) attributes {test.insert_results = [
+ [0, f32, {test.A}],
+ [0, f32, {test.B}]]}
diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp
index 4f2b45628ed99..a2d0c796c66b5 100644
--- a/mlir/test/lib/IR/TestFunc.cpp
+++ b/mlir/test/lib/IR/TestFunc.cpp
@@ -12,6 +12,72 @@
using namespace mlir;
namespace {
+/// This is a test pass for verifying FuncOp's insertArgument method.
+struct TestFuncInsertArg
+ : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
+ StringRef getArgument() const final { return "test-func-insert-arg"; }
+ StringRef getDescription() const final { return "Test inserting func args."; }
+ void runOnOperation() override {
+ auto module = getOperation();
+
+ for (FuncOp func : module.getOps<FuncOp>()) {
+ auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
+ if (!inserts || inserts.empty())
+ continue;
+ SmallVector<unsigned, 4> indicesToInsert;
+ SmallVector<Type, 4> typesToInsert;
+ SmallVector<DictionaryAttr, 4> attrsToInsert;
+ SmallVector<Optional<Location>, 4> locsToInsert;
+ for (auto insert : inserts.getAsRange<ArrayAttr>()) {
+ indicesToInsert.push_back(
+ insert[0].cast<IntegerAttr>().getValue().getZExtValue());
+ typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
+ attrsToInsert.push_back(insert.size() > 2
+ ? insert[2].cast<DictionaryAttr>()
+ : DictionaryAttr::get(&getContext()));
+ locsToInsert.push_back(
+ insert.size() > 3
+ ? Optional<Location>(insert[3].cast<LocationAttr>())
+ : Optional<Location>{});
+ }
+ func->removeAttr("test.insert_args");
+ func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
+ locsToInsert);
+ }
+ }
+};
+
+/// This is a test pass for verifying FuncOp's insertResult method.
+struct TestFuncInsertResult
+ : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
+ StringRef getArgument() const final { return "test-func-insert-result"; }
+ StringRef getDescription() const final {
+ return "Test inserting func results.";
+ }
+ void runOnOperation() override {
+ auto module = getOperation();
+
+ for (FuncOp func : module.getOps<FuncOp>()) {
+ auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
+ if (!inserts || inserts.empty())
+ continue;
+ SmallVector<unsigned, 4> indicesToInsert;
+ SmallVector<Type, 4> typesToInsert;
+ SmallVector<DictionaryAttr, 4> attrsToInsert;
+ for (auto insert : inserts.getAsRange<ArrayAttr>()) {
+ indicesToInsert.push_back(
+ insert[0].cast<IntegerAttr>().getValue().getZExtValue());
+ typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
+ attrsToInsert.push_back(insert.size() > 2
+ ? insert[2].cast<DictionaryAttr>()
+ : DictionaryAttr::get(&getContext()));
+ }
+ func->removeAttr("test.insert_results");
+ func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
+ }
+ }
+};
+
/// This is a test pass for verifying FuncOp's eraseArgument method.
struct TestFuncEraseArg
: public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
@@ -51,18 +117,15 @@ struct TestFuncEraseResult
for (FuncOp func : module.getOps<FuncOp>()) {
SmallVector<unsigned, 4> indicesToErase;
for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
- if (func.getResultAttr(resultIndex, "test.erase_this_"
- "result")) {
- // Push back twice to test
- // that duplicate indices
- // are handled correctly.
+ if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
+ // Push back twice to test that duplicate indices are handled
+ // correctly.
indicesToErase.push_back(resultIndex);
indicesToErase.push_back(resultIndex);
}
}
- // Reverse the order to test
- // that unsorted index lists are
- // handled correctly.
+ // Reverse the order to test that unsorted index lists are handled
+ // correctly.
std::reverse(indicesToErase.begin(), indicesToErase.end());
func.eraseResults(indicesToErase);
}
@@ -90,6 +153,10 @@ struct TestFuncSetType
namespace mlir {
void registerTestFunc() {
+ PassRegistration<TestFuncInsertArg>();
+
+ PassRegistration<TestFuncInsertResult>();
+
PassRegistration<TestFuncEraseArg>();
PassRegistration<TestFuncEraseResult>();
More information about the Mlir-commits
mailing list