[Mlir-commits] [mlir] 203fad4 - [mlir][DialectUtils] Cleanup IndexingUtils and provide more affine variants while reusing implementations

Nicolas Vasilache llvmlistbot at llvm.org
Tue Mar 14 03:45:06 PDT 2023


Author: Nicolas Vasilache
Date: 2023-03-14T03:44:59-07:00
New Revision: 203fad476b7e9318e1c81ff39958994190866901

URL: https://github.com/llvm/llvm-project/commit/203fad476b7e9318e1c81ff39958994190866901
DIFF: https://github.com/llvm/llvm-project/commit/203fad476b7e9318e1c81ff39958994190866901.diff

LOG: [mlir][DialectUtils] Cleanup IndexingUtils and provide more affine variants while reusing implementations

Differential Revision: https://reviews.llvm.org/D145784

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/include/mlir/IR/AffineExpr.h
    mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/lib/Dialect/Utils/IndexingUtils.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 64eea7eded92b..1969bd3a33121 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -23,35 +23,68 @@
 namespace mlir {
 class ArrayAttr;
 
-/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
-int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
-
-/// Given the strides together with a linear index in the dimension
-/// space, returns the vector-space offsets in each dimension for a
-/// de-linearized index.
-SmallVector<int64_t> delinearize(ArrayRef<int64_t> strides,
-                                 int64_t linearIndex);
+//===----------------------------------------------------------------------===//
+// Utils that operate on static integer values.
+//===----------------------------------------------------------------------===//
 
-/// Given a set of sizes, compute and return the strides (i.e. the number of
-/// linear incides to skip along the (k-1) most minor dimensions to get the next
-/// k-slice). This is also the basis that one can use to linearize an n-D offset
-/// confined to `[0 .. sizes]`.
-SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes);
+/// Given a set of sizes, return the suffix product.
+///
+/// When applied to slicing, this is the calculation needed to derive the
+/// strides (i.e. the number of linear indices to skip along the (k-1) most
+/// minor dimensions to get the next k-slice).
+///
+/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
+///
+/// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
+///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
+///
+/// `sizes` elements are asserted to be non-negative.
+///
+/// Return an empty vector if `sizes` is empty.
+SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
+inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
+  return computeSuffixProduct(sizes);
+}
 
-/// Return a vector containing llvm::zip of v1 and v2 multiplied elementwise.
+/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
+///
+/// Return an empty vector if `v1` and `v2` are empty.
 SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
                                            ArrayRef<int64_t> v2);
 
-/// Compute and return the multi-dimensional integral ratio of `subShape` to
-/// the trailing dimensions of `shape`. This represents how many times
-/// `subShape` fits within `shape`.
-/// If integral division is not possible, return std::nullopt.
+/// Return the number of elements of basis (i.e. the max linear index).
+/// Return `0` if `basis` is empty.
+///
+/// `basis` elements are asserted to be non-negative.
+///
+/// Return `0` if `basis` is empty.
+int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
+
+/// Return the linearized index of 'offsets' w.r.t. 'basis'.
+///
+/// `basis` elements are asserted to be non-negative.
+int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
+
+/// Given the strides together with a linear index in the dimension space,
+/// return the vector-space offsets in each dimension for a de-linearized index.
+/// `strides` elements are asserted to be non-negative.
+///
+/// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the
+/// vector of int64_t
+///   `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]`
+SmallVector<int64_t> delinearize(int64_t linearIndex,
+                                 ArrayRef<int64_t> strides);
+
+/// Return the multi-dimensional integral ratio of `subShape` to the trailing
+/// dimensions of `shape`. This represents how many times `subShape` fits
+/// within `shape`. If integral division is not possible, return std::nullopt.
 /// The trailing `subShape.size()` entries of both shapes are assumed (and
-/// enforced) to only contain noonnegative values.
+/// enforced) to only contain non-negative values.
 ///
 /// Examples:
 ///   - shapeRatio({3, 5, 8}, {2, 5, 2}) returns {3, 2, 1}.
-///   - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has higher
+///   - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has
+///   higher
 ///     rank).
 ///   - shapeRatio({42, 2, 10, 32}, {2, 5, 2}) returns {42, 1, 2, 16} which is
 ///     derived as {42(leading shape dim), 2/2, 10/5, 32/2}.
