[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)
James Newling
llvmlistbot at llvm.org
Wed Jun 4 15:50:42 PDT 2025
================
@@ -21,177 +21,259 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include <numeric>
#define DEBUG_TYPE "vector-shape-cast-lowering"
using namespace mlir;
-using namespace mlir::vector;
-
-/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
- int step = 1) {
- for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
- assert(indices[dim] < vecType.getDimSize(dim) &&
- "Indices are out of bound");
- indices[dim] += step;
- if (indices[dim] < vecType.getDimSize(dim))
- break;
- indices[dim] = 0;
- step = 1;
+/// Perform the inplace update
+/// rhs <- lhs + rhs
+///
+/// where `rhs` is a number expressed in mixed base `base` with most signficant
+/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
+/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
+///
+/// Some examples where `base` is {5,3,2}:
+/// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1}
+/// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0}
+/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
+///
+/// Invalid:
+/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
+///
+/// Overflows not handled correctly:
+/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
+static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
+ MutableArrayRef<int64_t> rhs) {
+
+ // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
+ for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
+ int64_t dimBase = base[dim];
+ assert(rhs[dim] < dimBase && "rhs not in base");
+
+ int64_t incremented = rhs[dim] + lhs;
+
+ // If the incremented value excedes the dimension base, we must spill to the
+ // next most significant dimension and repeat (we might need to spill to
+ // more significant dimensions multiple times).
+ lhs = incremented / dimBase;
+ rhs[dim] = incremented % dimBase;
+ if (lhs == 0)
+ break;
}
}
namespace {
-/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
-/// vectors progressively. This iterates over the n-1 major dimensions of the
-/// n-D vector and performs rewrites into:
-/// vector.extract from n-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOpNDDownCastRewritePattern
- : public OpRewritePattern<vector::ShapeCastOp> {
+
+/// shape_cast is converted to a sequence of extract, extract_strided_slice,
+/// insert_strided_slice, and insert operations. The running example will be:
+///
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
+///
+/// In this example the source and result shapes share a common suffix of 7x11.
+/// This means we can always decompose the shape_cast into extract, insert, and
+/// their strided equivalents, on vectors with shape suffix 7x11.
+///
+/// The greatest common divisor (gcd) of the first dimension preceding the
+/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
+/// on vectors with shapes that are `multiples` of (what we define as) the
+/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
+///
+/// vector<2x2x3x4x7x11xi8> to
+/// vector<8x6x7x11xi8>
+/// | ||||
+/// | ++++------------> common suffix of 7x11
+/// +-----------------> gcd(4,6) is 2 | |
+/// | | |
+/// v v v
+/// atomic shape <----- 2x7x11
+///
+///
+///
+/// The decomposition implemented in this pattern consists of a sequence of
+/// repeated steps:
+///
+/// (1) Extract vectors from the suffix of the source.
+/// In our example this is 2x2x3x4x7x11 -> 4x7x11.
+///
+/// (2) Do extract_strided_slice down to the atomic shape.
+/// In our example this is 4x7x11 -> 2x7x11.
+///
+/// (3) Do insert_strided_slice to the suffix of the result.
+/// In our example this is 2x7x11 -> 6x7x11.
+///
+/// (4) insert these vectors into the result vector.
+/// In our example this is 6x7x11 -> 8x6x7x11.
+///
+/// These steps occur with different periods. In this example
+/// (1) occurs 12 times,
+/// (2) and (3) occur 24 times, and
+/// (4) occurs 8 times.
+///
+/// Two special cases are handled independently in this pattern
+/// (i) A shape_cast that just does leading 1 insertion/removal
+/// (ii) A shape_cast where the gcd is 1.
+///
+/// These 2 cases can have more compact IR generated by not using the generic
+/// algorithm described above.
+///
+class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
- return failure();
-
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if (srcRank < 2 || resRank != 1)
- return failure();
-
- // Compute the number of 1-D vector elements involved in the reshape.
- int64_t numElts = 1;
- for (int64_t dim = 0; dim < srcRank - 1; ++dim)
- numElts *= sourceVectorType.getDimSize(dim);
-
- auto loc = op.getLoc();
- SmallVector<int64_t> srcIdx(srcRank - 1, 0);
- SmallVector<int64_t> resIdx(resRank, 0);
- int64_t extractSize = sourceVectorType.getShape().back();
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
-
- // Compute the indices of each 1-D vector element of the source extraction
- // and destination slice insertion and generate such instructions.
- for (int64_t i = 0; i < numElts; ++i) {
- if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/1);
- incIdx(resIdx, resultVectorType, /*step=*/extractSize);
- }
-
- Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, extract, result,
- /*offsets=*/resIdx, /*strides=*/1);
+ Location loc = op.getLoc();
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType resultType = op.getResultVectorType();
+
+ if (sourceType.isScalable() || resultType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op,
+ "shape_cast where vectors are scalable not handled by this pattern");
+
+ const ArrayRef<int64_t> sourceShape = sourceType.getShape();
+ const ArrayRef<int64_t> resultShape = resultType.getShape();
+ const int64_t sourceRank = sourceType.getRank();
+ const int64_t resultRank = resultType.getRank();
+ const int64_t numElms = sourceType.getNumElements();
+ const Value source = op.getSource();
+
+ // Set the first dimension (starting at the end) in the source and result
+ // respectively where the dimension sizes differ. Using the running example:
+ //
+ // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ]
+ // shapes: (2,2,3,4,7,11) -> (8,6,7,11)
+ // ^ ^
+ // | |
+ // sourceSuffixStartDim is 3 |
+ // |
+ // resultSuffixStartDim is 1
+ int64_t sourceSuffixStartDim = sourceRank - 1;
+ int64_t resultSuffixStartDim = resultRank - 1;
+ while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
+ (sourceType.getDimSize(sourceSuffixStartDim) ==
+ resultType.getDimSize(resultSuffixStartDim))) {
+ --sourceSuffixStartDim;
+ --resultSuffixStartDim;
}
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
-/// vectors progressively. This iterates over the n-1 major dimension of the n-D
-/// vector and performs rewrites into:
-/// vector.extract_strided_slice from 1-D + vector.insert into n-D
-/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
-class ShapeCastOpNDUpCastRewritePattern
- : public OpRewritePattern<vector::ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
+ // This is the case (i) where there are just some leading ones to contend
+ // with in the source or result. It can be handled with a single
+ // extract/insert pair.
+ if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) {
+ const int64_t delta = sourceRank - resultRank;
+ const int64_t sourceLeading = delta > 0 ? delta : 0;
+ const int64_t resultLeading = delta > 0 ? 0 : -delta;
+ const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType);
+ const Value extracted = rewriter.create<vector::ExtractOp>(
+ loc, source, SmallVector<int64_t>(sourceLeading, 0));
+ const Value result = rewriter.create<vector::InsertOp>(
+ loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0));
+ rewriter.replaceOp(op, result);
+ return success();
+ }
- LogicalResult matchAndRewrite(vector::ShapeCastOp op,
- PatternRewriter &rewriter) const override {
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
- return failure();
-
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if (srcRank != 1 || resRank < 2)
- return failure();
-
- // Compute the number of 1-D vector elements involved in the reshape.
- int64_t numElts = 1;
- for (int64_t dim = 0; dim < resRank - 1; ++dim)
- numElts *= resultVectorType.getDimSize(dim);
-
- // Compute the indices of each 1-D vector element of the source slice
- // extraction and destination insertion and generate such instructions.
- auto loc = op.getLoc();
- SmallVector<int64_t> srcIdx(srcRank, 0);
- SmallVector<int64_t> resIdx(resRank - 1, 0);
- int64_t extractSize = resultVectorType.getShape().back();
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
- for (int64_t i = 0; i < numElts; ++i) {
- if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
- incIdx(resIdx, resultVectorType, /*step=*/1);
+ const int64_t sourceSuffixStartDimSize =
+ sourceType.getDimSize(sourceSuffixStartDim);
+ const int64_t resultSuffixStartDimSize =
+ resultType.getDimSize(resultSuffixStartDim);
+ const int64_t greatestCommonDivisor =
+ std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
+ const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
+ const size_t extractPeriod =
+ sourceSuffixStartDimSize / greatestCommonDivisor;
+ const size_t insertPeriod =
+ resultSuffixStartDimSize / greatestCommonDivisor;
+
+ SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
+ sourceShape.end());
+ atomicShape[0] = greatestCommonDivisor;
+
+ const int64_t numAtomicElms = std::accumulate(
+ atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
+ const size_t nAtomicSlices = numElms / numAtomicElms;
+
+ // This is the case (ii) where the strided dimension size is 1. More compact
+ // IR is generated in this case if we just extract and insert the elements
+ // directly. In other words, we don't use extract_strided_slice and
+ // insert_strided_slice.
+ if (greatestCommonDivisor == 1) {
----------------
newling wrote:
Yes, looks a bit better now IMO
https://github.com/llvm/llvm-project/pull/140800
More information about the Mlir-commits
mailing list