[Mlir-commits] [mlir] [MLIR][Vector] Fix bug in ExtractStrideSlicesOp canonicalization (PR #147591)
Tomás Longeri
llvmlistbot at llvm.org
Tue Jul 8 13:01:12 PDT 2025
https://github.com/tlongeri created https://github.com/llvm/llvm-project/pull/147591
The pattern would produce an invalid slice when some dimensions were both sliced and broadcast.
>From ff1d804ae764020365d8ef4eb12453120367a7be Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= <tlongeri at google.com>
Date: Tue, 8 Jul 2025 19:55:23 +0000
Subject: [PATCH] [MLIR][Vector] Fix bug in ExtractStrideSlicesOp
canonicalization
The pattern would produce an invalid slice when some dimensions were
both sliced and broadcast.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 39 +++++++++++++---------
mlir/test/Dialect/Vector/canonicalize.mlir | 14 ++++++++
2 files changed, 37 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 214d2ba7e1b8e..2f5b831b3c40b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4169,28 +4169,35 @@ class StridedSliceBroadcast final
auto dstVecType = llvm::cast<VectorType>(op.getType());
unsigned dstRank = dstVecType.getRank();
unsigned rankDiff = dstRank - srcRank;
- // Check if the most inner dimensions of the source of the broadcast are the
- // same as the destination of the extract. If this is the case we can just
- // use a broadcast as the original dimensions are untouched.
- bool lowerDimMatch = true;
+ // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
+ // (n -> m with n > m). If they are originally both broadcasted *and*
+ // sliced, this can be simplified to just broadcasting.
+ bool needsSlice = false;
for (unsigned i = 0; i < srcRank; i++) {
- if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
- lowerDimMatch = false;
+ if (srcVecType.getDimSize(i) != 1 &&
+ srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
+ needsSlice = true;
break;
}
}
Value source = broadcast.getSource();
- // If the inner dimensions don't match, it means we need to extract from the
- // source of the orignal broadcast and then broadcast the extracted value.
- // We also need to handle degenerated cases where the source is effectively
- // just a single scalar.
- bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
- if (!lowerDimMatch && !isScalarSrc) {
+ if (needsSlice) {
+ SmallVector<int64_t> offsets =
+ getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
+ SmallVector<int64_t> sizes =
+ getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
+ for (unsigned i = 0; i < srcRank; i++) {
+ if (srcVecType.getDimSize(i) == 1) {
+ // In case this dimension was broadcasted *and* sliced, the offset
+ // and size need to be updated now that there is no broadcast before
+ // the slice.
+ offsets[i] = 0;
+ sizes[i] = 1;
+ }
+ }
source = rewriter.create<ExtractStridedSliceOp>(
- op->getLoc(), source,
- getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
- getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
- getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
+ op->getLoc(), source, offsets, sizes,
+ getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
}
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8a9e27378df61..dfa2e1c2a5a24 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1344,6 +1344,20 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
// -----
+// CHECK-LABEL: func @extract_strided_broadcast5
+// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
+// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
+// CHECK: return %[[V]]
+func.func @extract_strided_broadcast5(%arg0: vector<2x1xf32>) -> vector<2x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x8xf32>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]}
+ : vector<2x8xf32> to vector<2x4xf32>
+ return %1 : vector<2x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: consecutive_shape_cast
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
More information about the Mlir-commits
mailing list