@@ -60,14 +93,96 @@ SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
 std::optional<SmallVector<int64_t>>
 computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape);
 
+//===----------------------------------------------------------------------===//
+// Utils that operate on AffineExpr.
+//===----------------------------------------------------------------------===//
+
+/// Given a set of sizes, return the suffix product.
+///
+/// When applied to slicing, this is the calculation needed to derive the
+/// strides (i.e. the number of linear indices to skip along the (k-1) most
+/// minor dimensions to get the next k-slice).
+///
+/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
+///
+/// Assuming `sizes` is `[s0, .. sn]`, return the vector<AffineExpr>
+///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
+///
+/// It is the caller's responsibility to pass proper AffineExpr kind that
+/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
+/// by an AffineDimExpr).
+///
+/// `sizes` elements are expected to bind to non-negative values.
+///
+/// Return an empty vector if `sizes` is empty.
+SmallVector<AffineExpr> computeSuffixProduct(ArrayRef<AffineExpr> sizes);
+inline SmallVector<AffineExpr> computeStrides(ArrayRef<AffineExpr> sizes) {
+  return computeSuffixProduct(sizes);
+}
+
+/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
+///
+/// It is the caller's responsibility to pass proper AffineExpr kind that
+/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
+/// by an AffineDimExpr).
+///
+/// Return an empty vector if `v1` and `v2` are empty.
+SmallVector<AffineExpr> computeElementwiseMul(ArrayRef<AffineExpr> v1,
+                                              ArrayRef<AffineExpr> v2);
+
 /// Return the number of elements of basis (i.e. the max linear index).
 /// Return `0` if `basis` is empty.
-int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
+///
+/// It is the caller's responsibility to pass proper AffineExpr kind that
+/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
+/// by an AffineDimExpr).
+///
+/// `basis` elements are expected to bind to non-negative values.
+///
+/// Return the `0` AffineConstantExpr if `basis` is empty.
+AffineExpr computeMaxLinearIndex(MLIRContext *ctx, ArrayRef<AffineExpr> basis);
+
+/// Return the linearized index of 'offsets' w.r.t. 'basis'.
+///
+/// Assuming `offsets` is `[o0, .. on]` and `basis` is `[b0, .. bn]`, return the
+/// AffineExpr `o0 * b0 + .. + on * bn`.
+///
+/// It is the caller's responsibility to pass proper AffineExpr kind that result
+/// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an
+/// AffineDimExpr).
+///
+/// `basis` elements are expected to bind to non-negative values.
+AffineExpr linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
+                     ArrayRef<AffineExpr> basis);
+AffineExpr linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
+                     ArrayRef<int64_t> basis);
+
+/// Given the strides together with a linear index in the dimension space,
+/// return the vector-space offsets in each dimension for a de-linearized index.
+///
+/// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the
+/// vector of AffineExpr
+///   `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]`
+///
+/// It is the caller's responsibility to pass proper AffineExpr kind that result
+/// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an
+/// AffineDimExpr).
+///
+/// `strides` elements are expected to bind to non-negative values.
+SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
+                                    ArrayRef<AffineExpr> strides);
+SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
+                                    ArrayRef<int64_t> strides);
+
+//===----------------------------------------------------------------------===//
+// Permutation utils.
+//===----------------------------------------------------------------------===//
 
 /// Apply the permutation defined by `permutation` to `inVec`.
 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
-/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
-/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
+/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
+/// vector `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a',
+/// 'b']`.
 template <typename T, unsigned N>
 void applyPermutationToVector(SmallVector<T, N> &inVec,
                               ArrayRef<int64_t> permutation) {
@@ -83,18 +198,11 @@ SmallVector<int64_t> invertPermutationVector(ArrayRef<int64_t> permutation);
 /// Method to check if an interchange vector is a permutation.
 bool isPermutationVector(ArrayRef<int64_t> interchange);
 
-/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
+/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
+// TODO: Port everything relevant to DenseArrayAttr and drop this util.
 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
                                     unsigned dropBack = 0);
 
-/// Computes and returns linearized affine expression w.r.t. `basis`.
-mlir::AffineExpr getLinearAffineExpr(ArrayRef<int64_t> basis, mlir::Builder &b);
-
-/// Given the strides in the dimension space, returns the affine expressions for
-/// vector-space offsets in each dimension for a de-linearized index.
-SmallVector<mlir::AffineExpr>
-getDelinearizedAffineExpr(ArrayRef<int64_t> strides, mlir::Builder &b);
-
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H

