[Mlir-commits] [mlir] f3ece29 - [ADT] Allow specifying the size of resulting `SmallVector` in `map_to_vector`
Laszlo Kindrat
llvmlistbot at llvm.org
Thu May 25 08:35:33 PDT 2023
Author: Laszlo Kindrat
Date: 2023-05-25T11:35:19-04:00
New Revision: f3ece29b4658d60a1e7656bb9e67853376d094b7
URL: https://github.com/llvm/llvm-project/commit/f3ece29b4658d60a1e7656bb9e67853376d094b7
DIFF: https://github.com/llvm/llvm-project/commit/f3ece29b4658d60a1e7656bb9e67853376d094b7.diff
LOG: [ADT] Allow specifying the size of resulting `SmallVector` in `map_to_vector`
This patch adds an overload for the `map_to_vector` helper template, exposing a parameter to control the size of the resulting `SmallVector`. A few call sites in mlir are updated to illustrate and test the change.
Differential Revision: https://reviews.llvm.org/D150601
Added:
Modified:
llvm/include/llvm/ADT/SmallVectorExtras.h
mlir/include/mlir/IR/AffineMap.h
mlir/lib/IR/Builders.cpp
mlir/lib/IR/TypeUtilities.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/SmallVectorExtras.h b/llvm/include/llvm/ADT/SmallVectorExtras.h
index 8d5228025e0e..d5159aa0e62f 100644
--- a/llvm/include/llvm/ADT/SmallVectorExtras.h
+++ b/llvm/include/llvm/ADT/SmallVectorExtras.h
@@ -20,6 +20,11 @@
namespace llvm {
/// Map a range to a SmallVector with element types deduced from the mapping.
+template <unsigned Size, class ContainerTy, class FuncTy>
+auto map_to_vector(ContainerTy &&C, FuncTy &&F) {
+ return to_vector<Size>(
+ map_range(std::forward<ContainerTy>(C), std::forward<FuncTy>(F)));
+}
template <class ContainerTy, class FuncTy>
auto map_to_vector(ContainerTy &&C, FuncTy &&F) {
return to_vector(
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index e21dc9c950c5..01cd7183e43c 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -19,6 +19,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
#include <optional>
namespace llvm {
@@ -226,11 +227,11 @@ class AffineMap {
AffineMap shiftDims(unsigned shift, unsigned offset = 0) const {
assert(offset <= getNumDims());
return AffineMap::get(getNumDims() + shift, getNumSymbols(),
- llvm::to_vector<4>(llvm::map_range(
+ llvm::map_to_vector<4>(
getResults(),
[&](AffineExpr e) {
return e.shiftDims(getNumDims(), shift, offset);
- })),
+ }),
getContext());
}
@@ -238,12 +239,12 @@ class AffineMap {
/// by symbols[offset + shift ... shift + numSymbols).
AffineMap shiftSymbols(unsigned shift, unsigned offset = 0) const {
return AffineMap::get(getNumDims(), getNumSymbols() + shift,
- llvm::to_vector<4>(llvm::map_range(
- getResults(),
- [&](AffineExpr e) {
- return e.shiftSymbols(getNumSymbols(), shift,
- offset);
- })),
+ llvm::map_to_vector<4>(getResults(),
+ [&](AffineExpr e) {
+ return e.shiftSymbols(
+ getNumSymbols(), shift,
+ offset);
+ }),
getContext());
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 6cbba068fc1a..c4fad9c4b3d4 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -261,57 +262,56 @@ ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
}
ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
- auto attrs = llvm::to_vector<8>(llvm::map_range(
- values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
+ auto attrs = llvm::map_to_vector<8>(
+ values, [this](bool v) -> Attribute { return getBoolAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
- auto attrs = llvm::to_vector<8>(llvm::map_range(
- values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
+ auto attrs = llvm::map_to_vector<8>(
+ values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
- auto attrs = llvm::to_vector<8>(llvm::map_range(
- values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
+ auto attrs = llvm::map_to_vector<8>(
+ values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
- auto attrs = llvm::to_vector<8>(
- llvm::map_range(values, [this](int64_t v) -> Attribute {
- return getIntegerAttr(IndexType::get(getContext()), v);
- }));
+ auto attrs = llvm::map_to_vector<8>(values, [this](int64_t v) -> Attribute {
+ return getIntegerAttr(IndexType::get(getContext()), v);
+ });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
- auto attrs = llvm::to_vector<8>(llvm::map_range(
- values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
+ auto attrs = llvm::map_to_vector<8>(
+ values, [this](float v) -> Attribute { return getF32FloatAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
- auto attrs = llvm::to_vector<8>(llvm::map_range(
- values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
+ auto attrs = llvm::map_to_vector<8>(
+ values, [this](double v) -> Attribute { return getF64FloatAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
- auto attrs = llvm::to_vector<8>(llvm::map_range(
- values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
+ auto attrs = llvm::map_to_vector<8>(
+ values, [this](StringRef v) -> Attribute { return getStringAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
- auto attrs = llvm::to_vector<8>(llvm::map_range(
- values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
+ auto attrs = llvm::map_to_vector<8>(
+ values, [](Type v) -> Attribute { return TypeAttr::get(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
- auto attrs = llvm::to_vector<8>(llvm::map_range(
- values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
+ auto attrs = llvm::map_to_vector<8>(
+ values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); });
return getArrayAttr(attrs);
}
diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index 7aa37cb015fc..6926bebe0221 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -11,13 +11,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/TypeUtilities.h"
-
-#include <numeric>
-
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+#include <numeric>
using namespace mlir;
@@ -119,8 +118,8 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
/// have compatible dimensions. Dimensions are compatible if all non-dynamic
/// dims are equal. The element type does not matter.
LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
- auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
- types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); }));
+ auto shapedTypes = llvm::map_to_vector<8>(
+ types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
// Return failure if some, but not all are not shaped. Return early if none
// are shaped also.
if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
@@ -155,10 +154,10 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
for (unsigned i = 0; i < firstRank; ++i) {
// Retrieve all ranked dimensions
- auto dims = llvm::to_vector<8>(llvm::map_range(
+ auto dims = llvm::map_to_vector<8>(
llvm::make_filter_range(
shapes, [&](auto shape) { return shape.getRank() >= i; }),
- [&](auto shape) { return shape.getDimSize(i); }));
+ [&](auto shape) { return shape.getDimSize(i); });
if (verifyCompatibleDims(dims).failed())
return failure();
}
More information about the Mlir-commits
mailing list