[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)

James Newling llvmlistbot at llvm.org
Wed Jun 4 15:50:42 PDT 2025


================
@@ -21,177 +21,259 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include <numeric>
 
 #define DEBUG_TYPE "vector-shape-cast-lowering"
 
 using namespace mlir;
-using namespace mlir::vector;
-
-/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
-                   int step = 1) {
-  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
-    assert(indices[dim] < vecType.getDimSize(dim) &&
-           "Indices are out of bound");
-    indices[dim] += step;
-    if (indices[dim] < vecType.getDimSize(dim))
-      break;
 
-    indices[dim] = 0;
-    step = 1;
+/// Perform the inplace update
+///    rhs <- lhs + rhs
+///
+/// where `rhs` is a number expressed in mixed base `base` with most signficant
+/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
+/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
+///
+/// Some examples where `base` is {5,3,2}:
+/// rhs = {0,0,0}, lhs = 1  --> rhs = {0,0,1}
+/// rhs = {0,0,1}, lhs = 1  --> rhs = {0,1,0}
+/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
+///
+/// Invalid:
+/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
+///
+/// Overflows not handled correctly:
+/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
+static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
+                       MutableArrayRef<int64_t> rhs) {
+
+  // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
+  for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
+    int64_t dimBase = base[dim];
+    assert(rhs[dim] < dimBase && "rhs not in base");
+
+    int64_t incremented = rhs[dim] + lhs;
+
+    // If the incremented value excedes the dimension base, we must spill to the
+    // next most significant dimension and repeat (we might need to spill to
+    // more significant dimensions multiple times).
+    lhs = incremented / dimBase;
+    rhs[dim] = incremented % dimBase;
+    if (lhs == 0)
+      break;
   }
 }
 
 namespace {
-/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
-/// vectors progressively. This iterates over the n-1 major dimensions of the
-/// n-D vector and performs rewrites into:
-///   vector.extract from n-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOpNDDownCastRewritePattern
-    : public OpRewritePattern<vector::ShapeCastOp> {
+
+/// shape_cast is converted to a sequence of extract, extract_strided_slice,
+/// insert_strided_slice, and insert operations. The running example will be:
+///
+/// %0 = vector.shape_cast %arg0 :
+///         vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
+///
+/// In this example the source and result shapes share a common suffix of 7x11.
+/// This means we can always decompose the shape_cast into extract, insert, and
+/// their strided equivalents, on vectors with shape suffix 7x11.
+///
+/// The greatest common divisor (gcd) of the first dimension preceding the
+/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
+/// on vectors with shapes that are `multiples` of (what we define as) the
+/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
+///
+///         vector<2x2x3x4x7x11xi8> to
+///             vector<8x6x7x11xi8>
+///                      | ||||
+///                      | ++++------------> common suffix of 7x11
+///                      +----------------->    gcd(4,6) is 2 | |
+///                                                         | | |
+///                                                         v v v
+///                                 atomic shape   <-----   2x7x11
+///
+///
+///
+/// The decomposition implemented in this pattern consists of a sequence of
+/// repeated steps:
+///
+///  (1) Extract vectors from the suffix of the source.
+///      In our example this is 2x2x3x4x7x11 -> 4x7x11.
+///
+///  (2) Do extract_strided_slice down to the atomic shape.
+///      In our example this is 4x7x11 -> 2x7x11.
+///
+///  (3) Do insert_strided_slice to the suffix of the result.
+///      In our example this is 2x7x11 -> 6x7x11.
+///
+///  (4) insert these vectors into the result vector.
+///      In our example this is 6x7x11 -> 8x6x7x11.
+///
+/// These steps occur with different periods. In this example
+///  (1) occurs 12 times,
+///  (2) and (3) occur 24 times, and
+///  (4) occurs 8 times.
+///
+/// Two special cases are handled independently in this pattern
+///  (i) A shape_cast that just does leading 1 insertion/removal
+/// (ii) A shape_cast where the gcd is 1.
+///
+/// These 2 cases can have more compact IR generated by not using the generic
+/// algorithm described above.
+///
+class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                 PatternRewriter &rewriter) const override {
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
-    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
-      return failure();
-
-    int64_t srcRank = sourceVectorType.getRank();
-    int64_t resRank = resultVectorType.getRank();
-    if (srcRank < 2 || resRank != 1)
-      return failure();
-
-    // Compute the number of 1-D vector elements involved in the reshape.
-    int64_t numElts = 1;
-    for (int64_t dim = 0; dim < srcRank - 1; ++dim)
-      numElts *= sourceVectorType.getDimSize(dim);
-
-    auto loc = op.getLoc();
-    SmallVector<int64_t> srcIdx(srcRank - 1, 0);
-    SmallVector<int64_t> resIdx(resRank, 0);
-    int64_t extractSize = sourceVectorType.getShape().back();
-    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
-
-    // Compute the indices of each 1-D vector element of the source extraction
-    // and destination slice insertion and generate such instructions.
-    for (int64_t i = 0; i < numElts; ++i) {
-      if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, /*step=*/1);
-        incIdx(resIdx, resultVectorType, /*step=*/extractSize);
-      }
-
-      Value extract =
-          rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, extract, result,
-          /*offsets=*/resIdx, /*strides=*/1);
+    Location loc = op.getLoc();
+    VectorType sourceType = op.getSourceVectorType();
+    VectorType resultType = op.getResultVectorType();
+
+    if (sourceType.isScalable() || resultType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op,
+          "shape_cast where vectors are scalable not handled by this pattern");
+
+    const ArrayRef<int64_t> sourceShape = sourceType.getShape();
+    const ArrayRef<int64_t> resultShape = resultType.getShape();
+    const int64_t sourceRank = sourceType.getRank();
+    const int64_t resultRank = resultType.getRank();
+    const int64_t numElms = sourceType.getNumElements();
+    const Value source = op.getSource();
+
+    // Set the first dimension (starting at the end) in the source and result
+    // respectively where the dimension sizes differ. Using the running example:
+    //
+    //  dimensions:  [0 1 2 3 4 5 ]    [0 1 2 3 ]
+    //  shapes:      (2,2,3,4,7,11) -> (8,6,7,11)
+    //                      ^             ^
+    //                      |             |
+    //        sourceSuffixStartDim is 3   |
+    //                                    |
+    //                               resultSuffixStartDim is 1
+    int64_t sourceSuffixStartDim = sourceRank - 1;
+    int64_t resultSuffixStartDim = resultRank - 1;
+    while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
+           (sourceType.getDimSize(sourceSuffixStartDim) ==
+            resultType.getDimSize(resultSuffixStartDim))) {
+      --sourceSuffixStartDim;
+      --resultSuffixStartDim;
     }
 
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
-/// vectors progressively. This iterates over the n-1 major dimension of the n-D
-/// vector and performs rewrites into:
-///   vector.extract_strided_slice from 1-D + vector.insert into n-D
-/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
-class ShapeCastOpNDUpCastRewritePattern
-    : public OpRewritePattern<vector::ShapeCastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
+    // This is the case (i) where there are just some leading ones to contend
+    // with in the source or result. It can be handled with a single
+    // extract/insert pair.
+    if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) {
+      const int64_t delta = sourceRank - resultRank;
+      const int64_t sourceLeading = delta > 0 ? delta : 0;
+      const int64_t resultLeading = delta > 0 ? 0 : -delta;
+      const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType);
+      const Value extracted = rewriter.create<vector::ExtractOp>(
+          loc, source, SmallVector<int64_t>(sourceLeading, 0));
+      const Value result = rewriter.create<vector::InsertOp>(
+          loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0));
+      rewriter.replaceOp(op, result);
+      return success();
+    }
 
