[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)

James Newling llvmlistbot at llvm.org
Tue May 20 13:54:00 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/140800

Before this PR, vector.shape_casts without rank>1 source and result was lowered to _elementwise_ extracts/inserts, so that a shape_cast on a vector with N elements would always require N extracts/inserts. While this is necessary in the worst case scenario it is sometimes possible to use fewer larger extracts/inserts. Specifically, the largest common suffix on the shapes of the source and result can be extracted/inserted. For example:
```mlir
%0 = vector.shape_cast %arg0 : vector<10x2x3xf32> to vector<2x5x2x3xf32>
```
before this would be lowered to 60 extract/insert pairs with extracts of the form 
`vector.extract %arg0 [a, b, c] : f32 from vector<10x2x3xf32>` 
but with this PR it is 10 extract/insert pairs with extracts of the form 
`vector.extract %arg0 [a] : vector<2x3xf32> from vector<10x2x3xf32>`. 

This case first mentioned here: https://github.com/llvm/llvm-project/pull/138777#issuecomment-2874151059

>From ff92faaf649ddc93e422da2e32a8590149f9f319 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 20 May 2025 13:07:23 -0700
Subject: [PATCH] extract as large a chunk as possible in shape_cast lowering

---
 .../Transforms/LowerVectorShapeCast.cpp       | 80 +++++++++++--------
 ...vector-shape-cast-lowering-transforms.mlir | 51 ++++++++++++
 2 files changed, 99 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 23324a007377e..d0085bffca23c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -28,17 +28,20 @@ using namespace mlir;
 using namespace mlir::vector;
 
 /// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
                    int step = 1) {
   for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
-    assert(indices[dim] < vecType.getDimSize(dim) &&
-           "Indices are out of bound");
+    int64_t dimSize = shape[dim];
+    assert(indices[dim] < dimSize && "Indices are out of bound");
+
     indices[dim] += step;
-    if (indices[dim] < vecType.getDimSize(dim))
+
+    int64_t spill = indices[dim] / dimSize;
+    if (spill == 0)
       break;
 
-    indices[dim] = 0;
-    step = 1;
+    indices[dim] %= dimSize;
+    step = spill;
   }
 }
 
@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
     // and destination slice insertion and generate such instructions.
     for (int64_t i = 0; i < numElts; ++i) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, /*step=*/1);
-        incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+        incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
+        incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
       }
 
       Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
     Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
     for (int64_t i = 0; i < numElts; ++i) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
-        incIdx(resIdx, resultVectorType, /*step=*/1);
+        incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
+        incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
       }
 
       Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
+    VectorType sourceType = op.getSourceVectorType();
+    VectorType resultType = op.getResultVectorType();
 
-    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+    if (sourceType.isScalable() || resultType.isScalable())
       return failure();
 
-    // Special case for n-D / 1-D lowerings with better implementations.
-    int64_t srcRank = sourceVectorType.getRank();
-    int64_t resRank = resultVectorType.getRank();
-    if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
+    // Special case for n-D / 1-D lowerings with implementations that use
+    // extract_strided_slice / insert_strided_slice.
+    int64_t sourceRank = sourceType.getRank();
+    int64_t resultRank = resultType.getRank();
+    if ((sourceRank > 1 && resultRank == 1) ||
+        (sourceRank == 1 && resultRank > 1))
       return failure();
 
-    // Generic ShapeCast lowering path goes all the way down to unrolled scalar
-    // extract/insert chains.
-    int64_t numElts = 1;
-    for (int64_t r = 0; r < srcRank; r++)
-      numElts *= sourceVectorType.getDimSize(r);
+    int64_t numExtracts = sourceType.getNumElements();
+    int64_t nbCommonInnerDims = 0;
+    while (true) {
+      int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
+      int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
+      if (sourceDim < 0 || resultDim < 0)
+        break;
+      int64_t dimSize = sourceType.getDimSize(sourceDim);
+      if (dimSize != resultType.getDimSize(resultDim))
+        break;
+      numExtracts /= dimSize;
+      ++nbCommonInnerDims;
+    }
+
     // Replace with data movement operations:
     //    x[0,0,0] = y[0,0]
     //    x[0,0,1] = y[0,1]
     //    x[0,1,0] = y[0,2]
     // etc., incrementing the two index vectors "row-major"
     // within the source and result shape.
-    SmallVector<int64_t> srcIdx(srcRank, 0);
-    SmallVector<int64_t> resIdx(resRank, 0);
-    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
-    for (int64_t i = 0; i < numElts; i++) {
+    SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
+    SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+
+    for (int64_t i = 0; i < numExtracts; i++) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType);
-        incIdx(resIdx, resultVectorType);
+        incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
+        incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
       }
 
       Value extract =
-          rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
+      result =
+          rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
 
       // 4. Increment the insert/extract indices, stepping by minExtractionSize
       // for the trailing dimensions.
-      incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
-      incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
+      incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
+      incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
     }
 
     rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..044a73d71e665 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -140,6 +140,57 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
   return %s : vector<f32>
 }
 
+
+// CHECK-LABEL:  func.func @shape_cast_squeeze_leading_one(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK:          %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
+// CHECK-SAME:               vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK:          return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @shape_cast_squeeze_middle_one(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK:         %[[UB:.*]] = ub.poison : vector<2x3xf32>
+// CHECK:         %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
+// CHECK:         %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
+// CHECK-SAME:                 into vector<2x3xf32>
+// CHECK:         %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
+// CHECK:         %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
+// CHECK-SAME:                 into vector<2x3xf32>
+// CHECK:         return %[[I1]] : vector<2x3xf32>
+func.func @shape_cast_squeeze_middle_one(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @shape_cast_unsqueeze_leading_one(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK:          %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+// CHECK:          %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
+// CHECK-SAME:               : vector<2x3xf32> into vector<1x2x3xf32>
+// CHECK:        return %[[INSERTED]] : vector<1x2x3xf32>
+func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
+  return %s : vector<1x2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @shape_cast_unsqueeze_middle_one(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK:           %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+// CHECK:           %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
+// CHECK:           %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+// CHECK:           %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
+// CHECK:           %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+// CHECK:           return %[[I1]] : vector<2x1x3xf32>
+func.func @shape_cast_unsqueeze_middle_one(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
+  return %s : vector<2x1x3xf32>
+}
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op



More information about the Mlir-commits mailing list