[Mlir-commits] [mlir] 203fad4 - [mlir][DialectUtils] Cleanup IndexingUtils and provide more affine variants while reusing implementations
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Mar 14 03:45:06 PDT 2023
Author: Nicolas Vasilache
Date: 2023-03-14T03:44:59-07:00
New Revision: 203fad476b7e9318e1c81ff39958994190866901
URL: https://github.com/llvm/llvm-project/commit/203fad476b7e9318e1c81ff39958994190866901
DIFF: https://github.com/llvm/llvm-project/commit/203fad476b7e9318e1c81ff39958994190866901.diff
LOG: [mlir][DialectUtils] Cleanup IndexingUtils and provide more affine variants while reusing implementations
Differential Revision: https://reviews.llvm.org/D145784
Added:
Modified:
mlir/include/mlir/Dialect/Utils/IndexingUtils.h
mlir/include/mlir/IR/AffineExpr.h
mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/lib/Dialect/Utils/IndexingUtils.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 64eea7eded92b..1969bd3a33121 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -23,35 +23,68 @@
namespace mlir {
class ArrayAttr;
-/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
-int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
-
-/// Given the strides together with a linear index in the dimension
-/// space, returns the vector-space offsets in each dimension for a
-/// de-linearized index.
-SmallVector<int64_t> delinearize(ArrayRef<int64_t> strides,
- int64_t linearIndex);
+//===----------------------------------------------------------------------===//
+// Utils that operate on static integer values.
+//===----------------------------------------------------------------------===//
-/// Given a set of sizes, compute and return the strides (i.e. the number of
-/// linear incides to skip along the (k-1) most minor dimensions to get the next
-/// k-slice). This is also the basis that one can use to linearize an n-D offset
-/// confined to `[0 .. sizes]`.
-SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes);
+/// Given a set of sizes, return the suffix product.
+///
+/// When applied to slicing, this is the calculation needed to derive the
+/// strides (i.e. the number of linear indices to skip along the (k-1) most
+/// minor dimensions to get the next k-slice).
+///
+/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
+///
+/// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
+/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
+///
+/// `sizes` elements are asserted to be non-negative.
+///
+/// Return an empty vector if `sizes` is empty.
+SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
+inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
+ return computeSuffixProduct(sizes);
+}
-/// Return a vector containing llvm::zip of v1 and v2 multiplied elementwise.
+/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
+///
+/// Return an empty vector if `v1` and `v2` are empty.
SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
ArrayRef<int64_t> v2);
-/// Compute and return the multi-dimensional integral ratio of `subShape` to
-/// the trailing dimensions of `shape`. This represents how many times
-/// `subShape` fits within `shape`.
-/// If integral division is not possible, return std::nullopt.
+/// Return the number of elements of basis (i.e. the max linear index).
+/// Return `0` if `basis` is empty.
+///
+/// `basis` elements are asserted to be non-negative.
+///
+/// Return `0` if `basis` is empty.
+int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
+
+/// Return the linearized index of 'offsets' w.r.t. 'basis'.
+///
+/// `basis` elements are asserted to be non-negative.
+int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
+
+/// Given the strides together with a linear index in the dimension space,
+/// return the vector-space offsets in each dimension for a de-linearized index.
+/// `strides` elements are asserted to be non-negative.
+///
+/// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the
+/// vector of int64_t
+/// `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]`
+SmallVector<int64_t> delinearize(int64_t linearIndex,
+ ArrayRef<int64_t> strides);
+
+/// Return the multi-dimensional integral ratio of `subShape` to the trailing
+/// dimensions of `shape`. This represents how many times `subShape` fits
+/// within `shape`. If integral division is not possible, return std::nullopt.
/// The trailing `subShape.size()` entries of both shapes are assumed (and
-/// enforced) to only contain noonnegative values.
+/// enforced) to only contain non-negative values.
///
/// Examples:
/// - shapeRatio({3, 5, 8}, {2, 5, 2}) returns {3, 2, 1}.
-/// - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has higher
+/// - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has
+/// higher
/// rank).
/// - shapeRatio({42, 2, 10, 32}, {2, 5, 2}) returns {42, 1, 2, 16} which is
/// derived as {42(leading shape dim), 2/2, 10/5, 32/2}.
@@ -60,14 +93,96 @@ SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
std::optional<SmallVector<int64_t>>
computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape);
+//===----------------------------------------------------------------------===//
+// Utils that operate on AffineExpr.
+//===----------------------------------------------------------------------===//
+
+/// Given a set of sizes, return the suffix product.
+///
+/// When applied to slicing, this is the calculation needed to derive the
+/// strides (i.e. the number of linear indices to skip along the (k-1) most
+/// minor dimensions to get the next k-slice).
+///
+/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
+///
+/// Assuming `sizes` is `[s0, .. sn]`, return the vector<AffineExpr>
+/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
+///
+/// It is the caller's responsibility to pass proper AffineExpr kind that
+/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
+/// by an AffineDimExpr).
+///
+/// `sizes` elements are expected to bind to non-negative values.
+///
+/// Return an empty vector if `sizes` is empty.
+SmallVector<AffineExpr> computeSuffixProduct(ArrayRef<AffineExpr> sizes);
+inline SmallVector<AffineExpr> computeStrides(ArrayRef<AffineExpr> sizes) {
+ return computeSuffixProduct(sizes);
+}
+
+/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
+///
+/// It is the caller's responsibility to pass proper AffineExpr kind that
+/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
+/// by an AffineDimExpr).
+///
+/// Return an empty vector if `v1` and `v2` are empty.
+SmallVector<AffineExpr> computeElementwiseMul(ArrayRef<AffineExpr> v1,
+ ArrayRef<AffineExpr> v2);
+
/// Return the number of elements of basis (i.e. the max linear index).
/// Return `0` if `basis` is empty.
-int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
+///
+/// It is the caller's responsibility to pass proper AffineExpr kind that
+/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
+/// by an AffineDimExpr).
+///
+/// `basis` elements are expected to bind to non-negative values.
+///
+/// Return the `0` AffineConstantExpr if `basis` is empty.
+AffineExpr computeMaxLinearIndex(MLIRContext *ctx, ArrayRef<AffineExpr> basis);
+
+/// Return the linearized index of 'offsets' w.r.t. 'basis'.
+///
+/// Assuming `offsets` is `[o0, .. on]` and `basis` is `[b0, .. bn]`, return the
+/// AffineExpr `o0 * b0 + .. + on * bn`.
+///
+/// It is the caller's responsibility to pass proper AffineExpr kind that result
+/// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an
+/// AffineDimExpr).
+///
+/// `basis` elements are expected to bind to non-negative values.
+AffineExpr linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
+ ArrayRef<AffineExpr> basis);
+AffineExpr linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
+ ArrayRef<int64_t> basis);
+
+/// Given the strides together with a linear index in the dimension space,
+/// return the vector-space offsets in each dimension for a de-linearized index.
+///
+/// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the
+/// vector of AffineExpr
+/// `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]`
+///
+/// It is the caller's responsibility to pass proper AffineExpr kind that result
+/// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an
+/// AffineDimExpr).
+///
+/// `strides` elements are expected to bind to non-negative values.
+SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
+ ArrayRef<AffineExpr> strides);
+SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
+ ArrayRef<int64_t> strides);
+
+//===----------------------------------------------------------------------===//
+// Permutation utils.
+//===----------------------------------------------------------------------===//
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
-/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
-/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
+/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
+/// vector `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a',
+/// 'b']`.
template <typename T, unsigned N>
void applyPermutationToVector(SmallVector<T, N> &inVec,
ArrayRef<int64_t> permutation) {
@@ -83,18 +198,11 @@ SmallVector<int64_t> invertPermutationVector(ArrayRef<int64_t> permutation);
/// Method to check if an interchange vector is a permutation.
bool isPermutationVector(ArrayRef<int64_t> interchange);
-/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
+/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
+// TODO: Port everything relevant to DenseArrayAttr and drop this util.
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
unsigned dropBack = 0);
-/// Computes and returns linearized affine expression w.r.t. `basis`.
-mlir::AffineExpr getLinearAffineExpr(ArrayRef<int64_t> basis, mlir::Builder &b);
-
-/// Given the strides in the dimension space, returns the affine expressions for
-/// vector-space offsets in each dimension for a de-linearized index.
-SmallVector<mlir::AffineExpr>
-getDelinearizedAffineExpr(ArrayRef<int64_t> strides, mlir::Builder &b);
-
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 1d0688a38003f..74cd5e1e557d5 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -321,13 +321,6 @@ void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
}
-template <typename AffineExprTy>
-void bindSymbolsList(MLIRContext *ctx, SmallVectorImpl<AffineExprTy> &exprs) {
- int idx = 0;
- for (AffineExprTy &e : exprs)
- e = getAffineSymbolExpr(idx++, ctx);
-}
-
} // namespace detail
/// Bind a list of AffineExpr references to DimExpr at positions:
@@ -337,6 +330,13 @@ void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) {
detail::bindDims<0>(ctx, exprs...);
}
+template <typename AffineExprTy>
+void bindDimsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
+ int idx = 0;
+ for (AffineExprTy &e : exprs)
+ e = getAffineDimExpr(idx++, ctx);
+}
+
/// Bind a list of AffineExpr references to SymbolExpr at positions:
/// [0 .. sizeof...(exprs)]
template <typename... AffineExprTy>
@@ -344,6 +344,13 @@ void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) {
detail::bindSymbols<0>(ctx, exprs...);
}
+template <typename AffineExprTy>
+void bindSymbolsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
+ int idx = 0;
+ for (AffineExprTy &e : exprs)
+ e = getAffineSymbolExpr(idx++, ctx);
+}
+
} // namespace mlir
namespace llvm {
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index d24a54c7cbc88..c177e50e2c3c9 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -103,7 +103,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
loc, DenseElementsAttr::get(vecType, initValueAttr));
SmallVector<int64_t> strides = computeStrides(shape);
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
- SmallVector<int64_t> positions = delinearize(strides, linearIndex);
+ SmallVector<int64_t> positions = delinearize(linearIndex, strides);
SmallVector<Value> operands;
for (Value input : op->getOperands())
operands.push_back(
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 1e5b317caa941..d6834441e92bb 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -89,7 +89,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
SmallVector<int64_t> strides = computeStrides(shape);
for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
- SmallVector<int64_t> positions = delinearize(strides, linearIndex);
+ SmallVector<int64_t> positions = delinearize(linearIndex, strides);
SmallVector<Value> operands;
for (auto input : op->getOperands())
operands.push_back(
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 0d170f985fc30..6cbb50b0425b3 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -134,7 +134,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
SmallVector<Value> results(maxIndex);
for (int64_t i = 0; i < maxIndex; ++i) {
- auto offsets = delinearize(strides, i);
+ auto offsets = delinearize(i, strides);
SmallVector<Value> extracted(expandedOperands.size());
for (const auto &tuple : llvm::enumerate(expandedOperands))
@@ -152,7 +152,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
for (int64_t i = 0; i < maxIndex; ++i)
result = builder.create<vector::InsertOp>(results[i], result,
- delinearize(strides, i));
+ delinearize(i, strides));
// Reshape back to the original vector shape.
return builder.create<vector::ShapeCastOp>(
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 8ef16d5eeaec4..72e6a2925e83f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -75,7 +75,7 @@ struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
SmallVector<OpFoldResult> values(2 * sourceRank + 1);
SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
- detail::bindSymbolsList(rewriter.getContext(), symbols);
+ bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
AffineExpr expr = symbols.front();
values[0] = ShapedType::isDynamic(sourceOffset)
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
@@ -262,10 +262,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
auto sourceType = source.getType().cast<MemRefType>();
auto [strides, offset] = getStridesAndOffset(sourceType);
- OpFoldResult origStride =
- ShapedType::isDynamic(strides[groupId])
- ? origStrides[groupId]
- : builder.getIndexAttr(strides[groupId]);
+ OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
+ ? origStrides[groupId]
+ : builder.getIndexAttr(strides[groupId]);
// Apply the original stride to all the strides.
int64_t doneStrideIdx = 0;
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 129ac41233058..73264bc76933a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -54,24 +55,25 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
memref::ExpandShapeOp expandShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
- for (SmallVector<int64_t, 2> groups :
- expandShapeOp.getReassociationIndices()) {
+ MLIRContext *ctx = rewriter.getContext();
+ for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
assert(!groups.empty() && "association indices groups cannot be empty");
- unsigned groupSize = groups.size();
- SmallVector<int64_t> suffixProduct(groupSize);
- // Calculate suffix product of dimension sizes for all dimensions of expand
- // shape op result.
- suffixProduct[groupSize - 1] = 1;
- for (unsigned i = groupSize - 1; i > 0; i--)
- suffixProduct[i - 1] =
- suffixProduct[i] *
- expandShapeOp.getType().cast<MemRefType>().getDimSize(groups[i]);
- SmallVector<Value> dynamicIndices(groupSize);
- for (unsigned i = 0; i < groupSize; i++)
- dynamicIndices[i] = indices[groups[i]];
+ int64_t groupSize = groups.size();
+
// Construct the expression for the index value w.r.t to expand shape op
// source corresponding the indices wrt to expand shape op result.
- AffineExpr srcIndexExpr = getLinearAffineExpr(suffixProduct, rewriter);
+ SmallVector<int64_t> sizes(groupSize);
+ for (int64_t i = 0; i < groupSize; ++i)
+ sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
+ SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+ SmallVector<AffineExpr> dims(groupSize);
+ bindDimsList(ctx, MutableArrayRef{dims});
+ AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
+
+ /// Apply permutation and create AffineApplyOp.
+ SmallVector<Value> dynamicIndices(groupSize);
+ for (int64_t i = 0; i < groupSize; i++)
+ dynamicIndices[i] = indices[groups[i]];
sourceIndices.push_back(rewriter.create<AffineApplyOp>(
loc,
AffineMap::get(/*numDims=*/groupSize, /*numSymbols=*/0, srcIndexExpr),
@@ -98,35 +100,39 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
- unsigned cnt = 0;
+ int64_t cnt = 0;
SmallVector<Value> tmp(indices.size());
SmallVector<Value> dynamicIndices;
- for (SmallVector<int64_t, 2> groups :
- collapseShapeOp.getReassociationIndices()) {
+ for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
assert(!groups.empty() && "association indices groups cannot be empty");
dynamicIndices.push_back(indices[cnt++]);
- unsigned groupSize = groups.size();
- SmallVector<int64_t> suffixProduct(groupSize);
+ int64_t groupSize = groups.size();
+
// Calculate suffix product for all collapse op source dimension sizes.
- suffixProduct[groupSize - 1] = 1;
- for (unsigned i = groupSize - 1; i > 0; i--)
- suffixProduct[i - 1] =
- suffixProduct[i] * collapseShapeOp.getSrcType().getDimSize(groups[i]);
+ SmallVector<int64_t> sizes(groupSize);
+ for (int64_t i = 0; i < groupSize; ++i)
+ sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
+ SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+
// Derive the index values along all dimensions of the source corresponding
// to the index wrt to collapsed shape op output.
- SmallVector<AffineExpr, 4> srcIndexExpr =
- getDelinearizedAffineExpr(suffixProduct, rewriter);
- for (unsigned i = 0; i < groupSize; i++)
+ auto d0 = rewriter.getAffineDimExpr(0);
+ SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
+
+ // Construct the AffineApplyOp for each delinearizingExpr.
+ for (int64_t i = 0; i < groupSize; i++)
sourceIndices.push_back(rewriter.create<AffineApplyOp>(
- loc, AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, srcIndexExpr[i]),
+ loc,
+ AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
+ delinearizingExprs[i]),
dynamicIndices));
dynamicIndices.clear();
}
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
- unsigned srcRank =
+ int64_t srcRank =
collapseShapeOp.getViewSource().getType().cast<MemRefType>().getRank();
- for (unsigned i = 0; i < srcRank; i++)
+ for (int64_t i = 0; i < srcRank; i++)
sourceIndices.push_back(
rewriter.create<AffineApplyOp>(loc, zeroAffineMap, dynamicIndices));
}
@@ -157,9 +163,9 @@ resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter,
SmallVector<Value> useIndices;
// Check if this is rank-reducing case. Then for every unit-dim size add a
// zero to the indices.
- unsigned resultDim = 0;
+ int64_t resultDim = 0;
llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
- for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
+ for (auto dim : llvm::seq<int64_t>(0, subViewOp.getSourceType().getRank())) {
if (unusedDims.test(dim))
useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
else
@@ -171,7 +177,7 @@ resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter,
for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
SmallVector<Value> dynamicOperands;
AffineExpr expr = rewriter.getAffineDimExpr(0);
- unsigned numSymbols = 0;
+ int64_t numSymbols = 0;
dynamicOperands.push_back(useIndices[index]);
// Multiply the stride;
@@ -378,7 +384,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, loadOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
- SmallVector<Value, 4> sourceIndices;
+ SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp,
indices, sourceIndices)))
return failure();
@@ -424,7 +430,7 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, loadOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
- SmallVector<Value, 4> sourceIndices;
+ SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesExpandShape(
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
return failure();
@@ -456,7 +462,7 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, loadOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
- SmallVector<Value, 4> sourceIndices;
+ SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesCollapseShape(
loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
return failure();
@@ -488,7 +494,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, storeOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
- SmallVector<Value, 4> sourceIndices;
+ SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp,
indices, sourceIndices)))
return failure();
@@ -533,7 +539,7 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, storeOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
- SmallVector<Value, 4> sourceIndices;
+ SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesExpandShape(
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
return failure();
@@ -566,7 +572,7 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, storeOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
- SmallVector<Value, 4> sourceIndices;
+ SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesCollapseShape(
storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
return failure();
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 4563d74acfde4..eb86e0f782a78 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -11,27 +11,100 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "llvm/ADT/STLExtras.h"
#include <numeric>
#include <optional>
using namespace mlir;
-SmallVector<int64_t> mlir::computeStrides(ArrayRef<int64_t> sizes) {
- SmallVector<int64_t> strides(sizes.size(), 1);
+template <typename ExprType>
+SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
+ ExprType unit) {
+ if (sizes.empty())
+ return {};
+ SmallVector<ExprType> strides(sizes.size(), unit);
for (int64_t r = strides.size() - 2; r >= 0; --r)
strides[r] = strides[r + 1] * sizes[r + 1];
return strides;
}
-SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
- ArrayRef<int64_t> v2) {
- SmallVector<int64_t> result;
- for (auto it : llvm::zip(v1, v2))
+template <typename ExprType>
+SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
+ ArrayRef<ExprType> v2) {
+ // Early exit if both are empty, let zip_equal fail if only 1 is empty.
+ if (v1.empty() && v2.empty())
+ return {};
+ SmallVector<ExprType> result;
+ for (auto it : llvm::zip_equal(v1, v2))
result.push_back(std::get<0>(it) * std::get<1>(it));
return result;
}
+template <typename ExprType>
+ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis,
+ ExprType zero) {
+ assert(offsets.size() == basis.size());
+ ExprType linearIndex = zero;
+ for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
+ linearIndex = linearIndex + offsets[idx] * basis[idx];
+ return linearIndex;
+}
+
+template <typename ExprType, typename DivOpTy>
+SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
+ ArrayRef<ExprType> strides,
+ DivOpTy divOp) {
+ int64_t rank = strides.size();
+ SmallVector<ExprType> offsets(rank);
+ for (int64_t r = 0; r < rank; ++r) {
+ offsets[r] = divOp(linearIndex, strides[r]);
+ linearIndex = linearIndex % strides[r];
+ }
+ return offsets;
+}
+
+//===----------------------------------------------------------------------===//
+// Utils that operate on static integer values.
+//===----------------------------------------------------------------------===//
+
+SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
+ assert(llvm::all_of(sizes, [](int64_t s) { return s > 0; }) &&
+ "sizes must be nonnegative");
+ int64_t unit = 1;
+ return ::computeSuffixProductImpl(sizes, unit);
+}
+
+SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
+ ArrayRef<int64_t> v2) {
+ return computeElementwiseMulImpl(v1, v2);
+}
+
+int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
+ assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
+ "basis must be nonnegative");
+ if (basis.empty())
+ return 0;
+ return std::accumulate(basis.begin(), basis.end(), 1,
+ std::multiplies<int64_t>());
+}
+
+int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
+ assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
+ "basis must be nonnegative");
+ int64_t zero = 0;
+ return linearizeImpl(offsets, basis, zero);
+}
+
+SmallVector<int64_t> mlir::delinearize(int64_t linearIndex,
+ ArrayRef<int64_t> strides) {
+ assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) &&
+ "strides must be nonnegative");
+ return delinearizeImpl(linearIndex, strides,
+ [](int64_t e1, int64_t e2) { return e1 / e2; });
+}
+
std::optional<SmallVector<int64_t>>
mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
if (shape.size() < subShape.size())
@@ -60,35 +133,67 @@ mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
return SmallVector<int64_t>{result.rbegin(), result.rend()};
}
-int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
- assert(offsets.size() == basis.size());
- int64_t linearIndex = 0;
- for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
- linearIndex += offsets[idx] * basis[idx];
- return linearIndex;
+//===----------------------------------------------------------------------===//
+// Utils that operate on AffineExpr.
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) {
+ if (sizes.empty())
+ return {};
+ AffineExpr unit = getAffineConstantExpr(1, sizes.front().getContext());
+ return ::computeSuffixProductImpl(sizes, unit);
}
-llvm::SmallVector<int64_t> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
- int64_t index) {
- int64_t rank = sliceStrides.size();
- SmallVector<int64_t> vectorOffsets(rank);
- for (int64_t r = 0; r < rank; ++r) {
- assert(sliceStrides[r] > 0);
- vectorOffsets[r] = index / sliceStrides[r];
- index %= sliceStrides[r];
- }
- return vectorOffsets;
+SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1,
+ ArrayRef<AffineExpr> v2) {
+ return computeElementwiseMulImpl(v1, v2);
}
-int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
+AffineExpr mlir::computeMaxLinearIndex(MLIRContext *ctx,
+ ArrayRef<AffineExpr> basis) {
if (basis.empty())
- return 0;
- return std::accumulate(basis.begin(), basis.end(), 1,
- std::multiplies<int64_t>());
+ return getAffineConstantExpr(0, ctx);
+ return std::accumulate(basis.begin(), basis.end(),
+ getAffineConstantExpr(1, ctx),
+ std::multiplies<AffineExpr>());
+}
+
+AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
+ ArrayRef<AffineExpr> basis) {
+ AffineExpr zero = getAffineConstantExpr(0, ctx);
+ return linearizeImpl(offsets, basis, zero);
}
-llvm::SmallVector<int64_t>
+AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
+ ArrayRef<int64_t> basis) {
+ SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
+ basis, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
+ return linearize(ctx, offsets, basisExprs);
+}
+
+SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
+ ArrayRef<AffineExpr> strides) {
+ return delinearizeImpl(
+ linearIndex, strides,
+ [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); });
+}
+
+SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
+ ArrayRef<int64_t> strides) {
+ MLIRContext *ctx = linearIndex.getContext();
+ SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
+ strides, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
+ return delinearize(linearIndex, ArrayRef<AffineExpr>{basisExprs});
+}
+
+//===----------------------------------------------------------------------===//
+// Permutation utils.
+//===----------------------------------------------------------------------===//
+
+SmallVector<int64_t>
mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
+ assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) &&
+ "permutation must be non-negative");
SmallVector<int64_t> inversion(permutation.size());
for (const auto &pos : llvm::enumerate(permutation)) {
inversion[pos.value()] = pos.index();
@@ -97,6 +202,8 @@ mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
}
bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
+ assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) &&
+ "permutation must be non-negative");
llvm::SmallDenseSet<int64_t, 4> seenVals;
for (auto val : interchange) {
if (seenVals.count(val))
@@ -106,9 +213,9 @@ bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
return seenVals.size() == interchange.size();
}
-llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
- unsigned dropFront,
- unsigned dropBack) {
+SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
+ unsigned dropFront,
+ unsigned dropBack) {
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
auto range = arrayAttr.getAsRange<IntegerAttr>();
SmallVector<int64_t> res;
@@ -118,26 +225,3 @@ llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
res.push_back((*it).getValue().getSExtValue());
return res;
}
-
-mlir::AffineExpr mlir::getLinearAffineExpr(ArrayRef<int64_t> basis,
- mlir::Builder &b) {
- AffineExpr resultExpr = b.getAffineDimExpr(0);
- resultExpr = resultExpr * basis[0];
- for (unsigned i = 1; i < basis.size(); i++)
- resultExpr = resultExpr + b.getAffineDimExpr(i) * basis[i];
- return resultExpr;
-}
-
-llvm::SmallVector<mlir::AffineExpr>
-mlir::getDelinearizedAffineExpr(mlir::ArrayRef<int64_t> strides, Builder &b) {
- AffineExpr resultExpr = b.getAffineDimExpr(0);
- int64_t rank = strides.size();
- SmallVector<AffineExpr> vectorOffsets(rank);
- vectorOffsets[0] = resultExpr.floorDiv(strides[0]);
- resultExpr = resultExpr % strides[0];
- for (unsigned i = 1; i < rank; i++) {
- vectorOffsets[i] = resultExpr.floorDiv(strides[i]);
- resultExpr = resultExpr % strides[i];
- }
- return vectorOffsets;
-}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cabe7d7d95563..0180c7192fb2c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1558,7 +1558,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
}
std::reverse(newStrides.begin(), newStrides.end());
- SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
+ SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ef6804f5511ae..9e5d787856b1f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -457,7 +457,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
++linearIdx) {
- auto extractIdxs = delinearize(prunedInStrides, linearIdx);
+ auto extractIdxs = delinearize(linearIdx, prunedInStrides);
SmallVector<int64_t> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
@@ -588,8 +588,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
auto pos = rewriter.getI64ArrayAttr(d);
- Value x =
- rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
+ Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value r = nullptr;
if (acc)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b2d00255b0d7f..2cbafe19641cc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -31,7 +31,7 @@ using namespace mlir::vector;
static SmallVector<int64_t> getVectorOffset(ArrayRef<int64_t> ratioStrides,
int64_t index,
ArrayRef<int64_t> targetShape) {
- return computeElementwiseMul(delinearize(ratioStrides, index), targetShape);
+ return computeElementwiseMul(delinearize(index, ratioStrides), targetShape);
}
/// A functor that accomplishes the same thing as `getVectorOffset` but
More information about the Mlir-commits
mailing list