[Mlir-commits] [mlir] 2e36e0d - [MLIR] Move eraseArguments and eraseResults to FunctionLike
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 3 15:53:53 PST 2020
Author: mikeurbach
Date: 2020-11-03T16:53:46-07:00
New Revision: 2e36e0dad52b07f0e856f939d530d47bbe8a74ac
URL: https://github.com/llvm/llvm-project/commit/2e36e0dad52b07f0e856f939d530d47bbe8a74ac
DIFF: https://github.com/llvm/llvm-project/commit/2e36e0dad52b07f0e856f939d530d47bbe8a74ac.diff
LOG: [MLIR] Move eraseArguments and eraseResults to FunctionLike
Previously, they were only defined for `FuncOp`.
To support this, `FunctionLike` needs a way to get an updated type
from the concrete operation. This adds a new hook for that purpose,
called `getTypeWithoutArgsAndResults`.
For now, `FunctionLike` continues to assume the type is
`FunctionType`, and concrete operations that use another type can hide
the `getType`, `setType`, and `getTypeWithoutArgsAndResults` methods.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D90363
Added:
mlir/lib/IR/FunctionSupport.cpp
Modified:
mlir/docs/Traits.md
mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/Function.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/include/mlir/IR/Types.h
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/IR/Block.cpp
mlir/lib/IR/CMakeLists.txt
mlir/lib/IR/Function.cpp
mlir/lib/IR/Types.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md
index 50c67c7efa25..c6db640f7f5f 100644
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits.md
@@ -255,11 +255,17 @@ particular:
- they can have argument and result attributes that are stored in dictionary
attributes on the operation itself.
-This trait does *NOT* provide type support for the functions, meaning that
-concrete Ops must handle the type of the declared or defined function.
-`getTypeAttrName()` is a convenience function that returns the name of the
-attribute that can be used to store the function type, but the trait makes no
-assumption based on it.
+This trait provides limited type support for the declared or defined functions.
+The convenience function `getTypeAttrName()` returns the name of an attribute
+that can be used to store the function type. In addition, this trait provides
+`getType` and `setType` helpers to store a `FunctionType` in the attribute named
+by `getTypeAttrName()`.
+
+In general, this trait assumes concrete ops use `FunctionType` under the hood.
+If this is not the case, in order to use the function type support, concrete ops
+must define the following methods, using the same name, to hide the ones defined
+for `FunctionType`: `addBodyBlock`, `getType`, `getTypeWithoutArgsAndResults`
+and `setType`.
### HasParent
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index ca2523050b24..3e867976cc32 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -16,6 +16,10 @@
#include "mlir/IR/BlockSupport.h"
#include "mlir/IR/Visitors.h"
+namespace llvm {
+class BitVector;
+} // end namespace llvm
+
namespace mlir {
class TypeRange;
template <typename ValueRangeT> class ValueTypeRange;
@@ -98,6 +102,13 @@ class Block : public IRObjectWithUseList<BlockOperand>,
/// Erase the argument at 'index' and remove it from the argument list.
void eraseArgument(unsigned index);
+ /// Erases the arguments listed in `argIndices` and removes them from the
+ /// argument list.
+ /// `argIndices` is allowed to have duplicates and can be in any order.
+ void eraseArguments(ArrayRef<unsigned> argIndices);
+ /// Erases the arguments that have their corresponding bit set in
+ /// `eraseIndices` and removes them from the argument list.
+ void eraseArguments(llvm::BitVector eraseIndices);
unsigned getNumArguments() { return arguments.size(); }
BlockArgument getArgument(unsigned i) { return arguments[i]; }
diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 5c57b754828c..3ea546cd52a3 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -59,18 +59,6 @@ class FuncOp
void print(OpAsmPrinter &p);
LogicalResult verify();
- /// Erase a single argument at `argIndex`.
- void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
- /// Erases the arguments listed in `argIndices`.
- /// `argIndices` is allowed to have duplicates and can be in any order.
- void eraseArguments(ArrayRef<unsigned> argIndices);
-
- /// Erase a single result at `resultIndex`.
- void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
- /// Erases the results listed in `resultIndices`.
- /// `resultIndices` is allowed to have duplicates and can be in any order.
- void eraseResults(ArrayRef<unsigned> resultIndices);
-
/// Create a deep copy of this function and all of its blocks, remapping
/// any operands that use values outside of the function using the map that is
/// provided (leaving them alone if no entry is present). If the mapper
diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index 7756761c2a52..524b58e37210 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -71,6 +71,14 @@ inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
return resultDict ? resultDict.getValue() : llvm::None;
}
+/// Erase the specified arguments and update the function type attribute.
+void eraseFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
+ unsigned originalNumArgs, Type newType);
+
+/// Erase the specified results and update the function type attribute.
+void eraseFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
+ unsigned originalNumResults, Type newType);
+
} // namespace impl
namespace OpTrait {
@@ -84,12 +92,21 @@ namespace OpTrait {
/// arguments;
/// - they can have argument attributes that are stored in a dictionary
/// attribute on the Op itself.
-/// This trait does *NOT* provide type support for the functions, meaning that
-/// concrete Ops must handle the type of the declared or defined function.
-/// `getTypeAttrName()` is a convenience function that returns the name of the
-/// attribute that can be used to store the function type, but the trait makes
-/// no assumption based on it.
///
+/// This trait provides limited type support for the declared or defined
+/// functions. The convenience function `getTypeAttrName()` returns the name of
+/// an attribute that can be used to store the function type. In addition, this
+/// trait provides `getType` and `setType` helpers to store a `FunctionType` in
+/// the attribute named by `getTypeAttrName()`.
+///
+/// In general, this trait assumes concrete ops use `FunctionType` under the
+/// hood. If this is not the case, in order to use the function type support,
+/// concrete ops must define the following methods, using the same name, to hide
+/// the ones defined for `FunctionType`: `addBodyBlock`, `getType`,
+/// `getTypeWithoutArgsAndResults` and `setType`.
+///
+/// Besides the requirements above, concrete ops must interact with this trait
+/// using the following functions:
/// - Concrete ops *must* define a member function `getNumFuncArguments()` that
/// returns the number of function arguments based exclusively on type (so
/// that it can be called on function declarations).
@@ -183,6 +200,19 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
return getTypeAttr().getValue().template cast<FunctionType>();
}
+ /// Return the type of this function without the specified arguments and
+ /// results. This is used to update the function's signature in the
+ /// `eraseArguments` and `eraseResults` methods. The arrays of indices are
+ /// allowed to have duplicates and can be in any order.
+ ///
+ /// Note that the concrete class must define a method with the same name to
+ /// hide this one if the concrete class does not use FunctionType for the
+ /// function type under the hood.
+ FunctionType getTypeWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
+ ArrayRef<unsigned> resultIndices) {
+ return getType().getWithoutArgsAndResults(argIndices, resultIndices);
+ }
+
bool isTypeAttrValid() {
auto typeAttr = getTypeAttr();
if (!typeAttr)
@@ -204,7 +234,7 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
void setType(FunctionType newType);
//===--------------------------------------------------------------------===//
- // Argument Handling
+ // Argument and Result Handling
//===--------------------------------------------------------------------===//
using BlockArgListType = Region::BlockArgListType;
@@ -229,6 +259,30 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
return getBody().getArgumentTypes();
}
+ /// Erase a single argument at `argIndex`.
+ void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
+
+ /// Erases the arguments listed in `argIndices`.
+ /// `argIndices` is allowed to have duplicates and can be in any order.
+ void eraseArguments(ArrayRef<unsigned> argIndices) {
+ unsigned originalNumArgs = getNumArguments();
+ Type newType = getTypeWithoutArgsAndResults(argIndices, {});
+ ::mlir::impl::eraseFunctionArguments(this->getOperation(), argIndices,
+ originalNumArgs, newType);
+ }
+
+ /// Erase a single result at `resultIndex`.
+ void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
+
+ /// Erases the results listed in `resultIndices`.
+ /// `resultIndices` is allowed to have duplicates and can be in any order.
+ void eraseResults(ArrayRef<unsigned> resultIndices) {
+ unsigned originalNumResults = getNumResults();
+ Type newType = getTypeWithoutArgsAndResults({}, resultIndices);
+ ::mlir::impl::eraseFunctionResults(this->getOperation(), resultIndices,
+ originalNumResults, newType);
+ }
+
//===--------------------------------------------------------------------===//
// Argument Attributes
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index ad7e436068bc..09f49d8f7b4f 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -238,15 +238,19 @@ class FunctionType
static FunctionType get(TypeRange inputs, TypeRange results,
MLIRContext *context);
- // Input types.
+ /// Input types.
unsigned getNumInputs() const;
Type getInput(unsigned i) const { return getInputs()[i]; }
ArrayRef<Type> getInputs() const;
- // Result types.
+ /// Result types.
unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; }
ArrayRef<Type> getResults() const;
+
+ /// Returns a new function type without the specified arguments and results.
+ FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
+ ArrayRef<unsigned> resultIndices);
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index b2c649be9d92..2696c4b52d3a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -25,8 +25,6 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
-#include <set>
-
#define DEBUG_TYPE "linalg-drop-unit-dims"
using namespace mlir;
@@ -166,9 +164,8 @@ LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
for (unsigned unitDimLoop : unitDims) {
entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
}
- std::set<unsigned> orderedUnitDims(unitDims.begin(), unitDims.end());
- for (unsigned i : llvm::reverse(orderedUnitDims))
- entryBlock->eraseArgument(i);
+ SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end());
+ entryBlock->eraseArguments(unitDimsToErase);
return success();
}
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index e039b41ae4b7..b9ddabb80800 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -9,6 +9,7 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
+#include "llvm/ADT/BitVector.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -176,6 +177,22 @@ void Block::eraseArgument(unsigned index) {
arguments.erase(arguments.begin() + index);
}
+void Block::eraseArguments(ArrayRef<unsigned> argIndices) {
+ llvm::BitVector eraseIndices(getNumArguments());
+ for (unsigned i : argIndices)
+ eraseIndices.set(i);
+ eraseArguments(eraseIndices);
+}
+
+void Block::eraseArguments(llvm::BitVector eraseIndices) {
+ // We do this in reverse so that we erase later indices before earlier
+ // indices, to avoid shifting the later indices.
+ unsigned originalNumArgs = getNumArguments();
+ for (unsigned i = 0; i < originalNumArgs; ++i)
+ if (eraseIndices.test(originalNumArgs - i - 1))
+ eraseArgument(originalNumArgs - i - 1);
+}
+
/// Insert one value to the given position of the argument list. The existing
/// arguments are shifted. The block is expected not to have predecessors.
BlockArgument Block::insertArgument(args_iterator it, Type type) {
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 553408f6fb36..1305e1156490 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRIR
Dominance.cpp
Function.cpp
FunctionImplementation.cpp
+ FunctionSupport.cpp
IntegerSet.cpp
Location.cpp
MLIRContext.cpp
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index fb525d86912d..03378f21f638 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -98,65 +98,6 @@ LogicalResult FuncOp::verify() {
return success();
}
-void FuncOp::eraseArguments(ArrayRef<unsigned> argIndices) {
- auto oldType = getType();
- int originalNumArgs = oldType.getNumInputs();
- llvm::BitVector eraseIndices(originalNumArgs);
- for (auto index : argIndices)
- eraseIndices.set(index);
- auto shouldEraseArg = [&](int i) { return eraseIndices.test(i); };
-
- // There are 3 things that need to be updated:
- // - Function type.
- // - Arg attrs.
- // - Block arguments of entry block.
-
- // Update the function type and arg attrs.
- SmallVector<Type, 4> newInputTypes;
- SmallVector<MutableDictionaryAttr, 4> newArgAttrs;
- for (int i = 0; i < originalNumArgs; i++) {
- if (shouldEraseArg(i))
- continue;
- newInputTypes.emplace_back(oldType.getInput(i));
- newArgAttrs.emplace_back(getArgAttrDict(i));
- }
- setType(FunctionType::get(newInputTypes, oldType.getResults(), getContext()));
- setAllArgAttrs(newArgAttrs);
-
- // Update the entry block's arguments.
- // We do this in reverse so that we erase later indices before earlier
- // indices, to avoid shifting the later indices.
- Block &entry = front();
- for (int i = 0; i < originalNumArgs; i++)
- if (shouldEraseArg(originalNumArgs - i - 1))
- entry.eraseArgument(originalNumArgs - i - 1);
-}
-
-void FuncOp::eraseResults(ArrayRef<unsigned> resultIndices) {
- auto oldType = getType();
- int originalNumResults = oldType.getNumResults();
- llvm::BitVector eraseIndices(originalNumResults);
- for (auto index : resultIndices)
- eraseIndices.set(index);
- auto shouldEraseResult = [&](int i) { return eraseIndices.test(i); };
-
- // There are 2 things that need to be updated:
- // - Function type.
- // - Result attrs.
-
- // Update the function type and result attrs.
- SmallVector<Type, 4> newResultTypes;
- SmallVector<MutableDictionaryAttr, 4> newResultAttrs;
- for (int i = 0; i < originalNumResults; i++) {
- if (shouldEraseResult(i))
- continue;
- newResultTypes.emplace_back(oldType.getResult(i));
- newResultAttrs.emplace_back(getResultAttrDict(i));
- }
- setType(FunctionType::get(oldType.getInputs(), newResultTypes, getContext()));
- setAllResultAttrs(newResultAttrs);
-}
-
/// Clone the internal blocks from this function into dest and all attributes
/// from this function to dest.
void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp
new file mode 100644
index 000000000000..259c465d7ba5
--- /dev/null
+++ b/mlir/lib/IR/FunctionSupport.cpp
@@ -0,0 +1,103 @@
+//===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/FunctionSupport.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/BitVector.h"
+
+using namespace mlir;
+
+/// Helper to call a callback once on each index in the range
+/// [0, `totalIndices`), *except* for the indices given in `indices`.
+/// `indices` is allowed to have duplicates and can be in any order.
+inline void iterateIndicesExcept(unsigned totalIndices,
+ ArrayRef<unsigned> indices,
+ function_ref<void(unsigned)> callback) {
+ llvm::BitVector skipIndices(totalIndices);
+ for (unsigned i : indices)
+ skipIndices.set(i);
+
+ for (unsigned i = 0; i < totalIndices; ++i)
+ if (!skipIndices.test(i))
+ callback(i);
+}
+
+//===----------------------------------------------------------------------===//
+// Function Arguments and Results.
+//===----------------------------------------------------------------------===//
+
+void mlir::impl::eraseFunctionArguments(Operation *op,
+ ArrayRef<unsigned> argIndices,
+ unsigned originalNumArgs,
+ Type newType) {
+ // There are 3 things that need to be updated:
+ // - Function type.
+ // - Arg attrs.
+ // - Block arguments of entry block.
+ Block &entry = op->getRegion(0).front();
+ SmallString<8> nameBuf;
+
+ // Collect arg attrs to set.
+ SmallVector<MutableDictionaryAttr, 4> newArgAttrs;
+ iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
+ newArgAttrs.emplace_back(getArgAttrDict(op, i));
+ });
+
+ // Remove any arg attrs that are no longer needed.
+ for (unsigned i = newArgAttrs.size(), e = originalNumArgs; i < e; ++i)
+ op->removeAttr(getArgAttrName(i, nameBuf));
+
+ // Set the function type.
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+
+ // Set the new arg attrs, or remove them if empty.
+ for (unsigned i = 0, e = newArgAttrs.size(); i != e; ++i) {
+ auto nameAttr = getArgAttrName(i, nameBuf);
+ auto argAttr = newArgAttrs[i];
+ if (argAttr.empty())
+ op->removeAttr(nameAttr);
+ else
+ op->setAttr(nameAttr, argAttr.getDictionary(op->getContext()));
+ }
+
+ // Update the entry block's arguments.
+ entry.eraseArguments(argIndices);
+}
+
+void mlir::impl::eraseFunctionResults(Operation *op,
+ ArrayRef<unsigned> resultIndices,
+ unsigned originalNumResults,
+ Type newType) {
+ // There are 2 things that need to be updated:
+ // - Function type.
+ // - Result attrs.
+ SmallString<8> nameBuf;
+
+ // Collect result attrs to set.
+ SmallVector<MutableDictionaryAttr, 4> newResultAttrs;
+ iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
+ newResultAttrs.emplace_back(getResultAttrDict(op, i));
+ });
+
+ // Remove any result attrs that are no longer needed.
+ for (unsigned i = newResultAttrs.size(), e = originalNumResults; i < e; ++i)
+ op->removeAttr(getResultAttrName(i, nameBuf));
+
+ // Set the function type.
+ op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+
+ // Set the new result attrs, or remove them if empty.
+ for (unsigned i = 0, e = newResultAttrs.size(); i != e; ++i) {
+ auto nameAttr = getResultAttrName(i, nameBuf);
+ auto resultAttr = newResultAttrs[i];
+ if (resultAttr.empty())
+ op->removeAttr(nameAttr);
+ else
+ op->setAttr(nameAttr, resultAttr.getDictionary(op->getContext()));
+ }
+}
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index cdcd6a9c6ea5..f3c1c2c11247 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -10,6 +10,8 @@
#include "TypeDetail.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Twine.h"
using namespace mlir;
@@ -46,6 +48,48 @@ ArrayRef<Type> FunctionType::getResults() const {
return getImpl()->getResults();
}
+/// Helper to call a callback once on each index in the range
+/// [0, `totalIndices`), *except* for the indices given in `indices`.
+/// `indices` is allowed to have duplicates and can be in any order.
+inline void iterateIndicesExcept(unsigned totalIndices,
+ ArrayRef<unsigned> indices,
+ function_ref<void(unsigned)> callback) {
+ llvm::BitVector skipIndices(totalIndices);
+ for (unsigned i : indices)
+ skipIndices.set(i);
+
+ for (unsigned i = 0; i < totalIndices; ++i)
+ if (!skipIndices.test(i))
+ callback(i);
+}
+
+/// Returns a new function type without the specified arguments and results.
+FunctionType
+FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
+ ArrayRef<unsigned> resultIndices) {
+ ArrayRef<Type> newInputTypes = getInputs();
+ SmallVector<Type, 4> newInputTypesBuffer;
+ if (!argIndices.empty()) {
+ unsigned originalNumArgs = getNumInputs();
+ iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
+ newInputTypesBuffer.emplace_back(getInput(i));
+ });
+ newInputTypes = newInputTypesBuffer;
+ }
+
+ ArrayRef<Type> newResultTypes = getResults();
+ SmallVector<Type, 4> newResultTypesBuffer;
+ if (!resultIndices.empty()) {
+ unsigned originalNumResults = getNumResults();
+ iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
+ newResultTypesBuffer.emplace_back(getResult(i));
+ });
+ newResultTypes = newResultTypesBuffer;
+ }
+
+ return get(newInputTypes, newResultTypes, getContext());
+}
+
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list