[Mlir-commits] [mlir] 831041b - [mlir][vector] Cleanup VectorUnroll and create a generic tile iteration utility

Christopher Bate llvmlistbot at llvm.org
Thu Sep 14 19:34:54 PDT 2023


Author: Christopher Bate
Date: 2023-09-14T20:34:44-06:00
New Revision: 831041be797b099b4e3805db368bacb1d1abab5d

URL: https://github.com/llvm/llvm-project/commit/831041be797b099b4e3805db368bacb1d1abab5d
DIFF: https://github.com/llvm/llvm-project/commit/831041be797b099b4e3805db368bacb1d1abab5d.diff

LOG: [mlir][vector] Cleanup VectorUnroll and create a generic tile iteration utility

This change refactors some of the utilities used to unroll larger vector
computations into smaller vector computations. In fact, the indexing
computations used here are rather generic and are useful in other dialects or
downstream projects. Therefore, a utility for iterating over all possible tile
offsets for a particular pair of static (shape, tiled shape) is introduced in
IndexingUtils and replaces the existing computations in the vector unrolling
transformations. This builds off of the refactoring of IndexingUtils introduced
in 203fad476b7e.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D150000

Added: 
    mlir/unittests/Dialect/Utils/IndexingUtilsTest.cpp

Modified: 
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/include/mlir/IR/AffineExpr.h
    mlir/lib/Dialect/Utils/IndexingUtils.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
    mlir/lib/IR/AffineExpr.cpp
    mlir/unittests/Dialect/Utils/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index cb8419374c43e3a..f51a8b28b7548ed 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -18,7 +18,9 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator.h"
 #include <optional>
+#include <utility>
 
 namespace mlir {
 class ArrayAttr;
@@ -195,6 +197,23 @@ SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
 // Permutation utils.
 //===----------------------------------------------------------------------===//
 
+template <typename T>
+SmallVector<T> applyPermutation(ArrayRef<T> input,
+                                ArrayRef<int64_t> permutation) {
+  assert(input.size() == permutation.size() &&
+         "expected input rank to equal permutation rank");
+  auto permutationRange = llvm::map_range(
+      llvm::seq<unsigned>(0, input.size()),
+      [&](int64_t idx) -> T { return input[permutation[idx]]; });
+  return llvm::to_vector(permutationRange);
+}
+
+template <typename T>
+SmallVector<T> applyPermutation(const SmallVectorImpl<T> &input,
+                                ArrayRef<int64_t> permutation) {
+  return applyPermutation(ArrayRef(input), permutation);
+}
+
 /// Apply the permutation defined by `permutation` to `inVec`.
 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
 /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
@@ -203,10 +222,7 @@ SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
 template <typename T, unsigned N>
 void applyPermutationToVector(SmallVector<T, N> &inVec,
                               ArrayRef<int64_t> permutation) {
-  SmallVector<T, N> auxVec(inVec.size());
-  for (const auto &en : enumerate(permutation))
-    auxVec[en.index()] = inVec[en.value()];
-  inVec = auxVec;
+  inVec = applyPermutation(inVec, permutation);
 }
 
 /// Helper method to apply to inverse a permutation.
@@ -239,6 +255,138 @@ std::pair<AffineExpr, SmallVector<OpFoldResult>>
 computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
                    ArrayRef<OpFoldResult> indices);
 
