[Mlir-commits] [mlir] 7a69a9d - [NFC][mlir] VectorUtils / IndexingUtils simplifications and cleanups

Nicolas Vasilache llvmlistbot at llvm.org
Tue Nov 22 23:42:34 PST 2022


Author: Nicolas Vasilache
Date: 2022-11-22T23:42:29-08:00
New Revision: 7a69a9d7aee2c5ccbf160b08356a0fc6acbebf75

URL: https://github.com/llvm/llvm-project/commit/7a69a9d7aee2c5ccbf160b08356a0fc6acbebf75
DIFF: https://github.com/llvm/llvm-project/commit/7a69a9d7aee2c5ccbf160b08356a0fc6acbebf75.diff

LOG: [NFC][mlir] VectorUtils / IndexingUtils simplifications and cleanups

This revision refactors and cleans up a bunch of infra related to vector, shapes and indexing into more reusable APIs.

Differential Revision: https://reviews.llvm.org/D138501

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
    mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/lib/Dialect/Utils/IndexingUtils.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
    mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index c1857462d4c67..ee1d4550e953c 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -28,8 +28,39 @@ int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
 /// Given the strides together with a linear index in the dimension
 /// space, returns the vector-space offsets in each dimension for a
 /// de-linearized index.
-SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> strides,
-                                    int64_t linearIndex);
+SmallVector<int64_t> delinearize(ArrayRef<int64_t> strides,
+                                 int64_t linearIndex);
+
+/// Given a set of sizes, compute and return the strides (i.e. the number of
+/// linear incides to skip along the (k-1) most minor dimensions to get the next
+/// k-slice). This is also the basis that one can use to linearize an n-D offset
+/// confined to `[0 .. sizes]`.
+SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes);
+
+/// Return a vector containing llvm::zip of v1 and v2 multiplied elementwise.
+SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
+                                           ArrayRef<int64_t> v2);
+
+/// Compute and return the multi-dimensional integral ratio of `subShape` to
+/// the trailing dimensions of `shape`. This represents how many times
+/// `subShape` fits within `shape`.
+/// If integral division is not possible, return None.
+/// The trailing `subShape.size()` entries of both shapes are assumed (and
+/// enforced) to only contain noonnegative values.
+///
+/// Examples:
+///   - shapeRatio({3, 5, 8}, {2, 5, 2}) returns {3, 2, 1}.
+///   - shapeRatio({3, 8}, {2, 5, 2}) returns None (subshape has higher rank).
+///   - shapeRatio({42, 2, 10, 32}, {2, 5, 2}) returns {42, 1, 2, 16} which is
+///     derived as {42(leading shape dim), 2/2, 10/5, 32/2}.
+///   - shapeRatio({42, 2, 11, 32}, {2, 5, 2}) returns None  which is
+///     derived as {42(leading shape dim), 2/2, 11/5(not divisible), 32/2}.
+Optional<SmallVector<int64_t>> computeShapeRatio(ArrayRef<int64_t> shape,
+                                                 ArrayRef<int64_t> subShape);
+
+/// Return the number of elements of basis (i.e. the max linear index).
+/// Return `0` if `basis` is empty.
+int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
 
 /// Apply the permutation defined by `permutation` to `inVec`.
 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
@@ -45,16 +76,15 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
 }
 
 /// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
-SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
-                                       unsigned dropFront = 0,
-                                       unsigned dropBack = 0);
+SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
+                                    unsigned dropBack = 0);
 
 /// Computes and returns linearized affine expression w.r.t. `basis`.
 mlir::AffineExpr getLinearAffineExpr(ArrayRef<int64_t> basis, mlir::Builder &b);
 
 /// Given the strides in the dimension space, returns the affine expressions for
 /// vector-space offsets in each dimension for a de-linearized index.
-SmallVector<mlir::AffineExpr, 4>
+SmallVector<mlir::AffineExpr>
 getDelinearizedAffineExpr(ArrayRef<int64_t> strides, mlir::Builder &b);
 
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 15756617792ac..0ab03085f5ae1 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -111,7 +111,7 @@ struct UnrollVectorOptions {
   }
 
   using NativeShapeFnType =
-      std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
+      std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
   /// Function that returns the shape of the vector to unroll to for a given
   /// operation. The unrolling is aborted if the function returns `llvm::None`.
   NativeShapeFnType nativeShape = nullptr;
@@ -122,8 +122,8 @@ struct UnrollVectorOptions {
 
   /// Set the native shape to use for unrolling.
   UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
-    SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
-    nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
+    SmallVector<int64_t> tsShape(shape.begin(), shape.end());
+    nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t>> {
       return tsShape;
     };
     return *this;

diff  --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index b5e6bc1ae5747..3c9f1f41419ad 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -36,43 +36,6 @@ namespace vector {
 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
 } // namespace vector
 
-/// Return the number of elements of basis, `0` if empty.
-int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
-
-/// Given the shape and sizes of a vector, returns the corresponding
-/// strides for each dimension.
-/// TODO: needs better doc of how it is used.
-SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
-                                       ArrayRef<int64_t> sizes);
-
-/// Given the target sizes of a vector, together with vector-space offsets,
-/// returns the element-space offsets for each dimension.
-SmallVector<int64_t, 4>
-computeElementOffsetsFromVectorSliceOffsets(ArrayRef<int64_t> sizes,
-                                            ArrayRef<int64_t> vectorOffsets);
-
-/// Computes and returns the multi-dimensional ratio of `superShape` to
-/// `subShape`. This is calculated by performing a traversal from minor to major
-/// dimensions (i.e. in reverse shape order). If integral division is not
-/// possible, returns None.
-/// The ArrayRefs are assumed (and enforced) to only contain > 1 values.
-/// This constraint comes from the fact that they are meant to be used with
-/// VectorTypes, for which the property holds by construction.
-///
-/// Examples:
-///   - shapeRatio({3, 4, 5, 8}, {2, 5, 2}) returns {3, 2, 1, 4}
-///   - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None
-///   - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16}
-Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
-                                             ArrayRef<int64_t> subShape);
-
-/// Computes and returns the multi-dimensional ratio of the shapes of
-/// `superVector` to `subVector`. If integral division is not possible, returns
-/// None.
-/// Assumes and enforces that the VectorTypes have the same elemental type.
-Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
-                                             VectorType subVectorType);
-
 /// Constructs a permutation map of invariant memref indices to vector
 /// dimension.
 ///

