[Mlir-commits] [mlir] eb41f9e - [mlir][Vector] Simplify code a bit. NFCI.

Benjamin Kramer llvmlistbot at llvm.org
Sat Aug 1 05:49:40 PDT 2020


Author: Benjamin Kramer
Date: 2020-08-01T14:49:19+02:00
New Revision: eb41f9edde1070d68fce4a4eb31118e0ec1ca36d

URL: https://github.com/llvm/llvm-project/commit/eb41f9edde1070d68fce4a4eb31118e0ec1ca36d
DIFF: https://github.com/llvm/llvm-project/commit/eb41f9edde1070d68fce4a4eb31118e0ec1ca36d.diff

LOG: [mlir][Vector] Simplify code a bit. NFCI.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index d16c7c3d6fdb..c788d4ccb4a0 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -184,9 +184,9 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
   auto lhsType = types[0].cast<VectorType>();
   auto rhsType = types[1].cast<VectorType>();
   auto maskElementType = parser.getBuilder().getI1Type();
-  SmallVector<Type, 2> maskTypes;
-  maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType));
-  maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType));
+  std::array<Type, 2> maskTypes = {
+      VectorType::get(lhsType.getShape(), maskElementType),
+      VectorType::get(rhsType.getShape(), maskElementType)};
   if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
     return failure();
   return success();
@@ -462,12 +462,10 @@ std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
 }
 
 SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
-  SmallVector<AffineMap, 4> res;
-  auto mapAttrs = indexing_maps().getValue();
-  res.reserve(mapAttrs.size());
-  for (auto mapAttr : mapAttrs)
-    res.push_back(mapAttr.cast<AffineMapAttr>().getValue());
-  return res;
+  return llvm::to_vector<4>(
+      llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) {
+        return mapAttr.cast<AffineMapAttr>().getValue();
+      }));
 }
 
 Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
@@ -1854,8 +1852,7 @@ LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
 }
 
 Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
-  auto s = getVectorType().getShape();
-  return SmallVector<int64_t, 4>{s.begin(), s.end()};
+  return llvm::to_vector<4>(getVectorType().getShape());
 }
 
 //===----------------------------------------------------------------------===//
@@ -2014,11 +2011,8 @@ static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
   auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
   SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
                               memRefType.getShape().end());
-  if (vectorType) {
-    res.reserve(memRefType.getRank() + vectorType.getRank());
-    for (auto s : vectorType.getShape())
-      res.push_back(s);
-  }
+  if (vectorType)
+    res.append(vectorType.getShape().begin(), vectorType.getShape().end());
   return res;
 }
 

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index ab93ef406024..197b1c62274b 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1707,7 +1707,7 @@ void ContractionOpToOuterProductOpLowering::rewrite(
   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
   AffineExpr m, n, k;
   bindDims(rewriter.getContext(), m, n, k);
-  SmallVector<int64_t, 2> perm{1, 0};
+  static constexpr std::array<int64_t, 2> perm = {1, 0};
   auto iteratorTypes = op.iterator_types().getValue();
   SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
   if (isParallelIterator(iteratorTypes[0]) &&
@@ -1911,10 +1911,10 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
   assert(lookup.hasValue() && "parallel index not listed in reduction");
   int64_t resIndex = lookup.getValue();
   // Construct new iterator types and affine map array attribute.
-  SmallVector<AffineMap, 4> lowIndexingMaps;
-  lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
-  lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
-  lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
+  std::array<AffineMap, 3> lowIndexingMaps = {
+      adjustMap(iMap[0], iterIndex, rewriter),
+      adjustMap(iMap[1], iterIndex, rewriter),
+      adjustMap(iMap[2], iterIndex, rewriter)};
   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
   auto lowIter =
       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
@@ -1962,10 +1962,10 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
                                                 op.acc());
   }
   // Construct new iterator types and affine map array attribute.
-  SmallVector<AffineMap, 4> lowIndexingMaps;
-  lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter));
-  lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter));
-  lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter));
+  std::array<AffineMap, 3> lowIndexingMaps = {
+      adjustMap(iMap[0], iterIndex, rewriter),
+      adjustMap(iMap[1], iterIndex, rewriter),
+      adjustMap(iMap[2], iterIndex, rewriter)};
   auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
   auto lowIter =
       rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));


        


More information about the Mlir-commits mailing list