+//===----------------------------------------------------------------------===//
+// Utilities for decomposing larger shapes
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// Encapsulates the set of parameters that are used to make tile offset
+/// calculations in the TileOffsetRangeIterator.
+class TileOffsetRangeImpl {
+public:
+  TileOffsetRangeImpl(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
+                      ArrayRef<int64_t> loopOrder);
+
+  int64_t getMaxLinearIndex() const { return maxLinearIndex; }
+
+  SmallVector<int64_t> getStaticTileOffsets(int64_t linearIndex) const;
+
+  SmallVector<AffineExpr> getDynamicTileOffsets(AffineExpr linearIndex) const;
+
+  template <typename T>
+  SmallVector<T> getTileOffsets(T linearIndex) const {
+    if constexpr (std::is_same_v<T, int64_t>)
+      return getStaticTileOffsets(linearIndex);
+    else
+      return getDynamicTileOffsets(linearIndex);
+  }
+
+private:
+  /// The sub-shape that divides the larger outer shape (which is provided to
+  /// the constructor).
+  SmallVector<int64_t> tileShape;
+  /// The inverse permutation to the `loopOrder` permutation provided in the
+  /// constructor.
+  SmallVector<int64_t> inverseLoopOrder;
+  /// The strides for the basis 'div(shape, tileShape)' permuted by `loopOrder`.
+  SmallVector<int64_t> sliceStrides;
+  /// The maximum linear index in the iteration space given by basis 'div(shape,
+  /// tileShape)'.
+  int64_t maxLinearIndex;
+};
+
+/// The STL-style iterator implementation for StaticTileOffsetRange.
+template <typename ElementType>
+class TileOffsetRangeIterator
+    : public llvm::iterator_facade_base<TileOffsetRangeIterator<ElementType>,
+                                        std::forward_iterator_tag,
+                                        SmallVector<ElementType>> {
+public:
+  TileOffsetRangeIterator(const TileOffsetRangeImpl &params, ElementType index)
+      : params(params), index(index) {}
+
+  void operator++() { incrementIndex(1); }
+  TileOffsetRangeIterator operator++(int) {
+    const auto copy = *this;
+    ++*this;
+    return copy;
+  }
+
+  bool operator==(const TileOffsetRangeIterator &other) const {
+    return index == other.index;
+  }
+  bool operator!=(const TileOffsetRangeIterator &other) const {
+    return index != other.index;
+  }
+
+  SmallVector<ElementType> operator*() const {
+    return params.getTileOffsets(index);
+  }
+  void operator+=(int64_t offset) { incrementIndex(offset); }
+
+private:
+  void incrementIndex(int64_t offset) { index = index + offset; }
+  const TileOffsetRangeImpl params;
+  int64_t index;
+};
+} // namespace detail
+
+/// A range-style iterator that allows for iterating over the offsets of all
+/// potential tiles of size `tileShape` within the larger shape `shape`, using
+/// an ordering specified by `loopOrder`. The `loopOrder` specifies the order of
+/// unrolling by numbering the dimensions in order from "outer most for loop"
+/// (slowest changing) to "inner most for loop" (fastest changing).
+///
+/// For example, for `shape = {10, 20, 30}`, `tileShape = {5, 10, 15}`, and
+/// `loopOrder={2, 0, 1}`, the iterating over this range will yield offsets:
+///
+/// ```
+/// {0, 0,  0}, {0, 10,  0}, {5, 0,  0}, {5, 10,  0}, {0, 0, 15},
+/// {0, 10, 15}, {5, 0, 15}, {0, 10, 15}, {5, 10, 15}
+/// ```
+///
+/// This is useful in contexts where a vector computation over a larger shape
+/// needs to be unrolled to a set of operations on subsets of the original
+/// operands, such as during the "vector unrolling" transformations.
+///
+/// The size of `tileShape` must be less-than-or-equal-to the size of `shape`.a
+/// If the rank of `tileShape` is smaller than `shape`, then `tileShape`
+/// elements correspond to the trailing dimensions of `shape`, and the leading
+/// dimensions are considered untiled and `tileShape` is effectively prepended
+/// with the leading dims of `shape`.
+class StaticTileOffsetRange {
+public:
+  using IteratorTy = detail::TileOffsetRangeIterator<int64_t>;
+  using ParamsTy = detail::TileOffsetRangeImpl;
+
+  StaticTileOffsetRange(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
+                        ArrayRef<int64_t> loopOrder)
+      : params(shape, tileShape, loopOrder), beginValue(params, 0),
+        pastEndValue(params, params.getMaxLinearIndex()) {
+    assert(shape.size() >= tileShape.size());
+    assert(loopOrder.size() == shape.size());
+  }
+
+  /// Create the range with identity loop order.
+  StaticTileOffsetRange(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape)
+      : params(shape, tileShape,
+               llvm::to_vector(llvm::seq<int64_t>(0, shape.size()))),
+        beginValue(params, 0),
+        pastEndValue(params, params.getMaxLinearIndex()) {
+    assert(shape.size() >= tileShape.size());
+  }
+
+  IteratorTy begin() const { return beginValue; }
+  IteratorTy end() const { return pastEndValue; }
+
+  /// Returns the total number of tiles that fit in the larger shape.
+  size_t size() const { return params.getMaxLinearIndex(); }
+
+private:
+  const ParamsTy params;
+  IteratorTy beginValue;
+  IteratorTy pastEndValue;
+};
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H