diff  --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index a0eee716df4a9..d88e7795f0de4 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -80,8 +80,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
   Value result = rewriter.create<arith::ConstantOp>(
       loc, DenseElementsAttr::get(
                vecType, IntegerAttr::get(vecType.getElementType(), 0)));
-  SmallVector<int64_t> ones(shape.size(), 1);
-  SmallVector<int64_t> strides = computeStrides(shape, ones);
+  SmallVector<int64_t> strides = computeStrides(shape);
   for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
     SmallVector<int64_t> positions = delinearize(strides, linearIndex);
     SmallVector<Value> operands;

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a412dbd0c6e03..d40666d6608c5 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -79,8 +79,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
   Value result = rewriter.create<arith::ConstantOp>(
       loc, DenseElementsAttr::get(
                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
-  SmallVector<int64_t> ones(shape.size(), 1);
-  SmallVector<int64_t> strides = computeStrides(shape, ones);
+  SmallVector<int64_t> strides = computeStrides(shape);
   for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
     SmallVector<int64_t> positions = delinearize(strides, linearIndex);
     SmallVector<Value> operands;

diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 6413d50948962..1f004b1723a7a 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -127,15 +127,13 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
 
   // Iterate over all outer dimensions of the compute shape vector type.
   auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
-  int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims);
-
-  SmallVector<int64_t> ones(iterationDims.size(), 1);
-  auto strides = computeStrides(iterationDims, ones);
+  int64_t maxIndex = computeMaxLinearIndex(iterationDims);
+  auto strides = computeStrides(iterationDims);
 
   // Compute results for each one dimensional vector.
-  SmallVector<Value> results(maxLinearIndex);
+  SmallVector<Value> results(maxIndex);
 
-  for (int64_t i = 0; i < maxLinearIndex; ++i) {
+  for (int64_t i = 0; i < maxIndex; ++i) {
     auto offsets = delinearize(strides, i);
 
     SmallVector<Value> extracted(expandedOperands.size());
@@ -152,7 +150,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
   Value result = builder.create<arith::ConstantOp>(
       resultExpandedType, builder.getZeroAttr(resultExpandedType));
 
-  for (int64_t i = 0; i < maxLinearIndex; ++i)
+  for (int64_t i = 0; i < maxIndex; ++i)
     result = builder.create<vector::InsertOp>(results[i], result,
                                               delinearize(strides, i));
 

diff  --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index b9901b7af0ff1..1c7a89cd755f3 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -12,6 +12,53 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 
+#include <numeric>
+
+using namespace mlir;
+
+SmallVector<int64_t> mlir::computeStrides(ArrayRef<int64_t> sizes) {
+  SmallVector<int64_t> strides(sizes.size(), 1);
+  for (int64_t r = strides.size() - 2; r >= 0; --r)
+    strides[r] = strides[r + 1] * sizes[r + 1];
+  return strides;
+}
+
+SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
+                                                 ArrayRef<int64_t> v2) {
+  SmallVector<int64_t> result;
+  for (auto it : llvm::zip(v1, v2))
+    result.push_back(std::get<0>(it) * std::get<1>(it));
+  return result;
+}
+
+Optional<SmallVector<int64_t>>
+mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
+  if (shape.size() < subShape.size())
+    return None;
+  assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) &&
+         "shape must be nonnegative");
+  assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) &&
+         "subShape must be nonnegative");
+
+  // Starting from the end, compute the integer divisors.
+  std::vector<int64_t> result;
+  result.reserve(shape.size());
+  for (auto [size, subSize] :
+       llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) {
+    // If integral division does not occur, return and let the caller decide.
+    if (size % subSize != 0)
+      return None;
+    result.push_back(size / subSize);
+  }
+  // At this point we computed the ratio (in reverse) for the common size.
+  // Fill with the remaining entries from the shape (still in reverse).
+  int commonSize = subShape.size();
+  std::copy(shape.rbegin() + commonSize, shape.rend(),
+            std::back_inserter(result));
+  // Reverse again to get it back in the proper order and return.
+  return SmallVector<int64_t>{result.rbegin(), result.rend()};
+}
+
 int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
   assert(offsets.size() == basis.size());
   int64_t linearIndex = 0;
@@ -20,10 +67,10 @@ int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
   return linearIndex;
 }
 