diff  --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 1d0688a38003f..74cd5e1e557d5 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -321,13 +321,6 @@ void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
   bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
 }
 
-template <typename AffineExprTy>
-void bindSymbolsList(MLIRContext *ctx, SmallVectorImpl<AffineExprTy> &exprs) {
-  int idx = 0;
-  for (AffineExprTy &e : exprs)
-    e = getAffineSymbolExpr(idx++, ctx);
-}
-
 } // namespace detail
 
 /// Bind a list of AffineExpr references to DimExpr at positions:
@@ -337,6 +330,13 @@ void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) {
   detail::bindDims<0>(ctx, exprs...);
 }
 
+template <typename AffineExprTy>
+void bindDimsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
+  int idx = 0;
+  for (AffineExprTy &e : exprs)
+    e = getAffineDimExpr(idx++, ctx);
+}
+
 /// Bind a list of AffineExpr references to SymbolExpr at positions:
 ///   [0 .. sizeof...(exprs)]
 template <typename... AffineExprTy>
@@ -344,6 +344,13 @@ void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) {
   detail::bindSymbols<0>(ctx, exprs...);
 }
 
+template <typename AffineExprTy>
+void bindSymbolsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
+  int idx = 0;
+  for (AffineExprTy &e : exprs)
+    e = getAffineSymbolExpr(idx++, ctx);
+}
+
 } // namespace mlir
 
 namespace llvm {

diff  --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index d24a54c7cbc88..c177e50e2c3c9 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -103,7 +103,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
       loc, DenseElementsAttr::get(vecType, initValueAttr));
   SmallVector<int64_t> strides = computeStrides(shape);
   for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