diff  --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 352c5873c0832fc..8ced8770591ee8c 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -17,6 +17,7 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Casting.h"
 #include <functional>
 #include <type_traits>
@@ -250,6 +251,8 @@ inline AffineExpr operator-(int64_t val, AffineExpr expr) {
 AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
 AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
 AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
+SmallVector<AffineExpr> getAffineConstantExprs(ArrayRef<int64_t> constants,
+                                               MLIRContext *context);
 AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
                                  AffineExpr rhs);
 

diff  --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 8afa51a20345578..f4e29539214b4b6 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -181,9 +181,8 @@ AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
 
 AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
                            ArrayRef<int64_t> basis) {
-  SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
-      basis, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
-  return linearize(ctx, offsets, basisExprs);
+
+  return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx));
 }
 
 SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
@@ -196,9 +195,7 @@ SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
 SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
                                           ArrayRef<int64_t> strides) {
   MLIRContext *ctx = linearIndex.getContext();
-  SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
-      strides, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
-  return delinearize(linearIndex, ArrayRef<AffineExpr>{basisExprs});
+  return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
 }
 
 //===----------------------------------------------------------------------===//
@@ -302,3 +299,56 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
 
   return {expr, values};
 }
+
+//===----------------------------------------------------------------------===//
+// TileOffsetRange
+//===----------------------------------------------------------------------===//
+
+/// Apply left-padding by 1 to the tile shape if required.
+static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape,
+                                               unsigned paddedSize) {
+  assert(tileShape.size() <= paddedSize &&
+         "expected tileShape to <= paddedSize");
+  if (tileShape.size() == paddedSize)
+    return to_vector(tileShape);
+  SmallVector<int64_t> result(paddedSize - tileShape.size(), 1);
+  llvm::append_range(result, tileShape);
+  return result;
+}
+
+mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl(
+    ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
+    ArrayRef<int64_t> loopOrder)
+    : tileShape(padTileShapeToSize(tileShape, shape.size())),
+      inverseLoopOrder(invertPermutationVector(loopOrder)),
+      sliceStrides(shape.size()) {
+  // Divide the shape by the tile shape.
+  std::optional<SmallVector<int64_t>> shapeRatio =
+      mlir::computeShapeRatio(shape, tileShape);
+  assert(shapeRatio && shapeRatio->size() == shape.size() &&
+         "target shape does not evenly divide the original shape");
+  assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() &&
+         "expected loop order to be a permutation of rank equal to outer "
+         "shape");
+
+  maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio);
+  mlir::applyPermutationToVector(*shapeRatio, loopOrder);
+  sliceStrides = mlir::computeStrides(*shapeRatio);
+}
+
+SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets(
+    int64_t linearIndex) const {
+  SmallVector<int64_t> tileCoords = applyPermutation(
+      delinearize(linearIndex, sliceStrides), inverseLoopOrder);
+  return computeElementwiseMul(tileCoords, tileShape);
+}
+
+SmallVector<AffineExpr>
+mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets(
+    AffineExpr linearIndex) const {
+  MLIRContext *ctx = linearIndex.getContext();
+  SmallVector<AffineExpr> tileCoords = applyPermutation(
+      delinearize(linearIndex, sliceStrides), inverseLoopOrder);
+  return mlir::computeElementwiseMul(tileCoords,
+                                     getAffineConstantExprs(tileShape, ctx));
+}

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1a8f78031ba4f87..6a45231eb80bcea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -29,77 +29,6 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-/// During unrolling from `originalShape` to `targetShape` return the offset for
-/// the slice `index`.
-static SmallVector<int64_t> getVectorOffset(ArrayRef<int64_t> ratioStrides,
-                                            int64_t index,
-                                            ArrayRef<int64_t> targetShape) {
-  return computeElementwiseMul(delinearize(index, ratioStrides), targetShape);
-}
-
-/// A functor that accomplishes the same thing as `getVectorOffset` but
-/// allows for reordering the traversal of the dimensions. The order of
-/// traversal is given in "for loop order" (outer to inner).
-namespace {
-class DecomposeShapeIterator {
-private:
-  SmallVector<int64_t> vectorShape;
-  SmallVector<int64_t> loopOrder;
-  SmallVector<int64_t> sliceStrides;
-  int64_t maxIndexVal{1};
-
-public:
-  DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
-                         ArrayRef<int64_t> targetShape,
-                         ArrayRef<int64_t> loopOrder)
-      : vectorShape(targetShape.begin(), targetShape.end()),
-        loopOrder(loopOrder.begin(), loopOrder.end()),
-        sliceStrides(originalShape.size()) {
-    assert(originalShape.size() >= targetShape.size());
-    assert(loopOrder.size() == originalShape.size());
-
-    // Compute the count for each dimension.
-    auto maybeShapeRatio = computeShapeRatio(originalShape, targetShape);
-    assert(maybeShapeRatio && "Shape does not evenly divide");
-    // Pad `sliceDimCounts` with leading 1s so that all sizes match.
-    SmallVector<int64_t> sliceDimCounts = *maybeShapeRatio;
-    maxIndexVal = computeMaxLinearIndex(sliceDimCounts);
-
-    // Reversing "loop order" gives dimensions from fastest varying to slowest
-    // varying (smallest stride to largest stride).
-    int64_t accum = 1;
-    for (auto idx : llvm::reverse(loopOrder)) {
-      sliceStrides[idx] = accum;
-      accum *= sliceDimCounts[idx];
-    }
-  }
-
-  // Turn the linear index into a d-tuple based on units of vectors of size
-  // `vectorShape`. The linear index is assumed to represent traversal of the
-  // dimensions based on `order`.
-  SmallVector<int64_t> delinearize(int64_t index) const {
-    // Traverse in for loop order (largest stride to smallest stride).
-    SmallVector<int64_t> vectorOffsets(sliceStrides.size());
-    for (auto idx : loopOrder) {
-      vectorOffsets[idx] = index / sliceStrides[idx];
-      index %= sliceStrides[idx];
-    }
-    return vectorOffsets;
-  }
-
-  int64_t maxIndex() const { return maxIndexVal; }
-
-  /// Return the offset within d-tuple based on the ordering given by
-  /// `loopOrder`.
-  SmallVector<int64_t> getVectorOffset(int64_t index) const {
-    SmallVector<int64_t> vectorOffsets = delinearize(index);
-    SmallVector<int64_t> elementOffsets =
-        computeElementwiseMul(vectorShape, vectorOffsets);
-    return elementOffsets;
-  }
-};
-} // namespace
-
 /// Compute the indices of the slice `index` for a tranfer op.
 static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
                                                ArrayRef<Value> indices,
