[Mlir-commits] [mlir] f3ece29 - [ADT] Allow specifying the size of resulting `SmallVector` in `map_to_vector`

Laszlo Kindrat llvmlistbot at llvm.org
Thu May 25 08:35:33 PDT 2023


Author: Laszlo Kindrat
Date: 2023-05-25T11:35:19-04:00
New Revision: f3ece29b4658d60a1e7656bb9e67853376d094b7

URL: https://github.com/llvm/llvm-project/commit/f3ece29b4658d60a1e7656bb9e67853376d094b7
DIFF: https://github.com/llvm/llvm-project/commit/f3ece29b4658d60a1e7656bb9e67853376d094b7.diff

LOG: [ADT] Allow specifying the size of resulting `SmallVector` in `map_to_vector`

This patch adds an overload for the `map_to_vector` helper template, exposing a parameter to control the size of the resulting `SmallVector`. A few call sites in mlir are updated to illustrate and test the change.

Differential Revision: https://reviews.llvm.org/D150601

Added: 
    

Modified: 
    llvm/include/llvm/ADT/SmallVectorExtras.h
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/TypeUtilities.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/SmallVectorExtras.h b/llvm/include/llvm/ADT/SmallVectorExtras.h
index 8d5228025e0e..d5159aa0e62f 100644
--- a/llvm/include/llvm/ADT/SmallVectorExtras.h
+++ b/llvm/include/llvm/ADT/SmallVectorExtras.h
@@ -20,6 +20,11 @@
 namespace llvm {
 
 /// Map a range to a SmallVector with element types deduced from the mapping.
+template <unsigned Size, class ContainerTy, class FuncTy>
+auto map_to_vector(ContainerTy &&C, FuncTy &&F) {
+  return to_vector<Size>(
+      map_range(std::forward<ContainerTy>(C), std::forward<FuncTy>(F)));
+}
 template <class ContainerTy, class FuncTy>
 auto map_to_vector(ContainerTy &&C, FuncTy &&F) {
   return to_vector(

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index e21dc9c950c5..01cd7183e43c 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -19,6 +19,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include <optional>
 
 namespace llvm {
@@ -226,11 +227,11 @@ class AffineMap {
   AffineMap shiftDims(unsigned shift, unsigned offset = 0) const {
     assert(offset <= getNumDims());
     return AffineMap::get(getNumDims() + shift, getNumSymbols(),
-                          llvm::to_vector<4>(llvm::map_range(
+                          llvm::map_to_vector<4>(
                               getResults(),
                               [&](AffineExpr e) {
                                 return e.shiftDims(getNumDims(), shift, offset);
-                              })),
+                              }),
                           getContext());
   }
 
@@ -238,12 +239,12 @@ class AffineMap {
   /// by symbols[offset + shift ... shift + numSymbols).
   AffineMap shiftSymbols(unsigned shift, unsigned offset = 0) const {
     return AffineMap::get(getNumDims(), getNumSymbols() + shift,
-                          llvm::to_vector<4>(llvm::map_range(
-                              getResults(),
-                              [&](AffineExpr e) {
-                                return e.shiftSymbols(getNumSymbols(), shift,
-                                                      offset);
-                              })),
+                          llvm::map_to_vector<4>(getResults(),
+                                                 [&](AffineExpr e) {
+                                                   return e.shiftSymbols(
+                                                       getNumSymbols(), shift,
+                                                       offset);
+                                                 }),
                           getContext());
   }
 

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 6cbba068fc1a..c4fad9c4b3d4 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
@@ -261,57 +262,56 @@ ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
 }
 
 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
-  auto attrs = llvm::to_vector<8>(llvm::map_range(
-      values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
+  auto attrs = llvm::map_to_vector<8>(
+      values, [this](bool v) -> Attribute { return getBoolAttr(v); });
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
-  auto attrs = llvm::to_vector<8>(llvm::map_range(
-      values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
+  auto attrs = llvm::map_to_vector<8>(
+      values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); });
   return getArrayAttr(attrs);
 }
 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
-  auto attrs = llvm::to_vector<8>(llvm::map_range(
-      values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
+  auto attrs = llvm::map_to_vector<8>(
+      values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); });
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
-  auto attrs = llvm::to_vector<8>(
-      llvm::map_range(values, [this](int64_t v) -> Attribute {
-        return getIntegerAttr(IndexType::get(getContext()), v);
-      }));
+  auto attrs = llvm::map_to_vector<8>(values, [this](int64_t v) -> Attribute {
+    return getIntegerAttr(IndexType::get(getContext()), v);
+  });
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
-  auto attrs = llvm::to_vector<8>(llvm::map_range(
-      values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
+  auto attrs = llvm::map_to_vector<8>(
+      values, [this](float v) -> Attribute { return getF32FloatAttr(v); });
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
-  auto attrs = llvm::to_vector<8>(llvm::map_range(
-      values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
+  auto attrs = llvm::map_to_vector<8>(
+      values, [this](double v) -> Attribute { return getF64FloatAttr(v); });
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
-  auto attrs = llvm::to_vector<8>(llvm::map_range(
-      values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
+  auto attrs = llvm::map_to_vector<8>(
+      values, [this](StringRef v) -> Attribute { return getStringAttr(v); });
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
-  auto attrs = llvm::to_vector<8>(llvm::map_range(
-      values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
+  auto attrs = llvm::map_to_vector<8>(
+      values, [](Type v) -> Attribute { return TypeAttr::get(v); });
   return getArrayAttr(attrs);
 }
 
 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
-  auto attrs = llvm::to_vector<8>(llvm::map_range(
-      values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
+  auto attrs = llvm::map_to_vector<8>(
+      values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); });
   return getArrayAttr(attrs);
 }
 

diff  --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index 7aa37cb015fc..6926bebe0221 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -11,13 +11,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/TypeUtilities.h"
-
-#include <numeric>
-
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+#include <numeric>
 
 using namespace mlir;
 
@@ -119,8 +118,8 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
 /// have compatible dimensions. Dimensions are compatible if all non-dynamic
 /// dims are equal. The element type does not matter.
 LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
-  auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
-      types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); }));
+  auto shapedTypes = llvm::map_to_vector<8>(
+      types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
   // Return failure if some, but not all are not shaped. Return early if none
   // are shaped also.
   if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
@@ -155,10 +154,10 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
 
   for (unsigned i = 0; i < firstRank; ++i) {
     // Retrieve all ranked dimensions
-    auto dims = llvm::to_vector<8>(llvm::map_range(
+    auto dims = llvm::map_to_vector<8>(
         llvm::make_filter_range(
             shapes, [&](auto shape) { return shape.getRank() >= i; }),
-        [&](auto shape) { return shape.getDimSize(i); }));
+        [&](auto shape) { return shape.getDimSize(i); });
     if (verifyCompatibleDims(dims).failed())
       return failure();
   }


        


More information about the Mlir-commits mailing list