[Mlir-commits] [mlir] [MLIR][Vector] Fix bug in ExtractStrideSlicesOp canonicalization (PR #147591)

Tomás Longeri llvmlistbot at llvm.org
Tue Jul 8 13:01:12 PDT 2025


https://github.com/tlongeri created https://github.com/llvm/llvm-project/pull/147591

The pattern would produce an invalid slice when some dimensions were both sliced and broadcast.

>From ff1d804ae764020365d8ef4eb12453120367a7be Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= <tlongeri at google.com>
Date: Tue, 8 Jul 2025 19:55:23 +0000
Subject: [PATCH] [MLIR][Vector] Fix bug in ExtractStrideSlicesOp
 canonicalization

The pattern would produce an invalid slice when some dimensions were
both sliced and broadcast.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 39 +++++++++++++---------
 mlir/test/Dialect/Vector/canonicalize.mlir | 14 ++++++++
 2 files changed, 37 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 214d2ba7e1b8e..2f5b831b3c40b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4169,28 +4169,35 @@ class StridedSliceBroadcast final
     auto dstVecType = llvm::cast<VectorType>(op.getType());
     unsigned dstRank = dstVecType.getRank();
     unsigned rankDiff = dstRank - srcRank;
-    // Check if the most inner dimensions of the source of the broadcast are the
-    // same as the destination of the extract. If this is the case we can just
-    // use a broadcast as the original dimensions are untouched.
-    bool lowerDimMatch = true;
+    // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
+    // (n -> m with n > m). If they are originally both broadcasted *and*
+    // sliced, this can be simplified to just broadcasting.
+    bool needsSlice = false;
     for (unsigned i = 0; i < srcRank; i++) {
-      if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
-        lowerDimMatch = false;
+      if (srcVecType.getDimSize(i) != 1 &&
+          srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
+        needsSlice = true;
         break;
       }
     }
     Value source = broadcast.getSource();
-    // If the inner dimensions don't match, it means we need to extract from the
-    // source of the orignal broadcast and then broadcast the extracted value.
-    // We also need to handle degenerated cases where the source is effectively
-    // just a single scalar.
-    bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
-    if (!lowerDimMatch && !isScalarSrc) {
+    if (needsSlice) {
+      SmallVector<int64_t> offsets =
+          getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
+      SmallVector<int64_t> sizes =
+          getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
+      for (unsigned i = 0; i < srcRank; i++) {
+        if (srcVecType.getDimSize(i) == 1) {
+          // In case this dimension was broadcasted *and* sliced, the offset
+          // and size need to be updated now that there is no broadcast before
+          // the slice.
+          offsets[i] = 0;
+          sizes[i] = 1;
+        }
+      }
       source = rewriter.create<ExtractStridedSliceOp>(
-          op->getLoc(), source,
-          getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
-          getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
-          getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
+          op->getLoc(), source, offsets, sizes,
+          getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
     }
     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
     return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8a9e27378df61..dfa2e1c2a5a24 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1344,6 +1344,20 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @extract_strided_broadcast5
+//  CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
+//       CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
+//       CHECK: return %[[V]]
+func.func @extract_strided_broadcast5(%arg0: vector<2x1xf32>) -> vector<2x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x8xf32>
+ %1 = vector.extract_strided_slice %0
+      {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]}
+      : vector<2x8xf32> to vector<2x4xf32>
+  return %1 : vector<2x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: consecutive_shape_cast
 //       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
 //  CHECK-NEXT:   return %[[C]] : vector<4x4xf16>



More information about the Mlir-commits mailing list