@@ -232,13 +161,10 @@ struct UnrollTransferReadPattern
         VectorType::get(*targetShape, sourceVectorType.getElementType());
     SmallVector<Value> originalIndices(readOp.getIndices().begin(),
                                        readOp.getIndices().end());
-
     SmallVector<int64_t> loopOrder =
         getUnrollOrder(originalSize.size(), readOp, options);
-    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
-                                          loopOrder);
-    for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
-      SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
+    for (SmallVector<int64_t> elementOffsets :
+         StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
       SmallVector<Value> indices =
           sliceTransferIndices(elementOffsets, originalIndices,
                                readOp.getPermutationMap(), loc, rewriter);
@@ -283,14 +209,11 @@ struct UnrollTransferWritePattern
     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
     SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
                                        writeOp.getIndices().end());
-
     SmallVector<int64_t> loopOrder =
         getUnrollOrder(originalSize.size(), writeOp, options);
-    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
-                                          loopOrder);
     Value resultTensor;
-    for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
-      SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
+    for (SmallVector<int64_t> elementOffsets :
+         StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
       SmallVector<Value> indices =
@@ -355,11 +278,9 @@ struct UnrollContractionPattern
 
     SmallVector<int64_t> loopOrder = getUnrollOrder(
         contractOp.getIteratorTypes().size(), contractOp, options);
-    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
-                                          loopOrder);
-    const int64_t sliceCount = indexToOffsets.maxIndex();
-    for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t> offsets = indexToOffsets.getVectorOffset(i);
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
       SmallVector<Value> slicesOperands(contractOp.getNumOperands());
 
       // Helper to compute the new shape of each operand and extract the slice.
@@ -439,22 +360,16 @@ struct UnrollMultiReductionPattern
     if (!targetShape)
       return failure();
     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
-    SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
     llvm::MapVector<
         SmallVector<int64_t>, Value,
         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
         accCache;