-    SmallVector<int64_t> positions = delinearize(strides, linearIndex);
+    SmallVector<int64_t> positions = delinearize(linearIndex, strides);
     SmallVector<Value> operands;
     for (Value input : op->getOperands())
       operands.push_back(

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 1e5b317caa941..d6834441e92bb 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -89,7 +89,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
   SmallVector<int64_t> strides = computeStrides(shape);
   for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
-    SmallVector<int64_t> positions = delinearize(strides, linearIndex);
+    SmallVector<int64_t> positions = delinearize(linearIndex, strides);
     SmallVector<Value> operands;
     for (auto input : op->getOperands())
       operands.push_back(

diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 0d170f985fc30..6cbb50b0425b3 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -134,7 +134,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
   SmallVector<Value> results(maxIndex);
 
   for (int64_t i = 0; i < maxIndex; ++i) {
-    auto offsets = delinearize(strides, i);
+    auto offsets = delinearize(i, strides);
 
     SmallVector<Value> extracted(expandedOperands.size());
     for (const auto &tuple : llvm::enumerate(expandedOperands))
@@ -152,7 +152,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
 
   for (int64_t i = 0; i < maxIndex; ++i)
     result = builder.create<vector::InsertOp>(results[i], result,
-                                              delinearize(strides, i));
+                                              delinearize(i, strides));
 
   // Reshape back to the original vector shape.
   return builder.create<vector::ShapeCastOp>(

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 8ef16d5eeaec4..72e6a2925e83f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -75,7 +75,7 @@ struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
     SmallVector<OpFoldResult> values(2 * sourceRank + 1);
     SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
 
-    detail::bindSymbolsList(rewriter.getContext(), symbols);
+    bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
     AffineExpr expr = symbols.front();
     values[0] = ShapedType::isDynamic(sourceOffset)
                     ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
@@ -262,10 +262,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
   auto sourceType = source.getType().cast<MemRefType>();
   auto [strides, offset] = getStridesAndOffset(sourceType);
 
-  OpFoldResult origStride =
-      ShapedType::isDynamic(strides[groupId])
-          ? origStrides[groupId]
-          : builder.getIndexAttr(strides[groupId]);
+  OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
+                                ? origStrides[groupId]
+                                : builder.getIndexAttr(strides[groupId]);
 
   // Apply the original stride to all the strides.
   int64_t doneStrideIdx = 0;

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 129ac41233058..73264bc76933a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -54,24 +55,25 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
                                 memref::ExpandShapeOp expandShapeOp,
                                 ValueRange indices,
                                 SmallVectorImpl<Value> &sourceIndices) {
-  for (SmallVector<int64_t, 2> groups :
-       expandShapeOp.getReassociationIndices()) {
+  MLIRContext *ctx = rewriter.getContext();
+  for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
     assert(!groups.empty() && "association indices groups cannot be empty");
-    unsigned groupSize = groups.size();
-    SmallVector<int64_t> suffixProduct(groupSize);
-    // Calculate suffix product of dimension sizes for all dimensions of expand
-    // shape op result.
-    suffixProduct[groupSize - 1] = 1;
-    for (unsigned i = groupSize - 1; i > 0; i--)
-      suffixProduct[i - 1] =
-          suffixProduct[i] *
-          expandShapeOp.getType().cast<MemRefType>().getDimSize(groups[i]);
-    SmallVector<Value> dynamicIndices(groupSize);
-    for (unsigned i = 0; i < groupSize; i++)
-      dynamicIndices[i] = indices[groups[i]];
+    int64_t groupSize = groups.size();
+
     // Construct the expression for the index value w.r.t to expand shape op
     // source corresponding the indices wrt to expand shape op result.
-    AffineExpr srcIndexExpr = getLinearAffineExpr(suffixProduct, rewriter);
+    SmallVector<int64_t> sizes(groupSize);
+    for (int64_t i = 0; i < groupSize; ++i)
+      sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
+    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+    SmallVector<AffineExpr> dims(groupSize);
+    bindDimsList(ctx, MutableArrayRef{dims});
+    AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
+
+    /// Apply permutation and create AffineApplyOp.
+    SmallVector<Value> dynamicIndices(groupSize);
+    for (int64_t i = 0; i < groupSize; i++)
+      dynamicIndices[i] = indices[groups[i]];
     sourceIndices.push_back(rewriter.create<AffineApplyOp>(
         loc,
         AffineMap::get(/*numDims=*/groupSize, /*numSymbols=*/0, srcIndexExpr),
@@ -98,35 +100,39 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
                                   memref::CollapseShapeOp collapseShapeOp,
                                   ValueRange indices,
                                   SmallVectorImpl<Value> &sourceIndices) {
-  unsigned cnt = 0;
+  int64_t cnt = 0;
   SmallVector<Value> tmp(indices.size());
   SmallVector<Value> dynamicIndices;
-  for (SmallVector<int64_t, 2> groups :
-       collapseShapeOp.getReassociationIndices()) {
+  for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
     assert(!groups.empty() && "association indices groups cannot be empty");
     dynamicIndices.push_back(indices[cnt++]);
-    unsigned groupSize = groups.size();
-    SmallVector<int64_t> suffixProduct(groupSize);
+    int64_t groupSize = groups.size();
+
     // Calculate suffix product for all collapse op source dimension sizes.
-    suffixProduct[groupSize - 1] = 1;
-    for (unsigned i = groupSize - 1; i > 0; i--)
-      suffixProduct[i - 1] =
-          suffixProduct[i] * collapseShapeOp.getSrcType().getDimSize(groups[i]);
+    SmallVector<int64_t> sizes(groupSize);
+    for (int64_t i = 0; i < groupSize; ++i)
+      sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
+    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+
     // Derive the index values along all dimensions of the source corresponding
     // to the index wrt to collapsed shape op output.
-    SmallVector<AffineExpr, 4> srcIndexExpr =
-        getDelinearizedAffineExpr(suffixProduct, rewriter);
-    for (unsigned i = 0; i < groupSize; i++)
+    auto d0 = rewriter.getAffineDimExpr(0);
+    SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
+
+    // Construct the AffineApplyOp for each delinearizingExpr.
+    for (int64_t i = 0; i < groupSize; i++)
       sourceIndices.push_back(rewriter.create<AffineApplyOp>(
-          loc, AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, srcIndexExpr[i]),
+          loc,
+          AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
+                         delinearizingExprs[i]),
           dynamicIndices));
     dynamicIndices.clear();
   }
   if (collapseShapeOp.getReassociationIndices().empty()) {
     auto zeroAffineMap = rewriter.getConstantAffineMap(0);
-    unsigned srcRank =
+    int64_t srcRank =
         collapseShapeOp.getViewSource().getType().cast<MemRefType>().getRank();
-    for (unsigned i = 0; i < srcRank; i++)
+    for (int64_t i = 0; i < srcRank; i++)
       sourceIndices.push_back(
           rewriter.create<AffineApplyOp>(loc, zeroAffineMap, dynamicIndices));
   }
@@ -157,9 +163,9 @@ resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter,
   SmallVector<Value> useIndices;
   // Check if this is rank-reducing case. Then for every unit-dim size add a
   // zero to the indices.
-  unsigned resultDim = 0;
+  int64_t resultDim = 0;
   llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
-  for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
+  for (auto dim : llvm::seq<int64_t>(0, subViewOp.getSourceType().getRank())) {
     if (unusedDims.test(dim))
       useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
     else
@@ -171,7 +177,7 @@ resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter,
   for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
     SmallVector<Value> dynamicOperands;
     AffineExpr expr = rewriter.getAffineDimExpr(0);
-    unsigned numSymbols = 0;
+    int64_t numSymbols = 0;
     dynamicOperands.push_back(useIndices[index]);
 
     // Multiply the stride;
@@ -378,7 +384,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
         affineMap, indices, loadOp.getLoc(), rewriter);
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
-  SmallVector<Value, 4> sourceIndices;
+  SmallVector<Value> sourceIndices;
   if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp,
                                          indices, sourceIndices)))
     return failure();
@@ -424,7 +430,7 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
         affineMap, indices, loadOp.getLoc(), rewriter);
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
-  SmallVector<Value, 4> sourceIndices;
+  SmallVector<Value> sourceIndices;
   if (failed(resolveSourceIndicesExpandShape(
           loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
     return failure();
@@ -456,7 +462,7 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
         affineMap, indices, loadOp.getLoc(), rewriter);
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
-  SmallVector<Value, 4> sourceIndices;
+  SmallVector<Value> sourceIndices;
   if (failed(resolveSourceIndicesCollapseShape(
           loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
     return failure();
@@ -488,7 +494,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
         affineMap, indices, storeOp.getLoc(), rewriter);
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
-  SmallVector<Value, 4> sourceIndices;
+  SmallVector<Value> sourceIndices;
   if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp,
                                          indices, sourceIndices)))
     return failure();
@@ -533,7 +539,7 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
         affineMap, indices, storeOp.getLoc(), rewriter);
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
-  SmallVector<Value, 4> sourceIndices;
+  SmallVector<Value> sourceIndices;
   if (failed(resolveSourceIndicesExpandShape(
           storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
     return failure();
@@ -566,7 +572,7 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
         affineMap, indices, storeOp.getLoc(), rewriter);
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
-  SmallVector<Value, 4> sourceIndices;
+  SmallVector<Value> sourceIndices;
   if (failed(resolveSourceIndicesCollapseShape(
           storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
     return failure();

diff  --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 4563d74acfde4..eb86e0f782a78 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -11,27 +11,100 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "llvm/ADT/STLExtras.h"
 
 #include <numeric>
 #include <optional>
 
 using namespace mlir;
 
-SmallVector<int64_t> mlir::computeStrides(ArrayRef<int64_t> sizes) {
-  SmallVector<int64_t> strides(sizes.size(), 1);
+template <typename ExprType>
+SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
+                                               ExprType unit) {
+  if (sizes.empty())
+    return {};
+  SmallVector<ExprType> strides(sizes.size(), unit);
   for (int64_t r = strides.size() - 2; r >= 0; --r)
     strides[r] = strides[r + 1] * sizes[r + 1];
   return strides;
 }
 
-SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
-                                                 ArrayRef<int64_t> v2) {
-  SmallVector<int64_t> result;
-  for (auto it : llvm::zip(v1, v2))
+template <typename ExprType>
+SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
+                                                ArrayRef<ExprType> v2) {
+  // Early exit if both are empty, let zip_equal fail if only 1 is empty.
+  if (v1.empty() && v2.empty())
+    return {};
+  SmallVector<ExprType> result;
+  for (auto it : llvm::zip_equal(v1, v2))
     result.push_back(std::get<0>(it) * std::get<1>(it));
   return result;
 }
 
+template <typename ExprType>
+ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis,
+                       ExprType zero) {
+  assert(offsets.size() == basis.size());
+  ExprType linearIndex = zero;
+  for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
+    linearIndex = linearIndex + offsets[idx] * basis[idx];
+  return linearIndex;
+}
+
+template <typename ExprType, typename DivOpTy>
+SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
+                                      ArrayRef<ExprType> strides,
+                                      DivOpTy divOp) {
+  int64_t rank = strides.size();
+  SmallVector<ExprType> offsets(rank);
+  for (int64_t r = 0; r < rank; ++r) {
+    offsets[r] = divOp(linearIndex, strides[r]);
+    linearIndex = linearIndex % strides[r];
+  }
+  return offsets;
+}
+
+//===----------------------------------------------------------------------===//
+// Utils that operate on static integer values.
+//===----------------------------------------------------------------------===//
+
+SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
+  assert(llvm::all_of(sizes, [](int64_t s) { return s > 0; }) &&
+         "sizes must be nonnegative");
+  int64_t unit = 1;
+  return ::computeSuffixProductImpl(sizes, unit);
+}
+
+SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
+                                                 ArrayRef<int64_t> v2) {
+  return computeElementwiseMulImpl(v1, v2);
+}
+
+int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
+  assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
+         "basis must be nonnegative");
+  if (basis.empty())
+    return 0;
+  return std::accumulate(basis.begin(), basis.end(), 1,
+                         std::multiplies<int64_t>());
+}
+
+int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
+  assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
+         "basis must be nonnegative");
+  int64_t zero = 0;
+  return linearizeImpl(offsets, basis, zero);
+}
+
+SmallVector<int64_t> mlir::delinearize(int64_t linearIndex,
+                                       ArrayRef<int64_t> strides) {
+  assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) &&
+         "strides must be nonnegative");
+  return delinearizeImpl(linearIndex, strides,
+                         [](int64_t e1, int64_t e2) { return e1 / e2; });
+}
+
 std::optional<SmallVector<int64_t>>
 mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
   if (shape.size() < subShape.size())