-  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
-                                PatternRewriter &rewriter) const override {
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
-    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
-      return failure();
-
-    int64_t srcRank = sourceVectorType.getRank();
-    int64_t resRank = resultVectorType.getRank();
-    if (srcRank != 1 || resRank < 2)
-      return failure();
-
-    // Compute the number of 1-D vector elements involved in the reshape.
-    int64_t numElts = 1;
-    for (int64_t dim = 0; dim < resRank - 1; ++dim)
-      numElts *= resultVectorType.getDimSize(dim);
-
-    // Compute the indices of each 1-D vector element of the source slice
-    // extraction and destination insertion and generate such instructions.
-    auto loc = op.getLoc();
-    SmallVector<int64_t> srcIdx(srcRank, 0);
-    SmallVector<int64_t> resIdx(resRank - 1, 0);
-    int64_t extractSize = resultVectorType.getShape().back();
-    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
-    for (int64_t i = 0; i < numElts; ++i) {
-      if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
-        incIdx(resIdx, resultVectorType, /*step=*/1);
+    const int64_t sourceSuffixStartDimSize =
+        sourceType.getDimSize(sourceSuffixStartDim);
+    const int64_t resultSuffixStartDimSize =
+        resultType.getDimSize(resultSuffixStartDim);
+    const int64_t greatestCommonDivisor =
+        std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
+    const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
+    const size_t extractPeriod =
+        sourceSuffixStartDimSize / greatestCommonDivisor;
+    const size_t insertPeriod =
+        resultSuffixStartDimSize / greatestCommonDivisor;
+
+    SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
+                                     sourceShape.end());
+    atomicShape[0] = greatestCommonDivisor;
+
+    const int64_t numAtomicElms = std::accumulate(
+        atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
+    const size_t nAtomicSlices = numElms / numAtomicElms;
+
+    // This is the case (ii) where the strided dimension size is 1. More compact
+    // IR is generated in this case if we just extract and insert the elements
+    // directly. In other words, we don't use extract_strided_slice and
+    // insert_strided_slice.
+    if (greatestCommonDivisor == 1) {
----------------
newling wrote:

Yes, looks a bit better now IMO

https://github.com/llvm/llvm-project/pull/140800


More information about the Mlir-commits mailing list