[Mlir-commits] [mlir] e3cd80e - [mlir:Function][NFC] Use BitVector instead of ArrayRef for indices when removing arguments/results
River Riddle
llvmlistbot at llvm.org
Wed Jan 26 21:38:03 PST 2022
Author: River Riddle
Date: 2022-01-26T21:37:22-08:00
New Revision: e3cd80ea9f0ac0d04f537feb70d8f9a1c7875863
URL: https://github.com/llvm/llvm-project/commit/e3cd80ea9f0ac0d04f537feb70d8f9a1c7875863
DIFF: https://github.com/llvm/llvm-project/commit/e3cd80ea9f0ac0d04f537feb70d8f9a1c7875863.diff
LOG: [mlir:Function][NFC] Use BitVector instead of ArrayRef for indices when removing arguments/results
We already convert to BitVector internally, and other APIs (namely Operation::eraseOperands)
already use BitVector as well. Switching over provides a common format between
API and also reduces the amount of format conversions necessary.
Fixes #53325
Differential Revision: https://reviews.llvm.org/D118083
Added:
Modified:
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/FunctionInterfaces.h
mlir/include/mlir/IR/FunctionInterfaces.td
mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/FunctionInterfaces.cpp
mlir/test/lib/IR/TestFunc.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index e087e6bf55d8c..b03d9ea9f575d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -13,6 +13,7 @@
#include "SubElementInterfaces.h"
namespace llvm {
+class BitVector;
struct fltSemantics;
} // namespace llvm
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b6f90f92190a8..a78413548ba82 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -166,8 +166,8 @@ def Builtin_Function : Builtin_Type<"Function", [
TypeRange resultTypes);
/// Returns a new function type without the specified arguments and results.
- FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
- ArrayRef<unsigned> resultIndices);
+ FunctionType getWithoutArgsAndResults(const llvm::BitVector &argIndices,
+ const llvm::BitVector &resultIndices);
}];
}
diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h
index b81abceef6149..b6d6a9515ffe0 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.h
+++ b/mlir/include/mlir/IR/FunctionInterfaces.h
@@ -16,6 +16,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallString.h"
namespace mlir {
@@ -82,12 +83,12 @@ void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
unsigned originalNumResults, Type newType);
/// Erase the specified arguments and update the function type attribute.
-void eraseFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
- unsigned originalNumArgs, Type newType);
+void eraseFunctionArguments(Operation *op, const llvm::BitVector &argIndices,
+ Type newType);
/// Erase the specified results and update the function type attribute.
-void eraseFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
- unsigned originalNumResults, Type newType);
+void eraseFunctionResults(Operation *op, const llvm::BitVector &resultIndices,
+ Type newType);
/// Set a FunctionOpInterface operation's type signature.
void setFunctionType(Operation *op, Type newType);
@@ -100,7 +101,7 @@ TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices,
/// Filters out any elements referenced by `indices`. If any types are removed,
/// `storage` is used to hold the new type list. Returns the new type list.
-TypeRange filterTypesOut(TypeRange types, ArrayRef<unsigned> indices,
+TypeRange filterTypesOut(TypeRange types, const llvm::BitVector &indices,
SmallVectorImpl<Type> &storage);
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td
index 20c7d7bbd51b7..124f6594081f7 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.td
+++ b/mlir/include/mlir/IR/FunctionInterfaces.td
@@ -280,27 +280,31 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
}
/// Erase a single argument at `argIndex`.
- void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
+ void eraseArgument(unsigned argIndex) {
+ llvm::BitVector argsToErase($_op.getNumArguments());
+ argsToErase.set(argIndex);
+ eraseArguments(argsToErase);
+ }
/// Erases the arguments listed in `argIndices`.
- /// `argIndices` is allowed to have duplicates and can be in any order.
- void eraseArguments(ArrayRef<unsigned> argIndices) {
- unsigned originalNumArgs = $_op.getNumArguments();
- Type newType = $_op.getTypeWithoutArgsAndResults(argIndices, {});
- function_interface_impl::eraseFunctionArguments(this->getOperation(), argIndices,
- originalNumArgs, newType);
+ void eraseArguments(const llvm::BitVector &argIndices) {
+ Type newType = $_op.getTypeWithoutArgs(argIndices);
+ function_interface_impl::eraseFunctionArguments(
+ this->getOperation(), argIndices, newType);
}
/// Erase a single result at `resultIndex`.
- void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
+ void eraseResult(unsigned resultIndex) {
+ llvm::BitVector resultsToErase($_op.getNumResults());
+ resultsToErase.set(resultIndex);
+ eraseResults(resultsToErase);
+ }
/// Erases the results listed in `resultIndices`.
- /// `resultIndices` is allowed to have duplicates and can be in any order.
- void eraseResults(ArrayRef<unsigned> resultIndices) {
- unsigned originalNumResults = $_op.getNumResults();
- Type newType = $_op.getTypeWithoutArgsAndResults({}, resultIndices);
+ void eraseResults(const llvm::BitVector &resultIndices) {
+ Type newType = $_op.getTypeWithoutResults(resultIndices);
function_interface_impl::eraseFunctionResults(
- this->getOperation(), resultIndices, originalNumResults, newType);
+ this->getOperation(), resultIndices, newType);
}
/// Return the type of this function with the specified arguments and
@@ -320,10 +324,9 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
/// Return the type of this function without the specified arguments and
/// results. This is used to update the function's signature in the
- /// `eraseArguments` and `eraseResults` methods. The arrays of indices are
- /// allowed to have duplicates and can be in any order.
+ /// `eraseArguments` and `eraseResults` methods.
Type getTypeWithoutArgsAndResults(
- ArrayRef<unsigned> argIndices, ArrayRef<unsigned> resultIndices) {
+ const llvm::BitVector &argIndices, const llvm::BitVector &resultIndices) {
SmallVector<Type> argStorage, resultStorage;
TypeRange newArgTypes = function_interface_impl::filterTypesOut(
$_op.getArgumentTypes(), argIndices, argStorage);
@@ -331,6 +334,18 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
$_op.getResultTypes(), resultIndices, resultStorage);
return $_op.cloneTypeWith(newArgTypes, newResultTypes);
}
+ Type getTypeWithoutArgs(const llvm::BitVector &argIndices) {
+ SmallVector<Type> argStorage;
+ TypeRange newArgTypes = function_interface_impl::filterTypesOut(
+ $_op.getArgumentTypes(), argIndices, argStorage);
+ return $_op.cloneTypeWith(newArgTypes, $_op.getResultTypes());
+ }
+ Type getTypeWithoutResults(const llvm::BitVector &resultIndices) {
+ SmallVector<Type> resultStorage;
+ TypeRange newResultTypes = function_interface_impl::filterTypesOut(
+ $_op.getResultTypes(), resultIndices, resultStorage);
+ return $_op.cloneTypeWith($_op.getArgumentTypes(), newResultTypes);
+ }
//===------------------------------------------------------------------===//
// Argument Attributes
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 08780db5f94df..585e873b41881 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -24,10 +24,10 @@ static void updateFuncOp(FuncOp func,
// Collect information about the results will become appended arguments.
SmallVector<Type, 6> erasedResultTypes;
- SmallVector<unsigned, 6> erasedResultIndices;
+ llvm::BitVector erasedResultIndices(functionType.getNumResults());
for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
if (resultType.value().isa<BaseMemRefType>()) {
- erasedResultIndices.push_back(resultType.index());
+ erasedResultIndices.set(resultType.index());
erasedResultTypes.push_back(resultType.value());
}
}
@@ -40,9 +40,11 @@ static void updateFuncOp(FuncOp func,
func.setType(newFunctionType);
// Transfer the result attributes to arg attributes.
- for (int i = 0, e = erasedResultTypes.size(); i < e; i++)
+ auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
+ for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
func.setArgAttrs(functionType.getNumInputs() + i,
- func.getResultAttrs(erasedResultIndices[i]));
+ func.getResultAttrs(*erasedIndicesIt));
+ }
// Erase the results.
func.eraseResults(erasedResultIndices);
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 46166f16f1686..63a5962803900 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -172,8 +172,8 @@ FunctionType FunctionType::getWithArgsAndResults(
/// Returns a new function type without the specified arguments and results.
FunctionType
-FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
- ArrayRef<unsigned> resultIndices) {
+FunctionType::getWithoutArgsAndResults(const llvm::BitVector &argIndices,
+ const llvm::BitVector &resultIndices) {
SmallVector<Type> argStorage, resultStorage;
TypeRange newArgTypes = function_interface_impl::filterTypesOut(
getInputs(), argIndices, argStorage);
diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp
index 07da5ce1716f1..4f31c59f3f69e 100644
--- a/mlir/lib/IR/FunctionInterfaces.cpp
+++ b/mlir/lib/IR/FunctionInterfaces.cpp
@@ -7,26 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/FunctionInterfaces.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/BitVector.h"
using namespace mlir;
-/// Helper to call a callback once on each index in the range
-/// [0, `totalIndices`), *except* for the indices given in `indices`.
-/// `indices` is allowed to have duplicates and can be in any order.
-inline static void iterateIndicesExcept(unsigned totalIndices,
- ArrayRef<unsigned> indices,
- function_ref<void(unsigned)> callback) {
- llvm::BitVector skipIndices(totalIndices);
- for (unsigned i : indices)
- skipIndices.set(i);
-
- for (unsigned i = 0; i < totalIndices; ++i)
- if (!skipIndices.test(i))
- callback(i);
-}
-
//===----------------------------------------------------------------------===//
// Tablegen Interface Definitions
//===----------------------------------------------------------------------===//
@@ -217,8 +200,7 @@ void mlir::function_interface_impl::insertFunctionResults(
}
void mlir::function_interface_impl::eraseFunctionArguments(
- Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs,
- Type newType) {
+ Operation *op, const llvm::BitVector &argIndices, Type newType) {
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
@@ -229,9 +211,9 @@ void mlir::function_interface_impl::eraseFunctionArguments(
if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
SmallVector<DictionaryAttr, 4> newArgAttrs;
newArgAttrs.reserve(argAttrs.size());
- iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
- newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
- });
+ for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
+ if (!argIndices[i])
+ newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
setAllArgAttrDicts(op, newArgAttrs);
}
@@ -241,8 +223,7 @@ void mlir::function_interface_impl::eraseFunctionArguments(
}
void mlir::function_interface_impl::eraseFunctionResults(
- Operation *op, ArrayRef<unsigned> resultIndices,
- unsigned originalNumResults, Type newType) {
+ Operation *op, const llvm::BitVector &resultIndices, Type newType) {
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
@@ -251,9 +232,9 @@ void mlir::function_interface_impl::eraseFunctionResults(
if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
SmallVector<DictionaryAttr, 4> newResultAttrs;
newResultAttrs.reserve(resAttrs.size());
- iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
- newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
- });
+ for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
+ if (!resultIndices[i])
+ newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
setAllResultAttrDicts(op, newResultAttrs);
}
@@ -282,12 +263,14 @@ TypeRange mlir::function_interface_impl::insertTypesInto(
TypeRange
mlir::function_interface_impl::filterTypesOut(TypeRange types,
- ArrayRef<unsigned> indices,
+ const llvm::BitVector &indices,
SmallVectorImpl<Type> &storage) {
- if (indices.empty())
+ if (indices.none())
return types;
- iterateIndicesExcept(types.size(), indices,
- [&](unsigned i) { storage.emplace_back(types[i]); });
+
+ for (unsigned i = 0, e = types.size(); i < e; ++i)
+ if (!indices[i])
+ storage.emplace_back(types[i]);
return storage;
}
diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp
index dee9f8a5b2e52..0b4ce1d05992e 100644
--- a/mlir/test/lib/IR/TestFunc.cpp
+++ b/mlir/test/lib/IR/TestFunc.cpp
@@ -87,18 +87,10 @@ struct TestFuncEraseArg
auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
- SmallVector<unsigned, 4> indicesToErase;
- for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) {
- if (func.getArgAttr(argIndex, "test.erase_this_arg")) {
- // Push back twice to test that duplicate arg indices are handled
- // correctly.
- indicesToErase.push_back(argIndex);
- indicesToErase.push_back(argIndex);
- }
- }
- // Reverse the order to test that unsorted index lists are handled
- // correctly.
- std::reverse(indicesToErase.begin(), indicesToErase.end());
+ llvm::BitVector indicesToErase(func.getNumArguments());
+ for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
+ if (func.getArgAttr(argIndex, "test.erase_this_arg"))
+ indicesToErase.set(argIndex);
func.eraseArguments(indicesToErase);
}
}
@@ -115,18 +107,10 @@ struct TestFuncEraseResult
auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
- SmallVector<unsigned, 4> indicesToErase;
- for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
- if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
- // Push back twice to test that duplicate indices are handled
- // correctly.
- indicesToErase.push_back(resultIndex);
- indicesToErase.push_back(resultIndex);
- }
- }
- // Reverse the order to test that unsorted index lists are handled
- // correctly.
- std::reverse(indicesToErase.begin(), indicesToErase.end());
+ llvm::BitVector indicesToErase(func.getNumResults());
+ for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
+ if (func.getResultAttr(resultIndex, "test.erase_this_result"))
+ indicesToErase.set(resultIndex);
func.eraseResults(indicesToErase);
}
}
More information about the Mlir-commits
mailing list