@@ -60,35 +133,67 @@ mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
   return SmallVector<int64_t>{result.rbegin(), result.rend()};
 }
 
-int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
-  assert(offsets.size() == basis.size());
-  int64_t linearIndex = 0;
-  for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
-    linearIndex += offsets[idx] * basis[idx];
-  return linearIndex;
+//===----------------------------------------------------------------------===//
+// Utils that operate on AffineExpr.
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) {
+  if (sizes.empty())
+    return {};
+  AffineExpr unit = getAffineConstantExpr(1, sizes.front().getContext());
+  return ::computeSuffixProductImpl(sizes, unit);
 }
 
-llvm::SmallVector<int64_t> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
-                                             int64_t index) {
-  int64_t rank = sliceStrides.size();
-  SmallVector<int64_t> vectorOffsets(rank);
-  for (int64_t r = 0; r < rank; ++r) {
-    assert(sliceStrides[r] > 0);
-    vectorOffsets[r] = index / sliceStrides[r];
-    index %= sliceStrides[r];
-  }
-  return vectorOffsets;
+SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1,
+                                                    ArrayRef<AffineExpr> v2) {
+  return computeElementwiseMulImpl(v1, v2);
 }
 
-int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
+AffineExpr mlir::computeMaxLinearIndex(MLIRContext *ctx,
+                                       ArrayRef<AffineExpr> basis) {
   if (basis.empty())
-    return 0;
-  return std::accumulate(basis.begin(), basis.end(), 1,
-                         std::multiplies<int64_t>());
+    return getAffineConstantExpr(0, ctx);
+  return std::accumulate(basis.begin(), basis.end(),
+                         getAffineConstantExpr(1, ctx),
+                         std::multiplies<AffineExpr>());
+}
+
+AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
+                           ArrayRef<AffineExpr> basis) {
+  AffineExpr zero = getAffineConstantExpr(0, ctx);
+  return linearizeImpl(offsets, basis, zero);
 }
 