-llvm::SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
-                                                int64_t index) {
+llvm::SmallVector<int64_t> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
+                                             int64_t index) {
   int64_t rank = sliceStrides.size();
-  SmallVector<int64_t, 4> vectorOffsets(rank);
+  SmallVector<int64_t> vectorOffsets(rank);
   for (int64_t r = 0; r < rank; ++r) {
     assert(sliceStrides[r] > 0);
     vectorOffsets[r] = index / sliceStrides[r];
@@ -32,12 +79,19 @@ llvm::SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
   return vectorOffsets;
 }
 
-llvm::SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr,
-                                                   unsigned dropFront,
-                                                   unsigned dropBack) {
+int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
+  if (basis.empty())
+    return 0;
+  return std::accumulate(basis.begin(), basis.end(), 1,
+                         std::multiplies<int64_t>());
+}
+
+llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
+                                                unsigned dropFront,
+                                                unsigned dropBack) {
   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
   auto range = arrayAttr.getAsRange<IntegerAttr>();
-  SmallVector<int64_t, 4> res;
+  SmallVector<int64_t> res;
   res.reserve(arrayAttr.size() - dropFront - dropBack);
   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
        it != eit; ++it)
@@ -54,11 +108,11 @@ mlir::AffineExpr mlir::getLinearAffineExpr(ArrayRef<int64_t> basis,
   return resultExpr;
 }
 
