[Mlir-commits] [mlir] 7a69a9d - [NFC][mlir] VectorUtils / IndexingUtils simplifications and cleanups
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Nov 22 23:42:34 PST 2022
Author: Nicolas Vasilache
Date: 2022-11-22T23:42:29-08:00
New Revision: 7a69a9d7aee2c5ccbf160b08356a0fc6acbebf75
URL: https://github.com/llvm/llvm-project/commit/7a69a9d7aee2c5ccbf160b08356a0fc6acbebf75
DIFF: https://github.com/llvm/llvm-project/commit/7a69a9d7aee2c5ccbf160b08356a0fc6acbebf75.diff
LOG: [NFC][mlir] VectorUtils / IndexingUtils simplifications and cleanups
This revision refactors and cleans up a bunch of infra related to vector, shapes and indexing into more reusable APIs.
Differential Revision: https://reviews.llvm.org/D138501
Added:
Modified:
mlir/include/mlir/Dialect/Utils/IndexingUtils.h
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/Utils/IndexingUtils.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index c1857462d4c67..ee1d4550e953c 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -28,8 +28,39 @@ int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
/// Given the strides together with a linear index in the dimension
/// space, returns the vector-space offsets in each dimension for a
/// de-linearized index.
-SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> strides,
- int64_t linearIndex);
+SmallVector<int64_t> delinearize(ArrayRef<int64_t> strides,
+ int64_t linearIndex);
+
+/// Given a set of sizes, compute and return the strides (i.e. the number of
+/// linear incides to skip along the (k-1) most minor dimensions to get the next
+/// k-slice). This is also the basis that one can use to linearize an n-D offset
+/// confined to `[0 .. sizes]`.
+SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes);
+
+/// Return a vector containing llvm::zip of v1 and v2 multiplied elementwise.
+SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
+ ArrayRef<int64_t> v2);
+
+/// Compute and return the multi-dimensional integral ratio of `subShape` to
+/// the trailing dimensions of `shape`. This represents how many times
+/// `subShape` fits within `shape`.
+/// If integral division is not possible, return None.
+/// The trailing `subShape.size()` entries of both shapes are assumed (and
+/// enforced) to only contain noonnegative values.
+///
+/// Examples:
+/// - shapeRatio({3, 5, 8}, {2, 5, 2}) returns {3, 2, 1}.
+/// - shapeRatio({3, 8}, {2, 5, 2}) returns None (subshape has higher rank).
+/// - shapeRatio({42, 2, 10, 32}, {2, 5, 2}) returns {42, 1, 2, 16} which is
+/// derived as {42(leading shape dim), 2/2, 10/5, 32/2}.
+/// - shapeRatio({42, 2, 11, 32}, {2, 5, 2}) returns None which is
+/// derived as {42(leading shape dim), 2/2, 11/5(not divisible), 32/2}.
+Optional<SmallVector<int64_t>> computeShapeRatio(ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> subShape);
+
+/// Return the number of elements of basis (i.e. the max linear index).
+/// Return `0` if `basis` is empty.
+int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
@@ -45,16 +76,15 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
}
/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
-SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
- unsigned dropFront = 0,
- unsigned dropBack = 0);
+SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
+ unsigned dropBack = 0);
/// Computes and returns linearized affine expression w.r.t. `basis`.
mlir::AffineExpr getLinearAffineExpr(ArrayRef<int64_t> basis, mlir::Builder &b);
/// Given the strides in the dimension space, returns the affine expressions for
/// vector-space offsets in each dimension for a de-linearized index.
-SmallVector<mlir::AffineExpr, 4>
+SmallVector<mlir::AffineExpr>
getDelinearizedAffineExpr(ArrayRef<int64_t> strides, mlir::Builder &b);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 15756617792ac..0ab03085f5ae1 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -111,7 +111,7 @@ struct UnrollVectorOptions {
}
using NativeShapeFnType =
- std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
+ std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
/// Function that returns the shape of the vector to unroll to for a given
/// operation. The unrolling is aborted if the function returns `llvm::None`.
NativeShapeFnType nativeShape = nullptr;
@@ -122,8 +122,8 @@ struct UnrollVectorOptions {
/// Set the native shape to use for unrolling.
UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
- SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
- nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
+ SmallVector<int64_t> tsShape(shape.begin(), shape.end());
+ nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t>> {
return tsShape;
};
return *this;
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index b5e6bc1ae5747..3c9f1f41419ad 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -36,43 +36,6 @@ namespace vector {
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
} // namespace vector
-/// Return the number of elements of basis, `0` if empty.
-int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
-
-/// Given the shape and sizes of a vector, returns the corresponding
-/// strides for each dimension.
-/// TODO: needs better doc of how it is used.
-SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
- ArrayRef<int64_t> sizes);
-
-/// Given the target sizes of a vector, together with vector-space offsets,
-/// returns the element-space offsets for each dimension.
-SmallVector<int64_t, 4>
-computeElementOffsetsFromVectorSliceOffsets(ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> vectorOffsets);
-
-/// Computes and returns the multi-dimensional ratio of `superShape` to
-/// `subShape`. This is calculated by performing a traversal from minor to major
-/// dimensions (i.e. in reverse shape order). If integral division is not
-/// possible, returns None.
-/// The ArrayRefs are assumed (and enforced) to only contain > 1 values.
-/// This constraint comes from the fact that they are meant to be used with
-/// VectorTypes, for which the property holds by construction.
-///
-/// Examples:
-/// - shapeRatio({3, 4, 5, 8}, {2, 5, 2}) returns {3, 2, 1, 4}
-/// - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None
-/// - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16}
-Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
- ArrayRef<int64_t> subShape);
-
-/// Computes and returns the multi-dimensional ratio of the shapes of
-/// `superVector` to `subVector`. If integral division is not possible, returns
-/// None.
-/// Assumes and enforces that the VectorTypes have the same elemental type.
-Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
- VectorType subVectorType);
-
/// Constructs a permutation map of invariant memref indices to vector
/// dimension.
///
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index a0eee716df4a9..d88e7795f0de4 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -80,8 +80,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
Value result = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(
vecType, IntegerAttr::get(vecType.getElementType(), 0)));
- SmallVector<int64_t> ones(shape.size(), 1);
- SmallVector<int64_t> strides = computeStrides(shape, ones);
+ SmallVector<int64_t> strides = computeStrides(shape);
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(strides, linearIndex);
SmallVector<Value> operands;
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a412dbd0c6e03..d40666d6608c5 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -79,8 +79,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
Value result = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(
vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
- SmallVector<int64_t> ones(shape.size(), 1);
- SmallVector<int64_t> strides = computeStrides(shape, ones);
+ SmallVector<int64_t> strides = computeStrides(shape);
for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(strides, linearIndex);
SmallVector<Value> operands;
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 6413d50948962..1f004b1723a7a 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -127,15 +127,13 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Iterate over all outer dimensions of the compute shape vector type.
auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
- int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims);
-
- SmallVector<int64_t> ones(iterationDims.size(), 1);
- auto strides = computeStrides(iterationDims, ones);
+ int64_t maxIndex = computeMaxLinearIndex(iterationDims);
+ auto strides = computeStrides(iterationDims);
// Compute results for each one dimensional vector.
- SmallVector<Value> results(maxLinearIndex);
+ SmallVector<Value> results(maxIndex);
- for (int64_t i = 0; i < maxLinearIndex; ++i) {
+ for (int64_t i = 0; i < maxIndex; ++i) {
auto offsets = delinearize(strides, i);
SmallVector<Value> extracted(expandedOperands.size());
@@ -152,7 +150,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
Value result = builder.create<arith::ConstantOp>(
resultExpandedType, builder.getZeroAttr(resultExpandedType));
- for (int64_t i = 0; i < maxLinearIndex; ++i)
+ for (int64_t i = 0; i < maxIndex; ++i)
result = builder.create<vector::InsertOp>(results[i], result,
delinearize(strides, i));
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index b9901b7af0ff1..1c7a89cd755f3 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -12,6 +12,53 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include <numeric>
+
+using namespace mlir;
+
+SmallVector<int64_t> mlir::computeStrides(ArrayRef<int64_t> sizes) {
+ SmallVector<int64_t> strides(sizes.size(), 1);
+ for (int64_t r = strides.size() - 2; r >= 0; --r)
+ strides[r] = strides[r + 1] * sizes[r + 1];
+ return strides;
+}
+
+SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
+ ArrayRef<int64_t> v2) {
+ SmallVector<int64_t> result;
+ for (auto it : llvm::zip(v1, v2))
+ result.push_back(std::get<0>(it) * std::get<1>(it));
+ return result;
+}
+
+Optional<SmallVector<int64_t>>
+mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
+ if (shape.size() < subShape.size())
+ return None;
+ assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) &&
+ "shape must be nonnegative");
+ assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) &&
+ "subShape must be nonnegative");
+
+ // Starting from the end, compute the integer divisors.
+ std::vector<int64_t> result;
+ result.reserve(shape.size());
+ for (auto [size, subSize] :
+ llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) {
+ // If integral division does not occur, return and let the caller decide.
+ if (size % subSize != 0)
+ return None;
+ result.push_back(size / subSize);
+ }
+ // At this point we computed the ratio (in reverse) for the common size.
+ // Fill with the remaining entries from the shape (still in reverse).
+ int commonSize = subShape.size();
+ std::copy(shape.rbegin() + commonSize, shape.rend(),
+ std::back_inserter(result));
+ // Reverse again to get it back in the proper order and return.
+ return SmallVector<int64_t>{result.rbegin(), result.rend()};
+}
+
int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
assert(offsets.size() == basis.size());
int64_t linearIndex = 0;
@@ -20,10 +67,10 @@ int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
return linearIndex;
}
-llvm::SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
- int64_t index) {
+llvm::SmallVector<int64_t> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
+ int64_t index) {
int64_t rank = sliceStrides.size();
- SmallVector<int64_t, 4> vectorOffsets(rank);
+ SmallVector<int64_t> vectorOffsets(rank);
for (int64_t r = 0; r < rank; ++r) {
assert(sliceStrides[r] > 0);
vectorOffsets[r] = index / sliceStrides[r];
@@ -32,12 +79,19 @@ llvm::SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
return vectorOffsets;
}
-llvm::SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr,
- unsigned dropFront,
- unsigned dropBack) {
+int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
+ if (basis.empty())
+ return 0;
+ return std::accumulate(basis.begin(), basis.end(), 1,
+ std::multiplies<int64_t>());
+}
+
+llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
+ unsigned dropFront,
+ unsigned dropBack) {
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
auto range = arrayAttr.getAsRange<IntegerAttr>();
- SmallVector<int64_t, 4> res;
+ SmallVector<int64_t> res;
res.reserve(arrayAttr.size() - dropFront - dropBack);
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
it != eit; ++it)
@@ -54,11 +108,11 @@ mlir::AffineExpr mlir::getLinearAffineExpr(ArrayRef<int64_t> basis,
return resultExpr;
}
-llvm::SmallVector<mlir::AffineExpr, 4>
+llvm::SmallVector<mlir::AffineExpr>
mlir::getDelinearizedAffineExpr(mlir::ArrayRef<int64_t> strides, Builder &b) {
AffineExpr resultExpr = b.getAffineDimExpr(0);
int64_t rank = strides.size();
- SmallVector<AffineExpr, 4> vectorOffsets(rank);
+ SmallVector<AffineExpr> vectorOffsets(rank);
vectorOffsets[0] = resultExpr.floorDiv(strides[0]);
resultExpr = resultExpr % strides[0];
for (unsigned i = 1; i < rank; i++) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 0ce46c9bd685e..c41e66325c011 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -54,9 +54,9 @@ static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
}
// Helper to construct iterator types with one index removed.
-static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
- int64_t index) {
- SmallVector<Attribute, 4> results;
+static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
+ int64_t index) {
+ SmallVector<Attribute> results;
for (const auto &it : llvm::enumerate(iteratorTypes)) {
int64_t idx = it.index();
if (idx == index)
@@ -70,7 +70,7 @@ static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
static AffineMap adjustMap(AffineMap map, int64_t index,
PatternRewriter &rewriter) {
auto *ctx = rewriter.getContext();
- SmallVector<AffineExpr, 4> results;
+ SmallVector<AffineExpr> results;
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t idx = map.getDimPosition(i);
if (idx == index)
@@ -140,7 +140,7 @@ static Value reshapeStore(Location loc, Value val, Value result,
}
template <typename IntType>
-static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
+static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(llvm::map_range(
arrayAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
@@ -399,7 +399,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
VectorType resType = op.getResultType();
// Set up convenience transposition table.
- SmallVector<int64_t, 4> transp;
+ SmallVector<int64_t> transp;
for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
@@ -430,12 +430,11 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
// in vector form to improve performance. Therefore, we prune those
// dimensions from the shape/transpose data structures used to generate the
// extract/insert ops.
- SmallVector<int64_t, 4> prunedTransp;
+ SmallVector<int64_t> prunedTransp;
pruneNonTransposedDims(transp, prunedTransp);
size_t numPrunedDims = transp.size() - prunedTransp.size();
auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
- SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
- auto prunedInStrides = computeStrides(prunedInShape, ones);
+ auto prunedInStrides = computeStrides(prunedInShape);
// Generates the extract/insert operations for every scalar/vector element
// of the leftmost transposed dimensions. We traverse every transpose
@@ -448,7 +447,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
++linearIdx) {
auto extractIdxs = delinearize(prunedInStrides, linearIdx);
- SmallVector<int64_t, 4> insertIdxs(extractIdxs);
+ SmallVector<int64_t> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
@@ -488,7 +487,7 @@ class TransposeOp2DToShuffleLowering
if (srcType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
- SmallVector<int64_t, 4> transp;
+ SmallVector<int64_t> transp;
for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (transp[0] != 1 && transp[1] != 0)
@@ -685,8 +684,8 @@ struct ContractOpToElementwise
bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
- SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
- SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
+ SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
+ SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
newLhs = rewriter.create<vector::ExtractOp>(
loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
newRhs = rewriter.create<vector::ExtractOp>(
@@ -752,7 +751,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
if (rank == 1) {
// Express constant 1-D case in explicit vector form:
// [T,..,T,F,..,F].
- SmallVector<bool, 4> values(dstType.getDimSize(0));
+ SmallVector<bool> values(dstType.getDimSize(0));
for (int64_t d = 0; d < trueDim; d++)
values[d] = true;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
@@ -762,7 +761,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
- SmallVector<int64_t, 4> newDimSizes;
+ SmallVector<int64_t> newDimSizes;
for (int64_t r = 1; r < rank; r++)
newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
@@ -931,8 +930,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
// x[0,1,0] = y[0,2]
// etc., incrementing the two index vectors "row-major"
// within the source and result shape.
- SmallVector<int64_t, 4> srcIdx(srcRank);
- SmallVector<int64_t, 4> resIdx(resRank);
+ SmallVector<int64_t> srcIdx(srcRank);
+ SmallVector<int64_t> resIdx(resRank);
Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
for (int64_t i = 0; i < numElts; i++) {
@@ -948,7 +947,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
}
private:
- static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
+ static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
assert(0 <= r && r < tp.getRank());
if (++idx[r] == tp.getDimSize(r)) {
idx[r] = 0;
@@ -1039,7 +1038,7 @@ struct CombineContractABTranspose final
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
- SmallVector<AffineMap, 4> maps =
+ SmallVector<AffineMap> maps =
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
Value lhs = contractOp.getLhs();
Value rhs = contractOp.getRhs();
@@ -1169,7 +1168,7 @@ struct CombineContractBroadcast
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
- SmallVector<AffineMap, 4> maps =
+ SmallVector<AffineMap> maps =
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
Value lhs = contractOp.getLhs();
Value rhs = contractOp.getRhs();
@@ -1234,7 +1233,7 @@ struct CombineContractBroadcast
for (auto &m : maps)
m = compressDims(m, unusedDimsBitVector);
// Compute the combined iterators.
- SmallVector<Attribute, 4> iterators;
+ SmallVector<Attribute> iterators;
for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
if (!unusedDimsBitVector.test(i))
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
@@ -1328,7 +1327,7 @@ struct ReorderElementwiseOpsOnTranspose final
// Make sure all operands are transpose/constant ops and collect their
// transposition maps.
- SmallVector<ArrayAttr, 4> transposeMaps;
+ SmallVector<ArrayAttr> transposeMaps;
transposeMaps.reserve(op->getNumOperands());
// Record the initial type before transposition. We'll use its shape later.
// Any type will do here as we will check all transpose maps are the same.
@@ -1350,7 +1349,7 @@ struct ReorderElementwiseOpsOnTranspose final
if (!llvm::all_equal(transposeMaps))
return rewriter.notifyMatchFailure(op, "
diff erent transpose map");
- SmallVector<Value, 4> srcValues;
+ SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
// If there are constant operands, we need to insert inverse transposes for
@@ -1724,7 +1723,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
- SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
+ SmallVector<AffineMap> maps = op.getIndexingMapsArray();
//
// In the following we wish to make the reduction dimension innermost so we
// can load vectors and just fmul + reduce into a scalar.
@@ -1940,7 +1939,7 @@ ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
VectorType rhsType = op.getRhsType();
VectorType resType = op.getResultType().cast<VectorType>();
// Find the iterator type index and result index.
- SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
+ SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
int64_t iterIndex = -1;
int64_t dimSize = -1;
if (lhsIndex >= 0) {
@@ -2011,7 +2010,7 @@ ContractionOpLowering::lowerReduction(vector::ContractionOp op,
bool isInt = resType.isa<IntegerType>();
// Use iterator index 0.
int64_t iterIndex = 0;
- SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
+ SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
if (!lookupLhs.has_value())
@@ -2087,7 +2086,7 @@ struct TransferReadToVectorLoadLowering
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
return failure();
- SmallVector<unsigned, 4> broadcastedDims;
+ SmallVector<unsigned> broadcastedDims;
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
// We let the 0-d corner case pass-through as it is supported.
@@ -2106,8 +2105,8 @@ struct TransferReadToVectorLoadLowering
// If there is broadcasting involved then we first load the unbroadcasted
// vector, and then broadcast it with `vector.broadcast`.
ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
- SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
- vectorShape.end());
+ SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
+ vectorShape.end());
for (unsigned i : broadcastedDims)
unbroadcastedVectorShape[i] = 1;
VectorType unbroadcastedVectorType = VectorType::get(
@@ -2286,7 +2285,7 @@ struct TransferWriteToVectorStoreLowering
};
// Returns the values in `arrayAttr` as an integer vector.
-static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
+static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(
llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr attr) { return attr.getInt(); }));
@@ -2410,7 +2409,7 @@ struct BubbleDownBitCastForStridedSliceExtract
// dimension's offset given we are extracting from less elements now.
ArrayAttr newOffsets = extractOp.getOffsets();
if (newOffsets.size() == rank) {
- SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
+ SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
if (offsets.back() % expandRatio != 0)
return failure();
offsets.back() = offsets.back() / expandRatio;
@@ -2420,14 +2419,14 @@ struct BubbleDownBitCastForStridedSliceExtract
// Similarly for sizes.
ArrayAttr newSizes = extractOp.getSizes();
if (newSizes.size() == rank) {
- SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
+ SmallVector<int64_t> sizes = getIntValueVector(newSizes);
if (sizes.back() % expandRatio != 0)
return failure();
sizes.back() = sizes.back() / expandRatio;
newSizes = rewriter.getI64ArrayAttr(sizes);
}
- SmallVector<int64_t, 4> dims =
+ SmallVector<int64_t> dims =
llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
dims.back() = dims.back() / expandRatio;
VectorType newExtractType =
@@ -2500,13 +2499,13 @@ struct BubbleUpBitCastForStridedSliceInsert
ArrayAttr newOffsets = insertOp.getOffsets();
assert(newOffsets.size() == rank);
- SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
+ SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
if (offsets.back() % shrinkRatio != 0)
return failure();
offsets.back() = offsets.back() / shrinkRatio;
newOffsets = rewriter.getI64ArrayAttr(offsets);
- SmallVector<int64_t, 4> srcDims =
+ SmallVector<int64_t> srcDims =
llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
srcDims.back() = srcDims.back() / shrinkRatio;
VectorType newCastSrcType =
@@ -2515,7 +2514,7 @@ struct BubbleUpBitCastForStridedSliceInsert
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
- SmallVector<int64_t, 4> dstDims =
+ SmallVector<int64_t> dstDims =
llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
dstDims.back() = dstDims.back() / shrinkRatio;
VectorType newCastDstType =
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 8c5d5144304e5..5d07fb3a925b4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -27,24 +27,19 @@ using namespace mlir::vector;
/// During unrolling from `originalShape` to `targetShape` return the offset for
/// the slice `index`.
-static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
- ArrayRef<int64_t> targetShape,
- int64_t index) {
- SmallVector<int64_t, 4> dstSliceStrides =
- computeStrides(originalShape, targetShape);
- SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
- SmallVector<int64_t, 4> elementOffsets =
- computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
- return elementOffsets;
+static SmallVector<int64_t> getVectorOffset(ArrayRef<int64_t> ratioStrides,
+ int64_t index,
+ ArrayRef<int64_t> targetShape) {
+ return computeElementwiseMul(delinearize(ratioStrides, index), targetShape);
}
-/// A functor that accomplishes the same thing as `getVectorOffset` but allows
-/// for reordering the traversal of the dimensions. The order of traversal is
-/// given in "for loop order" (outer to inner).
+/// A functor that accomplishes the same thing as `getVectorOffset` but
+/// allows for reordering the traversal of the dimensions. The order of
+/// traversal is given in "for loop order" (outer to inner).
namespace {
class DecomposeShapeIterator {
private:
- SmallVector<int64_t, 4> vectorShape;
+ SmallVector<int64_t> vectorShape;
SmallVector<int64_t> loopOrder;
SmallVector<int64_t> sliceStrides;
int64_t maxIndexVal{1};
@@ -56,15 +51,15 @@ class DecomposeShapeIterator {
: vectorShape(targetShape.begin(), targetShape.end()),
loopOrder(loopOrder.begin(), loopOrder.end()),
sliceStrides(originalShape.size()) {
- assert(originalShape.size() == targetShape.size());
- assert(loopOrder.size() == targetShape.size());
+ assert(originalShape.size() >= targetShape.size());
+ assert(loopOrder.size() == originalShape.size());
// Compute the count for each dimension.
- SmallVector<int64_t> sliceDimCounts(originalShape.size());
- for (unsigned r = 0; r < originalShape.size(); ++r) {
- sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
- maxIndexVal *= sliceDimCounts[r];
- }
+ auto maybeShapeRatio = computeShapeRatio(originalShape, targetShape);
+ assert(maybeShapeRatio && "Shape does not evenly divide");
+ // Pad `sliceDimCounts` with leading 1s so that all sizes match.
+ SmallVector<int64_t> sliceDimCounts = *maybeShapeRatio;
+ maxIndexVal = computeMaxLinearIndex(sliceDimCounts);
// Reversing "loop order" gives dimensions from fastest varying to slowest
// varying (smallest stride to largest stride).
@@ -95,7 +90,7 @@ class DecomposeShapeIterator {
SmallVector<int64_t> getVectorOffset(int64_t index) const {
SmallVector<int64_t> vectorOffsets = delinearize(index);
SmallVector<int64_t> elementOffsets =
- computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
+ computeElementwiseMul(vectorShape, vectorOffsets);
return elementOffsets;
}
};
@@ -139,7 +134,7 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
/// Return the target shape for unrolling for the given `op`. Return llvm::None
/// if the op shouldn't be or cannot be unrolled.
-static Optional<SmallVector<int64_t, 4>>
+static Optional<SmallVector<int64_t>>
getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
if (options.filterConstraint && failed(options.filterConstraint(op)))
return llvm::None;
@@ -152,10 +147,10 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
if (!maybeUnrollShape)
return llvm::None;
- Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
+ Optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
if (!targetShape)
return llvm::None;
- auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
+ auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
if (!maybeShapeRatio ||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
return llvm::None;
@@ -197,7 +192,7 @@ struct UnrollTransferReadPattern
if (!targetShape)
return failure();
auto sourceVectorType = readOp.getVectorType();
- SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+ SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = readOp.getLoc();
ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
@@ -206,17 +201,16 @@ struct UnrollTransferReadPattern
loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
auto targetType =
VectorType::get(*targetShape, sourceVectorType.getElementType());
- SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
- readOp.getIndices().end());
+ SmallVector<Value> originalIndices(readOp.getIndices().begin(),
+ readOp.getIndices().end());
SmallVector<int64_t> loopOrder =
getUnrollOrder(originalSize.size(), readOp, options);
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
loopOrder);
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
- SmallVector<int64_t, 4> elementOffsets =
- indexToOffsets.getVectorOffset(i);
- SmallVector<Value, 4> indices =
+ SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
+ SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
readOp.getPermutationMap(), loc, rewriter);
auto slicedRead = rewriter.create<vector::TransferReadOp>(
@@ -255,11 +249,11 @@ struct UnrollTransferWritePattern
if (!targetShape)
return failure();
auto sourceVectorType = writeOp.getVectorType();
- SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+ SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = writeOp.getLoc();
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
- SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
- writeOp.getIndices().end());
+ SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
+ writeOp.getIndices().end());
SmallVector<int64_t> loopOrder =
getUnrollOrder(originalSize.size(), writeOp, options);
@@ -267,11 +261,10 @@ struct UnrollTransferWritePattern
loopOrder);
Value resultTensor;
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
- SmallVector<int64_t, 4> elementOffsets =
- indexToOffsets.getVectorOffset(i);
+ SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
- SmallVector<Value, 4> indices =
+ SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
writeOp.getPermutationMap(), loc, rewriter);
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
@@ -321,7 +314,7 @@ struct UnrollContractionPattern
if (!targetShape)
return failure();
auto dstVecType = contractOp.getResultType().cast<VectorType>();
- SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
+ SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
Location loc = contractOp.getLoc();
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
@@ -337,16 +330,16 @@ struct UnrollContractionPattern
loopOrder);
const int64_t sliceCount = indexToOffsets.maxIndex();
for (int64_t i = 0; i < sliceCount; i++) {
- SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
- SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
+ SmallVector<int64_t> offsets = indexToOffsets.getVectorOffset(i);
+ SmallVector<Value> slicesOperands(contractOp.getNumOperands());
- // Helper to coompute the new shape of each operand and extract the slice.
+ // Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](unsigned index, Value operand,
AffineMap permutationMap,
ArrayRef<int64_t> operandOffets) {
SmallVector<int64_t> operandShape = applyPermutationMap(
permutationMap, ArrayRef<int64_t>(*targetShape));
- SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
+ SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
loc, operand, operandOffets, operandShape, operandStrides);
};
@@ -420,12 +413,12 @@ struct UnrollMultiReductionPattern
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
PatternRewriter &rewriter) const override {
- Optional<SmallVector<int64_t, 4>> targetShape =
+ Optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, reductionOp);
if (!targetShape)
return failure();
- SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
- SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+ SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
+ SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
llvm::MapVector<
SmallVector<int64_t>, Value,
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
@@ -433,12 +426,16 @@ struct UnrollMultiReductionPattern
// Compute shape ratio of 'shape' and 'sizes'.
int64_t sliceCount = computeMaxLinearIndex(ratio);
Location loc = reductionOp.getLoc();
+
+ // Stride of the ratios, this gives us the offsets of sliceCount in a basis
+ // of multiples of the targetShape.
+ auto ratioStrides = computeStrides(ratio);
for (int64_t i = 0; i < sliceCount; i++) {
- SmallVector<int64_t, 4> offsets =
- getVectorOffset(originalSize, *targetShape, i);
+ SmallVector<int64_t> offsets =
+ getVectorOffset(ratioStrides, i, *targetShape);
SmallVector<Value> operands;
- SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
+ SmallVector<int64_t> operandStrides(offsets.size(), 1);
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
operands.push_back(slicedOperand);
@@ -451,7 +448,7 @@ struct UnrollMultiReductionPattern
}
}
Value acc;
- SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
+ SmallVector<int64_t> accStrides(destOffset.size(), 1);
// If a version of the accumulator has already been computed, use it
// otherwise extract the first version from the original operand.
auto accIt = accCache.find(destOffset);
@@ -500,21 +497,25 @@ struct UnrollElementwisePattern : public RewritePattern {
if (!targetShape)
return failure();
auto dstVecType = op->getResult(0).getType().cast<VectorType>();
- SmallVector<int64_t, 4> originalSize =
+ SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
- SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+ SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
int64_t sliceCount = computeMaxLinearIndex(ratio);
Location loc = op->getLoc();
// Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
loc, dstVecType, rewriter.getZeroAttr(dstVecType));
- SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+ SmallVector<int64_t> strides(targetShape->size(), 1);
VectorType newVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
+
+ // Stride of the ratios, this gives us the offsets of sliceCount in a basis
+ // of multiples of the targetShape.
+ auto ratioStrides = computeStrides(ratio);
for (int64_t i = 0; i < sliceCount; i++) {
- SmallVector<int64_t, 4> offsets =
- getVectorOffset(originalSize, *targetShape, i);
- SmallVector<Value, 4> extractOperands;
+ SmallVector<int64_t> offsets =
+ getVectorOffset(ratioStrides, i, *targetShape);
+ SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
auto vecType = operand.get().getType().template dyn_cast<VectorType>();
if (!vecType) {
@@ -547,19 +548,24 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
PatternRewriter &rewriter) const override {
- Optional<SmallVector<int64_t, 4>> targetShape =
+ Optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, reductionOp);
if (!targetShape)
return failure();
SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
- int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
+ auto ratio = *computeShapeRatio(originalSize, *targetShape);
+ int64_t sliceCount = ratio[0];
// Create unrolled vector reduction.
Location loc = reductionOp.getLoc();
Value accumulator = nullptr;
- for (int64_t i = 0; i < ratio; ++i) {
+
+ // Stride of the ratios, this gives us the offsets of sliceCount in a basis
+ // of multiples of the targetShape.
+ auto ratioStrides = computeStrides(ratio);
+ for (int64_t i = 0; i < sliceCount; ++i) {
SmallVector<int64_t> offsets =
- getVectorOffset(originalSize, *targetShape, i);
+ getVectorOffset(ratioStrides, i, *targetShape);
SmallVector<int64_t> strides(offsets.size(), 1);
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
loc, reductionOp.getVector(), offsets, *targetShape, strides);
@@ -600,21 +606,25 @@ struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
if (!targetShape)
return failure();
auto originalVectorType = tranposeOp.getResultType();
- SmallVector<int64_t, 4> strides(targetShape->size(), 1);
+ SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = tranposeOp.getLoc();
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
- SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
+ SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
int64_t sliceCount = computeMaxLinearIndex(ratio);
// Prepare the result vector;
Value result = rewriter.create<arith::ConstantOp>(
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
SmallVector<int64_t> permutation;
tranposeOp.getTransp(permutation);
+
+ // Stride of the ratios, this gives us the offsets of sliceCount in a basis
+ // of multiples of the targetShape.
+ auto ratioStrides = computeStrides(ratio);
for (int64_t i = 0; i < sliceCount; i++) {
- SmallVector<int64_t, 4> elementOffsets =
- getVectorOffset(originalSize, *targetShape, i);
- SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
- SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
+ SmallVector<int64_t> elementOffsets =
+ getVectorOffset(ratioStrides, i, *targetShape);
+ SmallVector<int64_t> permutedOffsets(elementOffsets.size());
+ SmallVector<int64_t> permutedShape(elementOffsets.size());
// Compute the source offsets and shape.
for (auto &indices : llvm::enumerate(permutation)) {
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 4e7ed8db19f49..8e3cd03f956cd 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IntegerSet.h"
@@ -25,7 +26,6 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/MathExtras.h"
-#include <numeric>
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
@@ -43,78 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
llvm_unreachable("Expected MemRefType or TensorType");
}
-/// Return the number of elements of basis, `0` if empty.
-int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
- if (basis.empty())
- return 0;
- return std::accumulate(basis.begin(), basis.end(), 1,
- std::multiplies<int64_t>());
-}
-
-SmallVector<int64_t, 4> mlir::computeStrides(ArrayRef<int64_t> shape,
- ArrayRef<int64_t> sizes) {
- int64_t rank = shape.size();
- // Compute the count for each dimension.
- SmallVector<int64_t, 4> sliceDimCounts(rank);
- for (int64_t r = 0; r < rank; ++r)
- sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);
- // Use that to compute the slice stride for each dimension.
- SmallVector<int64_t, 4> sliceStrides(rank);
- sliceStrides[rank - 1] = 1;
- for (int64_t r = rank - 2; r >= 0; --r)
- sliceStrides[r] = sliceStrides[r + 1] * sliceDimCounts[r + 1];
- return sliceStrides;
-}
-
-SmallVector<int64_t, 4> mlir::computeElementOffsetsFromVectorSliceOffsets(
- ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
- SmallVector<int64_t, 4> result;
- for (auto it : llvm::zip(vectorOffsets, sizes))
- result.push_back(std::get<0>(it) * std::get<1>(it));
- return result;
-}
-
-Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
- ArrayRef<int64_t> subShape) {
- if (superShape.size() < subShape.size()) {
- return None;
- }
-
- // Starting from the end, compute the integer divisors.
- std::vector<int64_t> result;
- result.reserve(superShape.size());
- for (auto [superSize, subSize] :
- llvm::zip(llvm::reverse(superShape), llvm::reverse(subShape))) {
- assert(superSize > 0 && "superSize must be > 0");
- assert(subSize > 0 && "subSize must be > 0");
-
- // If integral division does not occur, return and let the caller decide.
- if (superSize % subSize != 0)
- return None;
- result.push_back(superSize / subSize);
- }
-
- // At this point we computed the ratio (in reverse) for the common
- // size. Fill with the remaining entries from the super-vector shape (still in
- // reverse).
- int commonSize = subShape.size();
- std::copy(superShape.rbegin() + commonSize, superShape.rend(),
- std::back_inserter(result));
-
- assert(result.size() == superShape.size() &&
- "super to sub shape ratio is not of the same size as the super rank");
-
- // Reverse again to get it back in the proper order and return.
- return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
-}
-
-Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
- VectorType subVectorType) {
- assert(superVectorType.getElementType() == subVectorType.getElementType() &&
- "vector types must be of the same elemental type");
- return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
-}
-
/// Constructs a permutation map from memref indices to vector dimension.
///
/// The implementation uses the knowledge of the mapping of enclosing loop to
@@ -144,8 +72,8 @@ static AffineMap makePermutationMap(
return AffineMap();
MLIRContext *context =
enclosingLoopToVectorDim.begin()->getFirst()->getContext();
- SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
- getAffineConstantExpr(0, context));
+ SmallVector<AffineExpr> perm(enclosingLoopToVectorDim.size(),
+ getAffineConstantExpr(0, context));
for (auto kvp : enclosingLoopToVectorDim) {
assert(kvp.second < perm.size());
@@ -252,7 +180,8 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
}
// Get the ratio.
- auto ratio = shapeRatio(superVectorType, subVectorType);
+ auto ratio =
+ computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
// Sanity check.
assert((ratio || !mustDivide) &&
diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index 86cb5cea8a63f..61428bbf7091f 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
@@ -126,7 +127,8 @@ void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
// purpose of this test. If we need to test more intricate behavior in the
// future we can always extend.
auto superVectorType = opInst->getResult(0).getType().cast<VectorType>();
- auto ratio = shapeRatio(superVectorType, subVectorType);
+ auto ratio =
+ computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
if (!ratio) {
opInst->emitRemark("NOT MATCHED");
} else {
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 6b9afe3b42e71..1bd40e7cde7e3 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -72,11 +72,11 @@ struct TestVectorToVectorLowering
private:
// Return the target shape based on op type.
- static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
+ static Optional<SmallVector<int64_t>> getShape(Operation *op) {
if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
- return SmallVector<int64_t, 4>(2, 2);
+ return SmallVector<int64_t>(2, 2);
if (isa<vector::ContractionOp>(op))
- return SmallVector<int64_t, 4>(3, 2);
+ return SmallVector<int64_t>(3, 2);
// For transfer ops, just propagate the shape coming from
// InsertStridedSlices/ExtractStridedSlices.
if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
@@ -90,15 +90,15 @@ struct TestVectorToVectorLowering
return llvm::None;
dstVec = vecType;
}
- return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
- dstVec.getShape().end());
+ return SmallVector<int64_t>(dstVec.getShape().begin(),
+ dstVec.getShape().end());
}
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
if (!insert)
return llvm::None;
ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
- return SmallVector<int64_t, 4>(shape.begin(), shape.end());
+ return SmallVector<int64_t>(shape.begin(), shape.end());
}
return llvm::None;
}
@@ -314,10 +314,10 @@ struct TestVectorUnrollingPatterns
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
- [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
+ [](Operation *op) -> Optional<SmallVector<int64_t>> {
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
- SmallVector<int64_t, 4> nativeShape(
- contractOp.getIteratorTypes().size(), 4);
+ SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
+ 4);
Type lhsType = contractOp.getLhsType().getElementType();
nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
return nativeShape;
@@ -339,12 +339,11 @@ struct TestVectorUnrollingPatterns
}
populateVectorUnrollPatterns(patterns, opts);
} else {
- auto nativeShapeFn =
- [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
+ auto nativeShapeFn = [](Operation *op) -> Optional<SmallVector<int64_t>> {
auto contractOp = dyn_cast<ContractionOp>(op);
if (!contractOp)
return None;
- return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
+ return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
};
populateVectorUnrollPatterns(patterns,
UnrollVectorOptions()
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index eabc46ccda2b5..290c9c5777611 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -61,7 +61,6 @@ struct LinalgTransformationFilter {
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
Operation *op) const;
- bool hasReplacementFilter(Operation *op) const;
LinalgTransformationFilter &addFilter(const FilterFunction &f) {
if (f)
@@ -100,15 +99,6 @@ LinalgTransformationFilter::LinalgTransformationFilter(
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
replacement(replacement), matchByDefault(false) {}
-LinalgTransformationFilter::LinalgTransformationFilter(
- const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
- Optional<StringAttr> replacement)
- : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
- replacement(replacement), matchByDefault(false) {
- if (f)
- filters.push_back(f);
-}
-
LogicalResult
LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
Operation *op) const {
@@ -150,13 +140,6 @@ void LinalgTransformationFilter::replaceLinalgTransformationFilter(
op->removeAttr(rewriter.getStringAttr(kLinalgTransformMarker));
}
-bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const {
- if (!replacement)
- return false;
- auto attr = op->getAttr(kLinalgTransformMarker).dyn_cast<StringAttr>();
- return attr && attr == *replacement;
-}
-
/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
/// using a `filter` to avoid recursive application.
More information about the Mlir-commits
mailing list