[Mlir-commits] [mlir] [mlir][vector] Improve shape_cast lowering (PR #140800)
James Newling
llvmlistbot at llvm.org
Fri May 23 09:59:22 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/140800
>From 49ebfa731b9168cad0946b6e43229ae0ce064bd7 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 20 May 2025 13:07:23 -0700
Subject: [PATCH 1/2] extract as large a chunk as possible in shape_cast
lowering
---
.../Transforms/LowerVectorShapeCast.cpp | 80 +++++++++++--------
...vector-shape-cast-lowering-transforms.mlir | 51 ++++++++++++
2 files changed, 99 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 23324a007377e..d0085bffca23c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -28,17 +28,20 @@ using namespace mlir;
using namespace mlir::vector;
/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
int step = 1) {
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
- assert(indices[dim] < vecType.getDimSize(dim) &&
- "Indices are out of bound");
+ int64_t dimSize = shape[dim];
+ assert(indices[dim] < dimSize && "Indices are out of bound");
+
indices[dim] += step;
- if (indices[dim] < vecType.getDimSize(dim))
+
+ int64_t spill = indices[dim] / dimSize;
+ if (spill == 0)
break;
- indices[dim] = 0;
- step = 1;
+ indices[dim] %= dimSize;
+ step = spill;
}
}
@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
// and destination slice insertion and generate such instructions.
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/1);
- incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
}
Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
- incIdx(resIdx, resultVectorType, /*step=*/1);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
}
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType resultType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+ if (sourceType.isScalable() || resultType.isScalable())
return failure();
- // Special case for n-D / 1-D lowerings with better implementations.
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
+ // Special case for n-D / 1-D lowerings with implementations that use
+ // extract_strided_slice / insert_strided_slice.
+ int64_t sourceRank = sourceType.getRank();
+ int64_t resultRank = resultType.getRank();
+ if ((sourceRank > 1 && resultRank == 1) ||
+ (sourceRank == 1 && resultRank > 1))
return failure();
- // Generic ShapeCast lowering path goes all the way down to unrolled scalar
- // extract/insert chains.
- int64_t numElts = 1;
- for (int64_t r = 0; r < srcRank; r++)
- numElts *= sourceVectorType.getDimSize(r);
+ int64_t numExtracts = sourceType.getNumElements();
+ int64_t nbCommonInnerDims = 0;
+ while (true) {
+ int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
+ int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
+ if (sourceDim < 0 || resultDim < 0)
+ break;
+ int64_t dimSize = sourceType.getDimSize(sourceDim);
+ if (dimSize != resultType.getDimSize(resultDim))
+ break;
+ numExtracts /= dimSize;
+ ++nbCommonInnerDims;
+ }
+
// Replace with data movement operations:
// x[0,0,0] = y[0,0]
// x[0,0,1] = y[0,1]
// x[0,1,0] = y[0,2]
// etc., incrementing the two index vectors "row-major"
// within the source and result shape.
- SmallVector<int64_t> srcIdx(srcRank, 0);
- SmallVector<int64_t> resIdx(resRank, 0);
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
- for (int64_t i = 0; i < numElts; i++) {
+ SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
+ SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+
+ for (int64_t i = 0; i < numExtracts; i++) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType);
- incIdx(resIdx, resultVectorType);
+ incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
+ incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
}
Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
+ result =
+ rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
}
rewriter.replaceOp(op, result);
return success();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
// 4. Increment the insert/extract indices, stepping by minExtractionSize
// for the trailing dimensions.
- incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
- incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..044a73d71e665 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -140,6 +140,57 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
return %s : vector<f32>
}
+
+// CHECK-LABEL: func.func @shape_cast_squeeze_leading_one(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
+// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func.func @shape_cast_squeeze_middle_one(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
+// CHECK-SAME: into vector<2x3xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
+// CHECK-SAME: into vector<2x3xf32>
+// CHECK: return %[[I1]] : vector<2x3xf32>
+func.func @shape_cast_squeeze_middle_one(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func.func @shape_cast_unsqueeze_leading_one(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+// CHECK: %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
+// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
+// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
+func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
+ return %s : vector<1x2x3xf32>
+}
+
+// CHECK-LABEL: func.func @shape_cast_unsqueeze_middle_one(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+// CHECK: return %[[I1]] : vector<2x1x3xf32>
+func.func @shape_cast_unsqueeze_middle_one(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
+ return %s : vector<2x1x3xf32>
+}
+
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
>From 1a4b759e753c5649aceb0a12cfd65f06a3d35a82 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 22 May 2025 11:54:33 -0700
Subject: [PATCH 2/2] test name improvements
---
.../vector-shape-cast-lowering-transforms.mlir | 18 ++++++++++--------
1 file changed, 10 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index 044a73d71e665..2875f159a2df9 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -141,17 +141,19 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
}
-// CHECK-LABEL: func.func @shape_cast_squeeze_leading_one(
+// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
+// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
-func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
return %s : vector<2x3xf32>
}
-// CHECK-LABEL: func.func @shape_cast_squeeze_middle_one(
+// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
+// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
@@ -161,23 +163,23 @@ func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2
// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
// CHECK-SAME: into vector<2x3xf32>
// CHECK: return %[[I1]] : vector<2x3xf32>
-func.func @shape_cast_squeeze_middle_one(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
return %s : vector<2x3xf32>
}
-// CHECK-LABEL: func.func @shape_cast_unsqueeze_leading_one(
+// CHECK-LABEL: func.func @prepend_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
// CHECK: %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
-func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
return %s : vector<1x2x3xf32>
}
-// CHECK-LABEL: func.func @shape_cast_unsqueeze_middle_one(
+// CHECK-LABEL: func.func @insert_middle_unit_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
@@ -185,7 +187,7 @@ func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1
// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
// CHECK: return %[[I1]] : vector<2x1x3xf32>
-func.func @shape_cast_unsqueeze_middle_one(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
return %s : vector<2x1x3xf32>
}
More information about the Mlir-commits
mailing list