-llvm::SmallVector<int64_t>
+AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
+                           ArrayRef<int64_t> basis) {
+  SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
+      basis, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
+  return linearize(ctx, offsets, basisExprs);
+}
+
+SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
+                                          ArrayRef<AffineExpr> strides) {
+  return delinearizeImpl(
+      linearIndex, strides,
+      [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); });
+}
+
+SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
+                                          ArrayRef<int64_t> strides) {
+  MLIRContext *ctx = linearIndex.getContext();
+  SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
+      strides, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
+  return delinearize(linearIndex, ArrayRef<AffineExpr>{basisExprs});
+}
+
+//===----------------------------------------------------------------------===//
+// Permutation utils.
+//===----------------------------------------------------------------------===//
+
+SmallVector<int64_t>
 mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
+  assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) &&
+         "permutation must be non-negative");
   SmallVector<int64_t> inversion(permutation.size());
   for (const auto &pos : llvm::enumerate(permutation)) {
     inversion[pos.value()] = pos.index();
@@ -97,6 +202,8 @@ mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
 }
 
 bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
+  assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) &&
+         "permutation must be non-negative");
   llvm::SmallDenseSet<int64_t, 4> seenVals;
   for (auto val : interchange) {
     if (seenVals.count(val))
@@ -106,9 +213,9 @@ bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
   return seenVals.size() == interchange.size();
 }
 
-llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
-                                                unsigned dropFront,
-                                                unsigned dropBack) {
+SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
+                                          unsigned dropFront,
+                                          unsigned dropBack) {
   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
   auto range = arrayAttr.getAsRange<IntegerAttr>();
   SmallVector<int64_t> res;
@@ -118,26 +225,3 @@ llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
     res.push_back((*it).getValue().getSExtValue());
   return res;
 }
-
-mlir::AffineExpr mlir::getLinearAffineExpr(ArrayRef<int64_t> basis,
-                                           mlir::Builder &b) {
-  AffineExpr resultExpr = b.getAffineDimExpr(0);
-  resultExpr = resultExpr * basis[0];
-  for (unsigned i = 1; i < basis.size(); i++)
-    resultExpr = resultExpr + b.getAffineDimExpr(i) * basis[i];
-  return resultExpr;
-}
-
-llvm::SmallVector<mlir::AffineExpr>
-mlir::getDelinearizedAffineExpr(mlir::ArrayRef<int64_t> strides, Builder &b) {
-  AffineExpr resultExpr = b.getAffineDimExpr(0);
-  int64_t rank = strides.size();
-  SmallVector<AffineExpr> vectorOffsets(rank);
-  vectorOffsets[0] = resultExpr.floorDiv(strides[0]);
-  resultExpr = resultExpr % strides[0];
-  for (unsigned i = 1; i < rank; i++) {
-    vectorOffsets[i] = resultExpr.floorDiv(strides[i]);
-    resultExpr = resultExpr % strides[i];
-  }
-  return vectorOffsets;
-}

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cabe7d7d95563..0180c7192fb2c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1558,7 +1558,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
         getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
   }
   std::reverse(newStrides.begin(), newStrides.end());
-  SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
+  SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
   extractOp->setAttr(ExtractOp::getPositionAttrStrName(),

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ef6804f5511ae..9e5d787856b1f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -457,7 +457,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 
     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
          ++linearIdx) {
-      auto extractIdxs = delinearize(prunedInStrides, linearIdx);
+      auto extractIdxs = delinearize(linearIdx, prunedInStrides);
       SmallVector<int64_t> insertIdxs(extractIdxs);
       applyPermutationToVector(insertIdxs, prunedTransp);
       Value extractOp =
@@ -588,8 +588,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
         loc, resType, rewriter.getZeroAttr(resType));
     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
       auto pos = rewriter.getI64ArrayAttr(d);
-      Value x =
-          rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
+      Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
       Value r = nullptr;
       if (acc)

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b2d00255b0d7f..2cbafe19641cc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -31,7 +31,7 @@ using namespace mlir::vector;
 static SmallVector<int64_t> getVectorOffset(ArrayRef<int64_t> ratioStrides,
                                             int64_t index,
                                             ArrayRef<int64_t> targetShape) {
-  return computeElementwiseMul(delinearize(ratioStrides, index), targetShape);
+  return computeElementwiseMul(delinearize(index, ratioStrides), targetShape);
 }
 
 /// A functor that accomplishes the same thing as `getVectorOffset` but


        


More information about the Mlir-commits mailing list