-    // Compute shape ratio of 'shape' and 'sizes'.
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
     Location loc = reductionOp.getLoc();
 
     // Stride of the ratios, this gives us the offsets of sliceCount in a basis
     // of multiples of the targetShape.
-    auto ratioStrides = computeStrides(ratio);
-    for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t> offsets =
-          getVectorOffset(ratioStrides, i, *targetShape);
-
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalSize, *targetShape)) {
       SmallVector<Value> operands;
       SmallVector<int64_t> operandStrides(offsets.size(), 1);
       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -520,8 +435,6 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
-    SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
     Location loc = op->getLoc();
     // Prepare the result vector.
     Value result = rewriter.create<arith::ConstantOp>(
@@ -530,12 +443,9 @@ struct UnrollElementwisePattern : public RewritePattern {
     VectorType newVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
 
-    // Stride of the ratios, this gives us the offsets of sliceCount in a basis
-    // of multiples of the targetShape.
-    auto ratioStrides = computeStrides(ratio);
-    for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t> offsets =
-          getVectorOffset(ratioStrides, i, *targetShape);
+    // Create the unrolled computation.
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalSize, *targetShape)) {
       SmallVector<Value> extractOperands;
       for (OpOperand &operand : op->getOpOperands()) {
         auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -574,19 +484,12 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
     if (!targetShape)
       return failure();
     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
-    auto ratio = *computeShapeRatio(originalSize, *targetShape);
-    int64_t sliceCount = ratio[0];
 
     // Create unrolled vector reduction.
     Location loc = reductionOp.getLoc();
     Value accumulator = nullptr;
-
-    // Stride of the ratios, this gives us the offsets of sliceCount in a basis
-    // of multiples of the targetShape.
-    auto ratioStrides = computeStrides(ratio);
-    for (int64_t i = 0; i < sliceCount; ++i) {
-      SmallVector<int64_t> offsets =
-          getVectorOffset(ratioStrides, i, *targetShape);
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalSize, *targetShape)) {
       SmallVector<int64_t> strides(offsets.size(), 1);
       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
           loc, reductionOp.getVector(), offsets, *targetShape, strides);
@@ -630,20 +533,16 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
     SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = transposeOp.getLoc();
     ArrayRef<int64_t> originalSize = originalVectorType.getShape();
-    SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
-    int64_t sliceCount = computeMaxLinearIndex(ratio);
+
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
     SmallVector<int64_t> permutation;
     transposeOp.getTransp(permutation);
 
-    // Stride of the ratios, this gives us the offsets of sliceCount in a basis
-    // of multiples of the targetShape.
-    auto ratioStrides = computeStrides(ratio);
-    for (int64_t i = 0; i < sliceCount; i++) {
-      SmallVector<int64_t> elementOffsets =
-          getVectorOffset(ratioStrides, i, *targetShape);
+    // Unroll the computation.
+    for (SmallVector<int64_t> elementOffsets :
+         StaticTileOffsetRange(originalSize, *targetShape)) {
       SmallVector<int64_t> permutedOffsets(elementOffsets.size());
       SmallVector<int64_t> permutedShape(elementOffsets.size());
       // Compute the source offsets and shape.
@@ -694,13 +593,11 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
 
     SmallVector<int64_t> loopOrder =
         getUnrollOrder(originalSize.size(), gatherOp, options);
-    DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
-                                          loopOrder);
-    for (int64_t i = 0, e = indexToOffsets.maxIndex(); i < e; ++i) {
+    for (SmallVector<int64_t> elementOffsets :
+         StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
       // To get the unrolled gather, extract the same slice based on the
       // decomposed shape from each of the index, mask, and pass-through
       // vectors.
-      SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
       Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
           loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
       Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(

diff  --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 8564bacedd21c38..7eccbca4e6e7a1a 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -533,6 +533,14 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
   return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
 }
 
+SmallVector<AffineExpr>
+mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
+                             MLIRContext *context) {
+  return llvm::to_vector(llvm::map_range(constants, [&](int64_t constant) {
+    return getAffineConstantExpr(constant, context);
+  }));
+}
+
 /// Simplify add expression. Return nullptr if it can't be simplified.
 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();

diff  --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt
index d75b6936e69ca72..116e094e7706ab2 100644
--- a/mlir/unittests/Dialect/Utils/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRDialectUtilsTests
   StructuredOpsUtilsTest.cpp
+  IndexingUtilsTest.cpp
 )
 target_link_libraries(MLIRDialectUtilsTests
   PRIVATE

diff  --git a/mlir/unittests/Dialect/Utils/IndexingUtilsTest.cpp b/mlir/unittests/Dialect/Utils/IndexingUtilsTest.cpp
new file mode 100644
index 000000000000000..4b68e4214e7718b
--- /dev/null
+++ b/mlir/unittests/Dialect/Utils/IndexingUtilsTest.cpp
@@ -0,0 +1,71 @@
+//===- IndexingUtilsTest.cpp - IndexingUtils unit tests -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+TEST(StaticTileOffsetRange, checkIteratorCanonicalOrder) {
+  // Tile <4x8> by <2x4> with canonical row-major order.
+  std::vector<SmallVector<int64_t>> expected = {{0, 0}, {0, 4}, {2, 0}, {2, 4}};
+  for (auto [idx, tileOffset] :
+       llvm::enumerate(StaticTileOffsetRange({4, 8}, {2, 4}, {0, 1})))
+    EXPECT_EQ(tileOffset, expected[idx]);
+
+  // Check the constructor for default order and test use with zip iterator.
+  for (auto [tileOffset, tileOffsetDefault] :
+       llvm::zip(StaticTileOffsetRange({4, 8}, {2, 4}, {0, 1}),
+                 StaticTileOffsetRange({4, 8}, {2, 4})))
+    EXPECT_EQ(tileOffset, tileOffsetDefault);
+}
+
+TEST(StaticTileOffsetRange, checkIteratorRowMajorOrder) {
+  // Tile <4x8> by <2x4> with canonical row-major order.
+  std::vector<SmallVector<int64_t>> expected = {{0, 0}, {2, 0}, {0, 4}, {2, 4}};
+  for (auto [idx, tileOffset] :
+       llvm::enumerate(StaticTileOffsetRange({4, 8}, {2, 4}, {1, 0})))
+    EXPECT_EQ(tileOffset, expected[idx]);
+}
+
+TEST(StaticTileOffsetRange, checkLeadingOneFill) {
+  // Tile <4x8> by <4>. A smaller tile shape gets right-aligned to the shape.
+  for (auto [idx, tileOffset] :
+       llvm::enumerate(StaticTileOffsetRange({4, 8}, {4}))) {
+    SmallVector<int64_t> expected = {static_cast<int64_t>(idx) / 2,
+                                     static_cast<int64_t>(idx) % 2 * 4};
+    EXPECT_EQ(tileOffset, expected);
+  }
+  for (auto [idx, tileOffset] :
+       llvm::enumerate(StaticTileOffsetRange({1, 4, 8}, {4}, {2, 1, 0}))) {
+    SmallVector<int64_t> expected = {0, static_cast<int64_t>(idx) % 4,
+                                     (static_cast<int64_t>(idx) / 4) * 4};
+    EXPECT_EQ(tileOffset, expected);
+  }
+}
+
+TEST(StaticTileOffsetRange, checkIterator3DPermutation) {
+  // Tile <8x4x2> by <4x2x1> with permutation [1, 0, 2]
+  for (auto [idx, tileOffset] : llvm::enumerate(
+           StaticTileOffsetRange({8, 4, 2}, {4, 2, 1}, {1, 0, 2}))) {
+    SmallVector<int64_t> expected = {((static_cast<int64_t>(idx) / 2) % 2) * 4,
+                                     ((static_cast<int64_t>(idx) / 4) % 2) * 2,
+                                     static_cast<int64_t>(idx) % 2};
+    EXPECT_EQ(tileOffset, expected);
+  }
+
+  // Tile <10x20x30> by <5x10x16> with permutation [2, 0, 1]
+  for (auto [idx, tileOffset] : llvm::enumerate(
+           StaticTileOffsetRange({10, 20, 30}, {5, 10, 15}, {2, 0, 1}))) {
+    SmallVector<int64_t> expected = {((static_cast<int64_t>(idx) / 2) % 2) * 5,
+                                     (static_cast<int64_t>(idx) % 2) * 10,
+                                     (static_cast<int64_t>(idx) / 4) % 2 * 15};
+    EXPECT_EQ(tileOffset, expected);
+  }
+}


        


More information about the Mlir-commits mailing list