-llvm::SmallVector<mlir::AffineExpr, 4>
+llvm::SmallVector<mlir::AffineExpr>
 mlir::getDelinearizedAffineExpr(mlir::ArrayRef<int64_t> strides, Builder &b) {
   AffineExpr resultExpr = b.getAffineDimExpr(0);
   int64_t rank = strides.size();
-  SmallVector<AffineExpr, 4> vectorOffsets(rank);
+  SmallVector<AffineExpr> vectorOffsets(rank);
   vectorOffsets[0] = resultExpr.floorDiv(strides[0]);
   resultExpr = resultExpr % strides[0];
   for (unsigned i = 1; i < rank; i++) {

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 0ce46c9bd685e..c41e66325c011 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -54,9 +54,9 @@ static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
 }
 
 // Helper to construct iterator types with one index removed.
-static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
-                                            int64_t index) {
-  SmallVector<Attribute, 4> results;
+static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
+                                         int64_t index) {
+  SmallVector<Attribute> results;
   for (const auto &it : llvm::enumerate(iteratorTypes)) {
     int64_t idx = it.index();
     if (idx == index)
@@ -70,7 +70,7 @@ static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
 static AffineMap adjustMap(AffineMap map, int64_t index,
                            PatternRewriter &rewriter) {
   auto *ctx = rewriter.getContext();
-  SmallVector<AffineExpr, 4> results;
+  SmallVector<AffineExpr> results;
   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
     int64_t idx = map.getDimPosition(i);
     if (idx == index)
@@ -140,7 +140,7 @@ static Value reshapeStore(Location loc, Value val, Value result,
 }
 
 template <typename IntType>
-static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
+static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(llvm::map_range(
       arrayAttr.getAsRange<IntegerAttr>(),
       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
@@ -399,7 +399,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     VectorType resType = op.getResultType();
 
     // Set up convenience transposition table.
-    SmallVector<int64_t, 4> transp;
+    SmallVector<int64_t> transp;
     for (auto attr : op.getTransp())
       transp.push_back(attr.cast<IntegerAttr>().getInt());
 
@@ -430,12 +430,11 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     // in vector form to improve performance. Therefore, we prune those
     // dimensions from the shape/transpose data structures used to generate the
     // extract/insert ops.
-    SmallVector<int64_t, 4> prunedTransp;
+    SmallVector<int64_t> prunedTransp;
     pruneNonTransposedDims(transp, prunedTransp);
     size_t numPrunedDims = transp.size() - prunedTransp.size();
     auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
-    SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
-    auto prunedInStrides = computeStrides(prunedInShape, ones);
+    auto prunedInStrides = computeStrides(prunedInShape);
 
     // Generates the extract/insert operations for every scalar/vector element
     // of the leftmost transposed dimensions. We traverse every transpose
@@ -448,7 +447,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
          ++linearIdx) {
       auto extractIdxs = delinearize(prunedInStrides, linearIdx);
-      SmallVector<int64_t, 4> insertIdxs(extractIdxs);
+      SmallVector<int64_t> insertIdxs(extractIdxs);
       applyPermutationToVector(insertIdxs, prunedTransp);
       Value extractOp =
           rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
@@ -488,7 +487,7 @@ class TransposeOp2DToShuffleLowering
     if (srcType.getRank() != 2)
       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
 
-    SmallVector<int64_t, 4> transp;
+    SmallVector<int64_t> transp;
     for (auto attr : op.getTransp())
       transp.push_back(attr.cast<IntegerAttr>().getInt());
     if (transp[0] != 1 && transp[1] != 0)
@@ -685,8 +684,8 @@ struct ContractOpToElementwise
     bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
     newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
     newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
-    SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
-    SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
+    SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
+    SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
     newLhs = rewriter.create<vector::ExtractOp>(
         loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
     newRhs = rewriter.create<vector::ExtractOp>(
@@ -752,7 +751,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
     if (rank == 1) {
       // Express constant 1-D case in explicit vector form:
       //   [T,..,T,F,..,F].
-      SmallVector<bool, 4> values(dstType.getDimSize(0));
+      SmallVector<bool> values(dstType.getDimSize(0));
       for (int64_t d = 0; d < trueDim; d++)
         values[d] = true;
       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
@@ -762,7 +761,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
 
     VectorType lowType =
         VectorType::get(dstType.getShape().drop_front(), eltType);
-    SmallVector<int64_t, 4> newDimSizes;
+    SmallVector<int64_t> newDimSizes;
     for (int64_t r = 1; r < rank; r++)
       newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
@@ -931,8 +930,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
     //    x[0,1,0] = y[0,2]
     // etc., incrementing the two index vectors "row-major"
     // within the source and result shape.
-    SmallVector<int64_t, 4> srcIdx(srcRank);
-    SmallVector<int64_t, 4> resIdx(resRank);
+    SmallVector<int64_t> srcIdx(srcRank);
+    SmallVector<int64_t> resIdx(resRank);
     Value result = rewriter.create<arith::ConstantOp>(
         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
     for (int64_t i = 0; i < numElts; i++) {
@@ -948,7 +947,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
   }
 
 private:
-  static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
+  static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
     assert(0 <= r && r < tp.getRank());
     if (++idx[r] == tp.getDimSize(r)) {
       idx[r] = 0;
@@ -1039,7 +1038,7 @@ struct CombineContractABTranspose final
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
-    SmallVector<AffineMap, 4> maps =
+    SmallVector<AffineMap> maps =
         llvm::to_vector<4>(contractOp.getIndexingMapsArray());
     Value lhs = contractOp.getLhs();
     Value rhs = contractOp.getRhs();
@@ -1169,7 +1168,7 @@ struct CombineContractBroadcast
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
-    SmallVector<AffineMap, 4> maps =
+    SmallVector<AffineMap> maps =
         llvm::to_vector<4>(contractOp.getIndexingMapsArray());
     Value lhs = contractOp.getLhs();
     Value rhs = contractOp.getRhs();
@@ -1234,7 +1233,7 @@ struct CombineContractBroadcast
     for (auto &m : maps)
       m = compressDims(m, unusedDimsBitVector);
     // Compute the combined iterators.
-    SmallVector<Attribute, 4> iterators;
+    SmallVector<Attribute> iterators;
     for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
       if (!unusedDimsBitVector.test(i))
         iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
@@ -1328,7 +1327,7 @@ struct ReorderElementwiseOpsOnTranspose final
 
     // Make sure all operands are transpose/constant ops and collect their
     // transposition maps.
-    SmallVector<ArrayAttr, 4> transposeMaps;
+    SmallVector<ArrayAttr> transposeMaps;
     transposeMaps.reserve(op->getNumOperands());
     // Record the initial type before transposition. We'll use its shape later.
     // Any type will do here as we will check all transpose maps are the same.
@@ -1350,7 +1349,7 @@ struct ReorderElementwiseOpsOnTranspose final
     if (!llvm::all_equal(transposeMaps))
       return rewriter.notifyMatchFailure(op, "
diff erent transpose map");
 
-    SmallVector<Value, 4> srcValues;
+    SmallVector<Value> srcValues;
     srcValues.reserve(op->getNumOperands());
 
     // If there are constant operands, we need to insert inverse transposes for
@@ -1724,7 +1723,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
   AffineExpr m, n, k;
   bindDims(rewriter.getContext(), m, n, k);
-  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
+  SmallVector<AffineMap> maps = op.getIndexingMapsArray();
   //
   // In the following we wish to make the reduction dimension innermost so we
   // can load vectors and just fmul + reduce into a scalar.
@@ -1940,7 +1939,7 @@ ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
   VectorType rhsType = op.getRhsType();
   VectorType resType = op.getResultType().cast<VectorType>();
   // Find the iterator type index and result index.
-  SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
+  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
   int64_t iterIndex = -1;
   int64_t dimSize = -1;
   if (lhsIndex >= 0) {
@@ -2011,7 +2010,7 @@ ContractionOpLowering::lowerReduction(vector::ContractionOp op,
   bool isInt = resType.isa<IntegerType>();
   // Use iterator index 0.
   int64_t iterIndex = 0;
-  SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
+  SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
   Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
   Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
   if (!lookupLhs.has_value())
@@ -2087,7 +2086,7 @@ struct TransferReadToVectorLoadLowering
     if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
       return failure();
 
-    SmallVector<unsigned, 4> broadcastedDims;
+    SmallVector<unsigned> broadcastedDims;
     // Permutations are handled by VectorToSCF or
     // populateVectorTransferPermutationMapLoweringPatterns.
     // We let the 0-d corner case pass-through as it is supported.
@@ -2106,8 +2105,8 @@ struct TransferReadToVectorLoadLowering
     // If there is broadcasting involved then we first load the unbroadcasted
     // vector, and then broadcast it with `vector.broadcast`.
     ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
-    SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
-                                                     vectorShape.end());
+    SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
+                                                  vectorShape.end());
     for (unsigned i : broadcastedDims)
       unbroadcastedVectorShape[i] = 1;
     VectorType unbroadcastedVectorType = VectorType::get(
@@ -2286,7 +2285,7 @@ struct TransferWriteToVectorStoreLowering
 };
 
 // Returns the values in `arrayAttr` as an integer vector.
-static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
+static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(
       llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
                       [](IntegerAttr attr) { return attr.getInt(); }));
@@ -2410,7 +2409,7 @@ struct BubbleDownBitCastForStridedSliceExtract
     // dimension's offset given we are extracting from less elements now.
     ArrayAttr newOffsets = extractOp.getOffsets();
     if (newOffsets.size() == rank) {
-      SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
+      SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
       if (offsets.back() % expandRatio != 0)
         return failure();
       offsets.back() = offsets.back() / expandRatio;
@@ -2420,14 +2419,14 @@ struct BubbleDownBitCastForStridedSliceExtract
     // Similarly for sizes.
     ArrayAttr newSizes = extractOp.getSizes();
     if (newSizes.size() == rank) {
-      SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
+      SmallVector<int64_t> sizes = getIntValueVector(newSizes);
       if (sizes.back() % expandRatio != 0)
         return failure();
       sizes.back() = sizes.back() / expandRatio;
       newSizes = rewriter.getI64ArrayAttr(sizes);
     }
 
-    SmallVector<int64_t, 4> dims =
+    SmallVector<int64_t> dims =
         llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
     dims.back() = dims.back() / expandRatio;
     VectorType newExtractType =
@@ -2500,13 +2499,13 @@ struct BubbleUpBitCastForStridedSliceInsert
 
     ArrayAttr newOffsets = insertOp.getOffsets();
     assert(newOffsets.size() == rank);
-    SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
+    SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
     if (offsets.back() % shrinkRatio != 0)
       return failure();
     offsets.back() = offsets.back() / shrinkRatio;
     newOffsets = rewriter.getI64ArrayAttr(offsets);
 
-    SmallVector<int64_t, 4> srcDims =
+    SmallVector<int64_t> srcDims =
         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
     srcDims.back() = srcDims.back() / shrinkRatio;
     VectorType newCastSrcType =
@@ -2515,7 +2514,7 @@ struct BubbleUpBitCastForStridedSliceInsert
     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
 
-    SmallVector<int64_t, 4> dstDims =
+    SmallVector<int64_t> dstDims =
         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
     dstDims.back() = dstDims.back() / shrinkRatio;
     VectorType newCastDstType =

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 8c5d5144304e5..5d07fb3a925b4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -27,24 +27,19 @@ using namespace mlir::vector;
 
 /// During unrolling from `originalShape` to `targetShape` return the offset for
 /// the slice `index`.
-static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
-                                               ArrayRef<int64_t> targetShape,
-                                               int64_t index) {
-  SmallVector<int64_t, 4> dstSliceStrides =
-      computeStrides(originalShape, targetShape);
-  SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
-  SmallVector<int64_t, 4> elementOffsets =
-      computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
-  return elementOffsets;
+static SmallVector<int64_t> getVectorOffset(ArrayRef<int64_t> ratioStrides,
+                                            int64_t index,
+                                            ArrayRef<int64_t> targetShape) {
+  return computeElementwiseMul(delinearize(ratioStrides, index), targetShape);
 }
 
-/// A functor that accomplishes the same thing as `getVectorOffset` but allows
-/// for reordering the traversal of the dimensions. The order of traversal is
-/// given in "for loop order" (outer to inner).
+/// A functor that accomplishes the same thing as `getVectorOffset` but
+/// allows for reordering the traversal of the dimensions. The order of
+/// traversal is given in "for loop order" (outer to inner).
 namespace {
 class DecomposeShapeIterator {
 private:
-  SmallVector<int64_t, 4> vectorShape;
+  SmallVector<int64_t> vectorShape;
   SmallVector<int64_t> loopOrder;
   SmallVector<int64_t> sliceStrides;
   int64_t maxIndexVal{1};
@@ -56,15 +51,15 @@ class DecomposeShapeIterator {
       : vectorShape(targetShape.begin(), targetShape.end()),
         loopOrder(loopOrder.begin(), loopOrder.end()),
         sliceStrides(originalShape.size()) {
-    assert(originalShape.size() == targetShape.size());
-    assert(loopOrder.size() == targetShape.size());
+    assert(originalShape.size() >= targetShape.size());
+    assert(loopOrder.size() == originalShape.size());
 
     // Compute the count for each dimension.
-    SmallVector<int64_t> sliceDimCounts(originalShape.size());
-    for (unsigned r = 0; r < originalShape.size(); ++r) {
-      sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
-      maxIndexVal *= sliceDimCounts[r];
-    }
+    auto maybeShapeRatio = computeShapeRatio(originalShape, targetShape);
+    assert(maybeShapeRatio && "Shape does not evenly divide");
+    // Pad `sliceDimCounts` with leading 1s so that all sizes match.
+    SmallVector<int64_t> sliceDimCounts = *maybeShapeRatio;
+    maxIndexVal = computeMaxLinearIndex(sliceDimCounts);
 
     // Reversing "loop order" gives dimensions from fastest varying to slowest
     // varying (smallest stride to largest stride).
@@ -95,7 +90,7 @@ class DecomposeShapeIterator {
   SmallVector<int64_t> getVectorOffset(int64_t index) const {
     SmallVector<int64_t> vectorOffsets = delinearize(index);
     SmallVector<int64_t> elementOffsets =
-        computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
+        computeElementwiseMul(vectorShape, vectorOffsets);
     return elementOffsets;
   }
 };
@@ -139,7 +134,7 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
 
 /// Return the target shape for unrolling for the given `op`. Return llvm::None
 /// if the op shouldn't be or cannot be unrolled.
-static Optional<SmallVector<int64_t, 4>>
+static Optional<SmallVector<int64_t>>
 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
   if (options.filterConstraint && failed(options.filterConstraint(op)))
     return llvm::None;
@@ -152,10 +147,10 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
   if (!maybeUnrollShape)
     return llvm::None;
-  Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
+  Optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
   if (!targetShape)
     return llvm::None;
-  auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
+  auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
   if (!maybeShapeRatio ||
       llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
     return llvm::None;
@@ -197,7 +192,7 @@ struct UnrollTransferReadPattern
     if (!targetShape)
       return failure();
     auto sourceVectorType = readOp.getVectorType();
-    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+    SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = readOp.getLoc();
     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
 
@@ -206,17 +201,16 @@ struct UnrollTransferReadPattern
         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
     auto targetType =
         VectorType::get(*targetShape, sourceVectorType.getElementType());
-    SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
-                                          readOp.getIndices().end());
+    SmallVector<Value> originalIndices(readOp.getIndices().begin(),
+                                       readOp.getIndices().end());
 
     SmallVector<int64_t> loopOrder =
         getUnrollOrder(originalSize.size(), readOp, options);
     DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
                                           loopOrder);
     for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
-      SmallVector<int64_t, 4> elementOffsets =
-          indexToOffsets.getVectorOffset(i);
-      SmallVector<Value, 4> indices =
+      SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
+      SmallVector<Value> indices =
           sliceTransferIndices(elementOffsets, originalIndices,
                                readOp.getPermutationMap(), loc, rewriter);
       auto slicedRead = rewriter.create<vector::TransferReadOp>(
@@ -255,11 +249,11 @@ struct UnrollTransferWritePattern
     if (!targetShape)
       return failure();
     auto sourceVectorType = writeOp.getVectorType();
-    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+    SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = writeOp.getLoc();
     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
-    SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
-                                          writeOp.getIndices().end());
+    SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
+                                       writeOp.getIndices().end());
 
     SmallVector<int64_t> loopOrder =
         getUnrollOrder(originalSize.size(), writeOp, options);
@@ -267,11 +261,10 @@ struct UnrollTransferWritePattern
                                           loopOrder);
     Value resultTensor;
     for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
-      SmallVector<int64_t, 4> elementOffsets =
-          indexToOffsets.getVectorOffset(i);
+      SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
-      SmallVector<Value, 4> indices =
+      SmallVector<Value> indices =
           sliceTransferIndices(elementOffsets, originalIndices,
                                writeOp.getPermutationMap(), loc, rewriter);
       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
@@ -321,7 +314,7 @@ struct UnrollContractionPattern
     if (!targetShape)
       return failure();
     auto dstVecType = contractOp.getResultType().cast<VectorType>();
-    SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
+    SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
 
     Location loc = contractOp.getLoc();
     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
@@ -337,16 +330,16 @@ struct UnrollContractionPattern
                                           loopOrder);
     const int64_t sliceCount = indexToOffsets.maxIndex();
     for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
-      SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
+      SmallVector<int64_t> offsets = indexToOffsets.getVectorOffset(i);
+      SmallVector<Value> slicesOperands(contractOp.getNumOperands());
 
-      // Helper to coompute the new shape of each operand and extract the slice.
+      // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](unsigned index, Value operand,
                                 AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffets) {
         SmallVector<int64_t> operandShape = applyPermutationMap(
             permutationMap, ArrayRef<int64_t>(*targetShape));
-        SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
+        SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
         slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
             loc, operand, operandOffets, operandShape, operandStrides);
       };
@@ -420,12 +413,12 @@ struct UnrollMultiReductionPattern
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
                                 PatternRewriter &rewriter) const override {
-    Optional<SmallVector<int64_t, 4>> targetShape =
+    Optional<SmallVector<int64_t>> targetShape =
         getTargetShape(options, reductionOp);
     if (!targetShape)
       return failure();
-    SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+    SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
+    SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
     llvm::MapVector<
         SmallVector<int64_t>, Value,
         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
@@ -433,12 +426,16 @@ struct UnrollMultiReductionPattern
     // Compute shape ratio of 'shape' and 'sizes'.
     int64_t sliceCount = computeMaxLinearIndex(ratio);
     Location loc = reductionOp.getLoc();
+
+    // Stride of the ratios, this gives us the offsets of sliceCount in a basis
+    // of multiples of the targetShape.
+    auto ratioStrides = computeStrides(ratio);
     for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t, 4> offsets =
-          getVectorOffset(originalSize, *targetShape, i);
+      SmallVector<int64_t> offsets =
+          getVectorOffset(ratioStrides, i, *targetShape);
 
       SmallVector<Value> operands;
-      SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
+      SmallVector<int64_t> operandStrides(offsets.size(), 1);
       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
           loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
       operands.push_back(slicedOperand);
@@ -451,7 +448,7 @@ struct UnrollMultiReductionPattern
         }
       }
       Value acc;
-      SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
+      SmallVector<int64_t> accStrides(destOffset.size(), 1);
       // If a version of the accumulator has already been computed, use it
       // otherwise extract the first version from the original operand.
       auto accIt = accCache.find(destOffset);
@@ -500,21 +497,25 @@ struct UnrollElementwisePattern : public RewritePattern {
     if (!targetShape)
       return failure();
     auto dstVecType = op->getResult(0).getType().cast<VectorType>();
-    SmallVector<int64_t, 4> originalSize =
+    SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+    SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
     int64_t sliceCount = computeMaxLinearIndex(ratio);
     Location loc = op->getLoc();
     // Prepare the result vector.
     Value result = rewriter.create<arith::ConstantOp>(
         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
-    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+    SmallVector<int64_t> strides(targetShape->size(), 1);
     VectorType newVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
+
+    // Stride of the ratios, this gives us the offsets of sliceCount in a basis
+    // of multiples of the targetShape.
+    auto ratioStrides = computeStrides(ratio);
     for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t, 4> offsets =
-          getVectorOffset(originalSize, *targetShape, i);
-      SmallVector<Value, 4> extractOperands;
+      SmallVector<int64_t> offsets =
+          getVectorOffset(ratioStrides, i, *targetShape);
+      SmallVector<Value> extractOperands;
       for (OpOperand &operand : op->getOpOperands()) {
         auto vecType = operand.get().getType().template dyn_cast<VectorType>();
         if (!vecType) {
@@ -547,19 +548,24 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
 
   LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
                                 PatternRewriter &rewriter) const override {
-    Optional<SmallVector<int64_t, 4>> targetShape =
+    Optional<SmallVector<int64_t>> targetShape =
         getTargetShape(options, reductionOp);
     if (!targetShape)
       return failure();
     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
-    int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
+    auto ratio = *computeShapeRatio(originalSize, *targetShape);
+    int64_t sliceCount = ratio[0];
 
     // Create unrolled vector reduction.
     Location loc = reductionOp.getLoc();
     Value accumulator = nullptr;
-    for (int64_t i = 0; i < ratio; ++i) {
+
+    // Stride of the ratios, this gives us the offsets of sliceCount in a basis
+    // of multiples of the targetShape.
+    auto ratioStrides = computeStrides(ratio);
+    for (int64_t i = 0; i < sliceCount; ++i) {
       SmallVector<int64_t> offsets =
-          getVectorOffset(originalSize, *targetShape, i);
+          getVectorOffset(ratioStrides, i, *targetShape);
       SmallVector<int64_t> strides(offsets.size(), 1);
       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
           loc, reductionOp.getVector(), offsets, *targetShape, strides);
@@ -600,21 +606,25 @@ struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
     if (!targetShape)
       return failure();
     auto originalVectorType = tranposeOp.getResultType();
-    SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+    SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = tranposeOp.getLoc();
     ArrayRef<int64_t> originalSize = originalVectorType.getShape();
-    SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+    SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
     int64_t sliceCount = computeMaxLinearIndex(ratio);
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
     SmallVector<int64_t> permutation;
     tranposeOp.getTransp(permutation);
+
+    // Stride of the ratios, this gives us the offsets of sliceCount in a basis
+    // of multiples of the targetShape.
+    auto ratioStrides = computeStrides(ratio);
     for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t, 4> elementOffsets =
-          getVectorOffset(originalSize, *targetShape, i);
-      SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
-      SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
+      SmallVector<int64_t> elementOffsets =
+          getVectorOffset(ratioStrides, i, *targetShape);
+      SmallVector<int64_t> permutedOffsets(elementOffsets.size());
+      SmallVector<int64_t> permutedShape(elementOffsets.size());
       // Compute the source offsets and shape.
       for (auto &indices : llvm::enumerate(permutation)) {
         permutedOffsets[indices.value()] = elementOffsets[indices.index()];

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 4e7ed8db19f49..8e3cd03f956cd 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IntegerSet.h"
@@ -25,7 +26,6 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/MathExtras.h"
-#include <numeric>
 
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SetVector.h"
@@ -43,78 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
   llvm_unreachable("Expected MemRefType or TensorType");
 }
 
-/// Return the number of elements of basis, `0` if empty.
-int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
-  if (basis.empty())
-    return 0;
-  return std::accumulate(basis.begin(), basis.end(), 1,
-                         std::multiplies<int64_t>());
-}
-
-SmallVector<int64_t, 4> mlir::computeStrides(ArrayRef<int64_t> shape,
-                                             ArrayRef<int64_t> sizes) {
-  int64_t rank = shape.size();
-  // Compute the count for each dimension.
-  SmallVector<int64_t, 4> sliceDimCounts(rank);
-  for (int64_t r = 0; r < rank; ++r)
-    sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);
-  // Use that to compute the slice stride for each dimension.
-  SmallVector<int64_t, 4> sliceStrides(rank);
-  sliceStrides[rank - 1] = 1;
-  for (int64_t r = rank - 2; r >= 0; --r)
-    sliceStrides[r] = sliceStrides[r + 1] * sliceDimCounts[r + 1];
-  return sliceStrides;
-}
-
-SmallVector<int64_t, 4> mlir::computeElementOffsetsFromVectorSliceOffsets(
-    ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
-  SmallVector<int64_t, 4> result;
-  for (auto it : llvm::zip(vectorOffsets, sizes))
-    result.push_back(std::get<0>(it) * std::get<1>(it));
-  return result;
-}
-
-Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
-                                                   ArrayRef<int64_t> subShape) {
-  if (superShape.size() < subShape.size()) {
-    return None;
-  }
-
-  // Starting from the end, compute the integer divisors.
-  std::vector<int64_t> result;
-  result.reserve(superShape.size());
-  for (auto [superSize, subSize] :
-       llvm::zip(llvm::reverse(superShape), llvm::reverse(subShape))) {
-    assert(superSize > 0 && "superSize must be > 0");
-    assert(subSize > 0 && "subSize must be > 0");
-
-    // If integral division does not occur, return and let the caller decide.
-    if (superSize % subSize != 0)
-      return None;
-    result.push_back(superSize / subSize);
-  }
-
-  // At this point we computed the ratio (in reverse) for the common
-  // size. Fill with the remaining entries from the super-vector shape (still in
-  // reverse).
-  int commonSize = subShape.size();
-  std::copy(superShape.rbegin() + commonSize, superShape.rend(),
-            std::back_inserter(result));
-
-  assert(result.size() == superShape.size() &&
-         "super to sub shape ratio is not of the same size as the super rank");
-
-  // Reverse again to get it back in the proper order and return.
-  return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
-}
-
-Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
-                                                   VectorType subVectorType) {
-  assert(superVectorType.getElementType() == subVectorType.getElementType() &&
-         "vector types must be of the same elemental type");
-  return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
-}
-
 /// Constructs a permutation map from memref indices to vector dimension.
 ///
 /// The implementation uses the knowledge of the mapping of enclosing loop to
@@ -144,8 +72,8 @@ static AffineMap makePermutationMap(
     return AffineMap();
   MLIRContext *context =
       enclosingLoopToVectorDim.begin()->getFirst()->getContext();
-  SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
-                                  getAffineConstantExpr(0, context));
+  SmallVector<AffineExpr> perm(enclosingLoopToVectorDim.size(),
+                               getAffineConstantExpr(0, context));
 
   for (auto kvp : enclosingLoopToVectorDim) {
     assert(kvp.second < perm.size());
@@ -252,7 +180,8 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
   }
 
   // Get the ratio.
-  auto ratio = shapeRatio(superVectorType, subVectorType);
+  auto ratio =
+      computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
 
   // Sanity check.
   assert((ratio || !mustDivide) &&

diff  --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index 86cb5cea8a63f..61428bbf7091f 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
@@ -126,7 +127,8 @@ void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
     // purpose of this test. If we need to test more intricate behavior in the
     // future we can always extend.
     auto superVectorType = opInst->getResult(0).getType().cast<VectorType>();
-    auto ratio = shapeRatio(superVectorType, subVectorType);
+    auto ratio =
+        computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
     if (!ratio) {
       opInst->emitRemark("NOT MATCHED");
     } else {

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 6b9afe3b42e71..1bd40e7cde7e3 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -72,11 +72,11 @@ struct TestVectorToVectorLowering
 
 private:
   // Return the target shape based on op type.
-  static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
+  static Optional<SmallVector<int64_t>> getShape(Operation *op) {
     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
-      return SmallVector<int64_t, 4>(2, 2);
+      return SmallVector<int64_t>(2, 2);
     if (isa<vector::ContractionOp>(op))
-      return SmallVector<int64_t, 4>(3, 2);
+      return SmallVector<int64_t>(3, 2);
     // For transfer ops, just propagate the shape coming from
     // InsertStridedSlices/ExtractStridedSlices.
     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
@@ -90,15 +90,15 @@ struct TestVectorToVectorLowering
           return llvm::None;
         dstVec = vecType;
       }
-      return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
-                                     dstVec.getShape().end());
+      return SmallVector<int64_t>(dstVec.getShape().begin(),
+                                  dstVec.getShape().end());
     }
     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
       if (!insert)
         return llvm::None;
       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
-      return SmallVector<int64_t, 4>(shape.begin(), shape.end());
+      return SmallVector<int64_t>(shape.begin(), shape.end());
     }
     return llvm::None;
   }
@@ -314,10 +314,10 @@ struct TestVectorUnrollingPatterns
 
     if (unrollBasedOnType) {
       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
-          [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
+          [](Operation *op) -> Optional<SmallVector<int64_t>> {
         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
-        SmallVector<int64_t, 4> nativeShape(
-            contractOp.getIteratorTypes().size(), 4);
+        SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
+                                         4);
         Type lhsType = contractOp.getLhsType().getElementType();
         nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
         return nativeShape;
@@ -339,12 +339,11 @@ struct TestVectorUnrollingPatterns
       }
       populateVectorUnrollPatterns(patterns, opts);
     } else {
-      auto nativeShapeFn =
-          [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
+      auto nativeShapeFn = [](Operation *op) -> Optional<SmallVector<int64_t>> {
         auto contractOp = dyn_cast<ContractionOp>(op);
         if (!contractOp)
           return None;
-        return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
+        return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
       };
       populateVectorUnrollPatterns(patterns,
                                    UnrollVectorOptions()

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index eabc46ccda2b5..290c9c5777611 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -61,7 +61,6 @@ struct LinalgTransformationFilter {
   LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
   void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
                                          Operation *op) const;
-  bool hasReplacementFilter(Operation *op) const;
 
   LinalgTransformationFilter &addFilter(const FilterFunction &f) {
     if (f)
@@ -100,15 +99,6 @@ LinalgTransformationFilter::LinalgTransformationFilter(
     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
       replacement(replacement), matchByDefault(false) {}
 
-LinalgTransformationFilter::LinalgTransformationFilter(
-    const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
-    Optional<StringAttr> replacement)
-    : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
-      replacement(replacement), matchByDefault(false) {
-  if (f)
-    filters.push_back(f);
-}
-
 LogicalResult
 LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
                                            Operation *op) const {
@@ -150,13 +140,6 @@ void LinalgTransformationFilter::replaceLinalgTransformationFilter(
     op->removeAttr(rewriter.getStringAttr(kLinalgTransformMarker));
 }
 
-bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const {
-  if (!replacement)
-    return false;
-  auto attr = op->getAttr(kLinalgTransformMarker).dyn_cast<StringAttr>();
-  return attr && attr == *replacement;
-}
-
 /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
 /// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
 /// using a `filter` to avoid recursive application.


        


More information about the Mlir-commits mailing list