[Mlir-commits] [mlir] [mlir][tensor]-Handle Dynamic Offset in BubbleUpSliceOpThroughCollapse (PR #178921)
Amir Bishara
llvmlistbot at llvm.org
Wed Feb 4 01:58:29 PST 2026
https://github.com/amirBish updated https://github.com/llvm/llvm-project/pull/178921
>From 87d153955c0e45eb07f052bdf1f6687071ae85ab Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Thu, 29 Jan 2026 13:33:55 +0200
Subject: [PATCH 1/3] [mlir][tensor]-Handle Dynamic Offset in
BubbleUpSliceOpThroughCollapse
This patch extends the `BubbleUpExtractSliceThroughCollapseShape` pattern
to handle cases where `tensor.extract_slice` has a dynamic offset.
During tile and fuse transformations, it is common to encounter IR where
`tensor.extract_slice` operations appear after `tensor.collapse_shape`.
These patterns are used as cleanup transformations to canonicalize the IR
by bubbling up the slice operation before the reshape. This enables
further optimizations and simplifications downstream.
Previously, the pattern only handled:
1. Static offsets and sizes.
2. Dynamic sizes with a single non-unit expanded dimension.
This left a gap for additional common cases where we may have:
- Dynamic offsets with size == 1 (single element extraction).
- Size greater than 1 but the offset is computed dynamically.
Regarding the first case, It's always legal to perform such
a transformation.
Regarding the second case (i.e. dynamic offset with static
size > 1) we can guarantee contiguity by more restricted
conditions:
1. The innermost expanded dimension is divisible by the slice size
(ensures the slice fits within a single "row")
2. The offset is provably a multiple of the slice size
(ensures the slice starts at an aligned boundary)
For example, given:
- collapse_shape: tensor<6x10xf32> -> tensor<60xf32>
- extract_slice at offset with size
If 10 % size == 0 and offset % 5 == 0, the slice is guaranteed contiguous.
The transformation delinearizes the offset to get [row, col] indices
and extracts [1, size] from the expanded 6x10 shape.
The offset divisibility check uses affine analysis to statically verify
that affine expressions (e.g., `idx * 5`) are multiples of the slice size.
---
.../Tensor/Transforms/ReshapePatterns.cpp | 111 +++++++++++-
.../Tensor/bubble-up-extract-slice-op.mlir | 158 +++++++++++++++++-
2 files changed, 260 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index a53af98474245..d069c67daad81 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -579,6 +579,24 @@ LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
return success();
}
+// Checks if the `value` is a multiple of the `factor`.
+// Otherwise, returns false.
+// For now we are handling only the case where the value
+// is resulted from an affine.apply op, where affine ops
+// can be lowered to.
+static bool isValueMultipleOf(Value value, int64_t factor) {
+ auto applyOp = value.getDefiningOp<affine::AffineApplyOp>();
+ if (!applyOp)
+ return false;
+ AffineMap map = applyOp.getAffineMap();
+ SmallVector<Value> operands(applyOp.getOperands());
+ affine::fullyComposeAffineMapAndOperands(&map, &operands);
+ map = simplifyAffineMap(map);
+ if (map.getNumResults() != 1)
+ return false;
+ return map.getResult(0).isMultipleOf(factor);
+}
+
LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
OpBuilder &b, tensor::ExtractSliceOp sliceOp,
ArrayRef<ReassociationIndices> reassociation,
@@ -610,26 +628,103 @@ LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
for (auto [collapsedSize, collapsedOffset, reassocIndices] :
llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
// CASE #1 - size and/or offset are dynamic.
- // In this case, the slice can be represented as a contiguous slice only
- // if there is a single dimension in the reassociation group that has a
- // size not equal to 1.
if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
+ // When the size is dynamic, We will handle a simple case where there
+ // is a single dimension in the reassociation group that has a size
+ // not equal to 1, which guarantees it is a contiguous slice.
int nonUnitSizeCount = 0;
+ SmallVector<OpFoldResult> tentativeSizes, tentativeOffsets;
for (int64_t expandedShapeIdx : reassocIndices) {
if (expandedShape[expandedShapeIdx] != 1) {
nonUnitSizeCount++;
- expandedSizes.push_back(collapsedSize);
- expandedOffsets.push_back(collapsedOffset);
+ tentativeSizes.push_back(collapsedSize);
+ tentativeOffsets.push_back(collapsedOffset);
continue;
}
- expandedSizes.push_back(b.getIndexAttr(1));
- expandedOffsets.push_back(b.getIndexAttr(0));
+ tentativeSizes.push_back(b.getIndexAttr(1));
+ tentativeOffsets.push_back(b.getIndexAttr(0));
+ }
+
+ if (nonUnitSizeCount == 1) {
+ // Only one non-unit dimension, slice is contiguous.
+ expandedSizes.append(tentativeSizes);
+ expandedOffsets.append(tentativeOffsets);
+ continue;
}
- if (nonUnitSizeCount != 1) {
+ std::optional<int64_t> staticSize = getConstantIntValue(collapsedSize);
+ if (!staticSize.has_value())
return failure();
+
+ // If the size is statically 1, it means we're extracting a
+ // single element. In this case, we can always handle this by
+ // delinearizing the dynamic offset and get a contiguous slice.
+ if (staticSize.value() == 1) {
+ SmallVector<int64_t> basis;
+ for (int64_t expandedShapeIdx : reassocIndices) {
+ expandedSizes.push_back(b.getIndexAttr(1));
+ basis.push_back(expandedShape[expandedShapeIdx]);
+ }
+
+ Value offsetVal = getValueOrCreateConstantIndexOp(b, sliceOp.getLoc(),
+ collapsedOffset);
+ auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
+ b, sliceOp.getLoc(), offsetVal, basis, /*hasOuterBound=*/true);
+ for (auto result : delinearizeOp.getResults())
+ expandedOffsets.push_back(result);
+
+ continue;
}
+
+ // If the size is greater than 1, we will take a more
+ // restrictive case where both the offset and the innermost
+ // expanded dimension are divisible by the size.
+ //
+ // Example:
+ // collapse_shape tensor<6x10xf32> -> tensor<60xf32>
+ // extract_slice at offset_a with some size_b.
+ //
+ // Contiguity is guaranteed when:
+ // 1. innermost dim (10) % size_b == 0 i.e.
+ // the slice fits within a single row.
+ // 2. offset_a is multiple of size_b i.e.
+ // the slice starts at row-aligned boundary.
+ //
+ // Result: delinearize offset_a -> [row, col], extract [1, size_b]
+ // from expanded shape.
+ assert(staticSize.value() > 1 && "Expected size to be greater than 1");
+ int64_t sizeVal = staticSize.value();
+ int64_t innermostDim = expandedShape[reassocIndices.back()];
+
+ // Only try this path if offset is from affine.apply
+ Value offsetVal = cast<Value>(collapsedOffset);
+ if (!offsetVal.getDefiningOp<affine::AffineApplyOp>())
+ return failure();
+
+ if (innermostDim % sizeVal != 0)
+ return failure();
+
+ if (!isValueMultipleOf(offsetVal, sizeVal))
+ return failure();
+
+ SmallVector<int64_t> basis;
+ for (int64_t expandedShapeIdx : reassocIndices)
+ basis.push_back(expandedShape[expandedShapeIdx]);
+
+ auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
+ b, sliceOp.getLoc(), offsetVal, basis, /*hasOuterBound=*/true);
+
+ // Sizes: [1, 1, ..., 1, sizeVal] - size goes on innermost dimension.
+ for (size_t i = 0; i < reassocIndices.size() - 1; ++i)
+ expandedSizes.push_back(b.getIndexAttr(1));
+
+ expandedSizes.push_back(collapsedSize);
+
+ // Offsets from delinearization.
+ for (auto result : delinearizeOp.getResults())
+ expandedOffsets.push_back(result);
+
continue;
}
diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
index 34128d6a5ec8b..97157eb06fb96 100644
--- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
@@ -190,7 +190,6 @@ func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src(%
return %extract : tensor<?xf32>
}
-
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset(
// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x1xf32>,
// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<3xf32> {
@@ -229,6 +228,163 @@ func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_gro
return %extract : tensor<20x?xf32>
}
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_unit_size_dynamic_offset(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>,
+// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<1xf32> {
+// CHECK: %[[DELINEARIZE:.*]]:3 = affine.delinearize_index %[[OFFSET]] into (2, 3, 10)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[DELINEARIZE]]#2] [1, 1, 1] [1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_unit_size_dynamic_offset(%src: tensor<2x3x10xf32>, %offset : index) -> tensor<1xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+ %extract = tensor.extract_slice %collapse[%offset][1][1] : tensor<60xf32> to tensor<1xf32>
+ return %extract : tensor<1xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_unit_size_dynamic_offset_multi_group(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x4x5xf32>,
+// CHECK-SAME: %[[OFFSET0:.*]]: index,
+// CHECK-SAME: %[[OFFSET1:.*]]: index) -> tensor<1x1xf32> {
+// CHECK-DAG: %[[DELINEARIZE0:.*]]:2 = affine.delinearize_index %[[OFFSET0]] into (2, 3)
+// CHECK-DAG: %[[DELINEARIZE1:.*]]:2 = affine.delinearize_index %[[OFFSET1]] into (4, 5)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[DELINEARIZE0]]#0, %[[DELINEARIZE0]]#1, %[[DELINEARIZE1]]#0, %[[DELINEARIZE1]]#1] [1, 1, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_unit_size_dynamic_offset_multi_group(%src: tensor<2x3x4x5xf32>, %offset0 : index, %offset1 : index) -> tensor<1x1xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<2x3x4x5xf32> into tensor<6x20xf32>
+ %extract = tensor.extract_slice %collapse[%offset0, %offset1][1, 1][1, 1] : tensor<6x20xf32> to tensor<1x1xf32>
+ return %extract : tensor<1x1xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_affine_apply_offset(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<5xf32> {
+// CHECK: %[[OFFSET:.*]] = affine.apply
+// CHECK: %[[DELINEARIZE:.*]]:3 = affine.delinearize_index %[[OFFSET]] into (2, 3, 10)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[DELINEARIZE]]#2] [1, 1, 5] [1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_affine_apply_offset(%src: tensor<2x3x10xf32>, %idx : index) -> tensor<5xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+ %offset = affine.apply affine_map<()[s0] -> (s0 * 5)>()[%idx]
+ %extract = tensor.extract_slice %collapse[%offset][5][1] : tensor<60xf32> to tensor<5xf32>
+ return %extract : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_affine_apply_offset_full_inner_dim(
+// CHECK-SAME: %[[SRC:.*]]: tensor<4x5x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<8xf32> {
+// CHECK: %[[OFFSET:.*]] = affine.apply
+// CHECK: %[[DELINEARIZE:.*]]:3 = affine.delinearize_index %[[OFFSET]] into (4, 5, 8)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[DELINEARIZE]]#2] [1, 1, 8] [1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_affine_apply_offset_full_inner_dim(%src: tensor<4x5x8xf32>, %idx : index) -> tensor<8xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<4x5x8xf32> into tensor<160xf32>
+ // offset = idx * 8, size = 8, innermost dim = 8, 8 % 8 == 0, should work
+ %offset = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%idx]
+ %extract = tensor.extract_slice %collapse[%offset][8][1] : tensor<160xf32> to tensor<8xf32>
+ return %extract : tensor<8xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_affine_apply_complex_expr(
+// CHECK-SAME: %[[SRC:.*]]: tensor<3x4x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<4xf32> {
+// CHECK: %[[OFFSET:.*]] = affine.apply
+// CHECK: %[[DELINEARIZE:.*]]:3 = affine.delinearize_index %[[OFFSET]] into (3, 4, 8)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[DELINEARIZE]]#2] [1, 1, 4] [1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_affine_apply_complex_expr(%src: tensor<3x4x8xf32>, %idx : index) -> tensor<4xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<3x4x8xf32> into tensor<96xf32>
+ // offset = idx * 4, size = 4, innermost dim = 8, 8 % 4 == 0, should work
+ %offset = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%idx]
+ %extract = tensor.extract_slice %collapse[%offset][4][1] : tensor<96xf32> to tensor<4xf32>
+ return %extract : tensor<4xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_affine_apply_multi_group(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x4x6xf32>,
+// CHECK-SAME: %[[IDX0:.*]]: index,
+// CHECK-SAME: %[[IDX1:.*]]: index) -> tensor<3x2xf32> {
+// CHECK-DAG: %[[OFFSET0:.*]] = affine.apply {{.*}}{{\[}}%[[IDX0]]]
+// CHECK-DAG: %[[OFFSET1:.*]] = affine.apply {{.*}}{{\[}}%[[IDX1]]]
+// CHECK-DAG: %[[DELINEARIZE0:.*]]:2 = affine.delinearize_index %[[OFFSET0]] into (2, 3)
+// CHECK-DAG: %[[DELINEARIZE1:.*]]:2 = affine.delinearize_index %[[OFFSET1]] into (4, 6)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[DELINEARIZE0]]#0, %[[DELINEARIZE0]]#1, %[[DELINEARIZE1]]#0, %[[DELINEARIZE1]]#1] [1, 3, 1, 2] [1, 1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_affine_apply_multi_group(%src: tensor<2x3x4x6xf32>, %idx0 : index, %idx1 : index) -> tensor<3x2xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<2x3x4x6xf32> into tensor<6x24xf32>
+ // For first group: size=3, innermost=3, 3%3==0, offset=idx0*3, works
+ // For second group: size=2, innermost=6, 6%2==0, offset=idx1*2, works
+ %offset0 = affine.apply affine_map<()[s0] -> (s0 * 3)>()[%idx0]
+ %offset1 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%idx1]
+ %extract = tensor.extract_slice %collapse[%offset0, %offset1][3, 2][1, 1] : tensor<6x24xf32> to tensor<3x2xf32>
+ return %extract : tensor<3x2xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_affine_apply_nested(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x4x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<4xf32> {
+// CHECK: affine.apply
+// CHECK: %[[OFFSET:.*]] = affine.apply
+// CHECK: %[[DELINEARIZE:.*]]:3 = affine.delinearize_index %[[OFFSET]] into (2, 4, 8)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[DELINEARIZE]]#2] [1, 1, 4] [1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_affine_apply_nested(%src: tensor<2x4x8xf32>, %idx : index) -> tensor<4xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x4x8xf32> into tensor<64xf32>
+ // Nested affine.apply: offset = (idx * 2) * 2 = idx * 4, multiple of 4
+ %tmp = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%idx]
+ %offset = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%tmp]
+ %extract = tensor.extract_slice %collapse[%offset][4][1] : tensor<64xf32> to tensor<4xf32>
+ return %extract : tensor<4xf32>
+}
+
+// CHECK-LABEL: func.func @no_bubble_up_extract_slice_affine_apply_offset_not_multiple(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<4xf32> {
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK: %[[OFFSET:.*]] = affine.apply
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
+// CHECK: return %[[EXTRACT]]
+func.func @no_bubble_up_extract_slice_affine_apply_offset_not_multiple(%src: tensor<2x3x10xf32>, %idx : index) -> tensor<4xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+ // offset = idx * 3, but size = 4, 3 is not a multiple of 4
+ %offset = affine.apply affine_map<()[s0] -> (s0 * 3)>()[%idx]
+ %extract = tensor.extract_slice %collapse[%offset][4][1] : tensor<60xf32> to tensor<4xf32>
+ return %extract : tensor<4xf32>
+}
+
+// CHECK-LABEL: func.func @no_bubble_up_extract_slice_affine_apply_size_not_dividing_inner(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<3xf32> {
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK: %[[OFFSET:.*]] = affine.apply
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
+// CHECK: return %[[EXTRACT]]
+func.func @no_bubble_up_extract_slice_affine_apply_size_not_dividing_inner(%src: tensor<2x3x10xf32>, %idx : index) -> tensor<3xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+ // offset = idx * 3 (multiple of 3), but size = 3, innermost = 10, 10 % 3 != 0
+ %offset = affine.apply affine_map<()[s0] -> (s0 * 3)>()[%idx]
+ %extract = tensor.extract_slice %collapse[%offset][3][1] : tensor<60xf32> to tensor<3xf32>
+ return %extract : tensor<3xf32>
+}
+
+// CHECK-LABEL: func.func @no_bubble_up_extract_slice_dynamic_offset_not_affine_apply(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>,
+// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<5xf32> {
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
+// CHECK: return %[[EXTRACT]]
+func.func @no_bubble_up_extract_slice_dynamic_offset_not_affine_apply(%src: tensor<2x3x10xf32>, %offset : index) -> tensor<5xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+ // offset is a plain Value (not from affine.apply), can't prove it's a multiple of 5
+ %extract = tensor.extract_slice %collapse[%offset][5][1] : tensor<60xf32> to tensor<5xf32>
+ return %extract : tensor<5xf32>
+}
+
/// The 2 following tests are cases where the bubble up cannot occur because the contiguous size extracted
/// from the collapsed shape cannot be expressed via a single extract_slice op.
/// In the first test it is because the size extracted cannot be expressed as a slice
>From 30703b4e5081e8e9c316dac725bb5a5016291beb Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Tue, 3 Feb 2026 10:01:06 +0200
Subject: [PATCH 2/3] Unify the logics of the dynamic offset handling
---
.../Tensor/Transforms/ReshapePatterns.cpp | 364 +++++++++---------
.../Tensor/bubble-up-extract-slice-op.mlir | 15 +-
2 files changed, 194 insertions(+), 185 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index d069c67daad81..2328b311367f6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/LogicalResult.h"
@@ -579,12 +580,17 @@ LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
return success();
}
-// Checks if the `value` is a multiple of the `factor`.
-// Otherwise, returns false.
-// For now we are handling only the case where the value
-// is resulted from an affine.apply op, where affine ops
-// can be lowered to.
-static bool isValueMultipleOf(Value value, int64_t factor) {
+// Checks if the `ofr` is a multiple of the `factor`.
+// Handles both static integer and dynamic values
+// where the value is the result of an affine.apply.
+static bool isMultipleOf(OpFoldResult ofr, int64_t factor) {
+ std::optional<int64_t> staticValue = getConstantIntValue(ofr);
+ if (staticValue.has_value())
+ return staticValue.value() % factor == 0;
+
+ Value value = dyn_cast<Value>(ofr);
+ if (!value)
+ return false;
auto applyOp = value.getDefiningOp<affine::AffineApplyOp>();
if (!applyOp)
return false;
@@ -597,6 +603,147 @@ static bool isValueMultipleOf(Value value, int64_t factor) {
return map.getResult(0).isMultipleOf(factor);
}
+static LogicalResult computeExpandedSliceInfoForReassocGroup(
+ OpBuilder &b, OpFoldResult collapsedSize, OpFoldResult collapsedOffset,
+ const ReassociationIndices &reassocIndices, ArrayRef<int64_t> expandedShape,
+ SmallVectorImpl<OpFoldResult> &groupSizes,
+ SmallVectorImpl<OpFoldResult> &groupOffsets) {
+ assert(groupSizes.empty() && "Group sizes must be empty");
+ assert(groupOffsets.empty() && "Group offsets must be empty");
+ // The first case is when there's only one non-unit dimension in the
+ // reassociation group.
+ // When there's only one non-unit dimension, the slice is trivially
+ // contiguous - offset and size go directly on that dimension.
+ // This works for both dynamic size and dynamic offset.
+ int nonUnitSizeCount = llvm::count_if(
+ reassocIndices, [&expandedShape](int64_t expandedShapeIdx) {
+ return expandedShape[expandedShapeIdx] != 1;
+ });
+ if (nonUnitSizeCount == 1) {
+ for (int64_t expandedShapeIdx : reassocIndices) {
+ if (expandedShape[expandedShapeIdx] != 1) {
+ groupSizes.push_back(collapsedSize);
+ groupOffsets.push_back(collapsedOffset);
+ continue;
+ }
+ groupSizes.push_back(b.getIndexAttr(1));
+ groupOffsets.push_back(b.getIndexAttr(0));
+ }
+ return success();
+ }
+
+ // Having dynamic extracted size requires additional complex
+ // analysis to guarantee contiguous slicing.
+ if (isa<Value>(collapsedSize))
+ return failure();
+
+ std::optional<int64_t> staticSize = getConstantIntValue(collapsedSize);
+ assert(staticSize.has_value() && "Expected static size");
+
+ // The extracted size is only one element, offset may be static
+ // or dynamic, It's a trivial case where we always can guarantee
+ // contiguous slicing.
+ if (staticSize.value() == 1) {
+ SmallVector<int64_t> basis;
+ for (size_t i = 0; i < reassocIndices.size(); ++i)
+ groupSizes.push_back(b.getIndexAttr(1));
+
+ return success();
+ }
+
+ // Size is static and greater than 1, offset may be static or dynamic.
+ // Use traversal to find dimension k where slicing occurs.
+ // Verify that the slice can be represented as a contiguous slice of the
+ // src of the collapse_shape.
+ // Checking this is done on order of most internal dimensions first,
+ // so traversal is done in reverse order of the reassociation group.
+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
+ // ...,An] then we first find the size and offset for n...k+1 then for k
+ // and then for k-1...0.
+
+ // currentCollapsedsize is initialized with the original collapsed size
+ // and divided by the expanded shape size in each dimension as we go along
+ // the reassociation group. In essence we are spreading the original
+ // collapsed size over the various expanded slice dimensions.
+ // currentOffsetDivisor is initialized with 1 and multiplied by the expanded
+ // shape size in each dimension as we go along the reassociation group.
+ // These variables are used both to check the validity of the slice and to
+ // compute the expanded sizes and offsets.
+ assert(staticSize.value() > 1 && "Expected size to be greater than 1");
+ int64_t currentCollapsedsize = staticSize.value();
+ int64_t currentOffsetDivisor = 1;
+
+ ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
+ reassocIndices.rend());
+ int64_t idx = 0;
+ int64_t reassocGroupSize = reassocIndices.size();
+
+ SmallVector<OpFoldResult> groupExpandedSizes;
+
+ // First handle the trailing dimensions where the slice size should be
+ // equal to the tensor shape and the offset should be 0 (n...k+1).
+ for (; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+
+ if (currentCollapsedsize < expandedShapeSize)
+ break;
+
+ // Check size divisibility.
+ if ((currentCollapsedsize % expandedShapeSize) != 0)
+ return failure();
+
+ // Check dynamic/static offset divisibility.
+ currentOffsetDivisor *= expandedShapeSize;
+ if (!isMultipleOf(collapsedOffset, currentOffsetDivisor))
+ return failure();
+
+ // Trailing dims get full shape and zero offset.
+ groupSizes.push_back(b.getIndexAttr(expandedShapeSize));
+ currentCollapsedsize /= expandedShapeSize;
+ }
+
+ // Now handle the first dim where slicing occurs on (k).
+ if (idx < reassocGroupSize) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+ std::optional<int64_t> staticOffset = getConstantIntValue(collapsedOffset);
+
+ if (staticOffset.has_value()) {
+ // Static offset: check that offset + size doesn't exceed dimension.
+ int64_t offsetInDim =
+ (staticOffset.value() / currentOffsetDivisor) % expandedShapeSize;
+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize)
+ return failure();
+ } else {
+ // If the offset is dynamic, We could have more restricted conditions
+ // to guarantee contiguous slicing.
+ // For example, we could require that the dimension is divisible by the
+ // slice size and the offset is a multiple of the slice size.
+ // For more complex cases, we could use valueBoundsInterface
+ // to check the validity of the range.
+ if ((expandedShapeSize % currentCollapsedsize) != 0)
+ return failure();
+ if (!isMultipleOf(collapsedOffset, currentCollapsedsize))
+ return failure();
+ }
+ // Slicing dimension gets the remaining collapsed size.
+ groupSizes.push_back(b.getIndexAttr(currentCollapsedsize));
+ }
+
+ // Now handle the leading dimensions where the slice size is equal to 1
+ // (k-1...0).
+ // The size for these dimensions must be 1 because of how we constructed
+ // the slice size of the expanded shape. We spread the original collapsed
+ // size over the expanded shape sizes until we reached dimension k where
+ // the remaining size was smaller than the expanded shape size, and spread
+ // the remaining size on it. So, now we are left with only 1s.
+ for (idx++; idx < reassocGroupSize; ++idx)
+ groupSizes.push_back(b.getIndexAttr(1));
+
+ // Sizes were built in reverse order, so reverse them.
+ groupSizes = llvm::to_vector(llvm::reverse(groupSizes));
+ return success();
+}
+
LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
OpBuilder &b, tensor::ExtractSliceOp sliceOp,
ArrayRef<ReassociationIndices> reassociation,
@@ -620,193 +767,48 @@ LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
return failure();
}
+ using ReassocGroupResult =
+ std::pair<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>;
+ SmallVector<ReassocGroupResult> groupResults;
+
// Compute new offsets, sizes, and strides for tensor.extract_slice.
// The new tensor.extract_slice will work on a tensor that has has a rank
// equal to the rank of the src of the collapse_shape. In each iteration of
// the loop, the offsets and sizes will be computed per reassociation group.
- expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
for (auto [collapsedSize, collapsedOffset, reassocIndices] :
llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
- // CASE #1 - size and/or offset are dynamic.
- if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
- // When the size is dynamic, We will handle a simple case where there
- // is a single dimension in the reassociation group that has a size
- // not equal to 1, which guarantees it is a contiguous slice.
- int nonUnitSizeCount = 0;
- SmallVector<OpFoldResult> tentativeSizes, tentativeOffsets;
- for (int64_t expandedShapeIdx : reassocIndices) {
- if (expandedShape[expandedShapeIdx] != 1) {
- nonUnitSizeCount++;
- tentativeSizes.push_back(collapsedSize);
- tentativeOffsets.push_back(collapsedOffset);
- continue;
- }
-
- tentativeSizes.push_back(b.getIndexAttr(1));
- tentativeOffsets.push_back(b.getIndexAttr(0));
- }
- if (nonUnitSizeCount == 1) {
- // Only one non-unit dimension, slice is contiguous.
- expandedSizes.append(tentativeSizes);
- expandedOffsets.append(tentativeOffsets);
- continue;
- }
-
- std::optional<int64_t> staticSize = getConstantIntValue(collapsedSize);
- if (!staticSize.has_value())
- return failure();
-
- // If the size is statically 1, it means we're extracting a
- // single element. In this case, we can always handle this by
- // delinearizing the dynamic offset and get a contiguous slice.
- if (staticSize.value() == 1) {
- SmallVector<int64_t> basis;
- for (int64_t expandedShapeIdx : reassocIndices) {
- expandedSizes.push_back(b.getIndexAttr(1));
- basis.push_back(expandedShape[expandedShapeIdx]);
- }
-
- Value offsetVal = getValueOrCreateConstantIndexOp(b, sliceOp.getLoc(),
- collapsedOffset);
- auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
- b, sliceOp.getLoc(), offsetVal, basis, /*hasOuterBound=*/true);
- for (auto result : delinearizeOp.getResults())
- expandedOffsets.push_back(result);
-
- continue;
- }
-
- // If the size is greater than 1, we will take a more
- // restrictive case where both the offset and the innermost
- // expanded dimension are divisible by the size.
- //
- // Example:
- // collapse_shape tensor<6x10xf32> -> tensor<60xf32>
- // extract_slice at offset_a with some size_b.
- //
- // Contiguity is guaranteed when:
- // 1. innermost dim (10) % size_b == 0 i.e.
- // the slice fits within a single row.
- // 2. offset_a is multiple of size_b i.e.
- // the slice starts at row-aligned boundary.
- //
- // Result: delinearize offset_a -> [row, col], extract [1, size_b]
- // from expanded shape.
- assert(staticSize.value() > 1 && "Expected size to be greater than 1");
- int64_t sizeVal = staticSize.value();
- int64_t innermostDim = expandedShape[reassocIndices.back()];
-
- // Only try this path if offset is from affine.apply
- Value offsetVal = cast<Value>(collapsedOffset);
- if (!offsetVal.getDefiningOp<affine::AffineApplyOp>())
- return failure();
-
- if (innermostDim % sizeVal != 0)
- return failure();
-
- if (!isValueMultipleOf(offsetVal, sizeVal))
- return failure();
-
- SmallVector<int64_t> basis;
- for (int64_t expandedShapeIdx : reassocIndices)
- basis.push_back(expandedShape[expandedShapeIdx]);
-
- auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
- b, sliceOp.getLoc(), offsetVal, basis, /*hasOuterBound=*/true);
-
- // Sizes: [1, 1, ..., 1, sizeVal] - size goes on innermost dimension.
- for (size_t i = 0; i < reassocIndices.size() - 1; ++i)
- expandedSizes.push_back(b.getIndexAttr(1));
-
- expandedSizes.push_back(collapsedSize);
+ SmallVector<OpFoldResult> groupSizes;
+ SmallVector<OpFoldResult> groupOffsets;
+ LogicalResult result = computeExpandedSliceInfoForReassocGroup(
+ b, collapsedSize, collapsedOffset, reassocIndices, expandedShape,
+ groupSizes, groupOffsets);
+ if (failed(result))
+ return failure();
+ groupResults.emplace_back(std::make_pair(groupSizes, groupOffsets));
+ }
- // Offsets from delinearization.
- for (auto result : delinearizeOp.getResults())
- expandedOffsets.push_back(result);
+ expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
+ for (auto [groupIdx, reassocIndices] : llvm::enumerate(reassociation)) {
+ auto &[sizes, offsets] = groupResults[groupIdx];
+ expandedSizes.append(sizes);
+ if (!offsets.empty()) {
+ expandedOffsets.append(offsets);
continue;
}
- // CASE #2 = size and offset are static.
- // Verify that the slice can be represented as a contiguous slice of the
- // src of the collapse_shape.
- // Checking this is done on order of most internal dimensions first,
- // so traversal is done in reverse order of the reassociation group.
- // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
- // ...,An] then we first find the size and offset for n...k+1 then for k
- // and then for k-1...0.
-
- // currentCollapsedsize and currentCollapsedOffset are initialized with
- // the original collapsed size and offset and divided by the expanded
- // shape size in each dimension as we go along the reassociation group.
- // In essence we are spreading the original collapsed size and offset over
- // the various expanded slice dimensions.
- // The variables are used both to check the validity of the slice and to
- // compute the expanded sizes and offsets.
- int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
- int64_t currentCollapsedOffset =
- getConstantIntValue(collapsedOffset).value();
- SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
- ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
- reassocIndices.rend());
- int64_t idx = 0;
- int64_t reassocGroupSize = reassocIndices.size();
-
- // First handle the trailing dimensions where the slice size should be
- // equal to the tensor shape and the offset should be 0 (n...k+1).
- for (; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
-
- if (currentCollapsedsize < expandedShapeSize)
- break;
-
- // We need to make sure that the slice size can be set to the shape size
- // and the offset to 0.
- if ((currentCollapsedsize % expandedShapeSize) != 0 ||
- (currentCollapsedOffset % expandedShapeSize) != 0) {
- return failure();
- }
-
- groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize));
- groupExpandedOffsets.push_back(b.getIndexAttr(0));
-
- currentCollapsedsize /= expandedShapeSize;
- currentCollapsedOffset /= expandedShapeSize;
- }
-
- // Now handle the first dim where slicing occurs on (k).
- if (idx < reassocGroupSize) {
- int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
- // We need to make sure that the slice size in this dim + offset will
- // not exceed the shape size.
- if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
- return failure();
- }
- groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize));
- groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
- currentCollapsedOffset /= expandedShapeSize;
- }
-
- // Now handle the leading dimensions where the slice size is equal to 1
- // (k-1...0).
- // The size for these dimensions must be 1 because of how we constructed
- // the slice size of the expanded shape. We spread the original collapsed
- // size over the expanded shape sizes until we reached dimension k where
- // the remaining size was smaller than the expanded shape size, and spread
- // the remaining size on it. So, now we are left with only 1s.
- for (idx++; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
- groupExpandedSizes.push_back(b.getIndexAttr(1));
- groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
- currentCollapsedOffset /= expandedShapeSize;
- }
- expandedSizes.append(groupExpandedSizes.rbegin(),
- groupExpandedSizes.rend());
- expandedOffsets.append(groupExpandedOffsets.rbegin(),
- groupExpandedOffsets.rend());
+ SmallVector<int64_t> basis;
+ for (int64_t expandedShapeIdx : reassocIndices)
+ basis.push_back(expandedShape[expandedShapeIdx]);
+
+ OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
+ Value offsetVal =
+ getValueOrCreateConstantIndexOp(b, sliceOp.getLoc(), collapsedOffset);
+ auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
+ b, sliceOp.getLoc(), offsetVal, basis, /*hasOuterBound=*/true);
+ for (OpResult result : delinearizeOp.getResults())
+ expandedOffsets.push_back(result);
}
return success();
}
diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
index 97157eb06fb96..e2bfacfad5c2c 100644
--- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
@@ -135,7 +135,8 @@ func.func @bubble_up_extract_slice_affine_apply_not_folded(%src: tensor<60xf32>,
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(
// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<1xf32> {
-// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, 1, 1] [1, 1, 1]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] [1, 1, 1] [1, 1, 1]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(%src: tensor<6x5x2xf32>) -> tensor<1xf32> {
@@ -146,7 +147,9 @@ func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(%
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(
// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
-// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1]
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C1]], %[[C0]], %[[C1]], %[[C0]]] [3, 5, 1, 10] [1, 1, 1, 1]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3]]
// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(%src: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
@@ -157,7 +160,9 @@ func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(
// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> {
-// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][2, 0, 0] [1, 2, 2] [1, 1, 1]
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C2]], %[[C0]], %[[C0]]] [1, 2, 2] [1, 1, 1]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(%src: tensor<6x5x2xf32>) -> tensor<4xf32> {
@@ -219,7 +224,9 @@ func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_siz
// CHECK-SAME: %[[SRC:.*]]: tensor<5x10x1x1x40xf32>,
// CHECK-SAME: %[[OFFSET:.*]]: index,
// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<20x?xf32> {
-// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][1, 0, 0, 0, %[[OFFSET]]] [2, 10, 1, 1, %[[SIZE]]] [1, 1, 1, 1, 1]
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C1]], %[[C0]], 0, 0, %[[OFFSET]]] [2, 10, 1, 1, %[[SIZE]]] [1, 1, 1, 1, 1]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3, 4]]
// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_groups(%src: tensor<5x10x1x1x40xf32>, %offset : index, %size : index) -> tensor<20x?xf32> {
>From 9e9fc283d7f06c729b1091c80df4a339fe28264c Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Wed, 4 Feb 2026 11:57:23 +0200
Subject: [PATCH 3/3] Fix code review threads
---
.../Dialect/Tensor/Transforms/Transforms.h | 4 +-
.../Tensor/Transforms/ReshapePatterns.cpp | 48 ++++++++-----------
.../Tensor/bubble-up-extract-slice-op.mlir | 18 +++++--
3 files changed, 36 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 3e4da94bd714e..093393eca7436 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -156,14 +156,14 @@ getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp,
/// Computes the offsets, sizes, and strides needed to build an expanded
/// `sliceOp`. The dimensions to expand are specified by `reassociation` and
-/// `expandedShape`.
+/// the shape of `expandedValue`.
///
/// This fails when the specified expansion cannot be represented by a valid
/// ExtractSliceOp.
LogicalResult
getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp,
ArrayRef<ReassociationIndices> reassociation,
- ArrayRef<int64_t> expandedShape,
+ Value expandedValue,
SmallVectorImpl<OpFoldResult> &expandedOffsets,
SmallVectorImpl<OpFoldResult> &expandedSizes,
SmallVectorImpl<OpFoldResult> &expandedStrides);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 2328b311367f6..c74c0c930d4a3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -444,7 +444,7 @@ struct BubbleUpExtractSliceThroughCollapseShape
SmallVector<OpFoldResult> offsets, sizes, strides;
if (failed(getExpandedExtractSliceInfo(
rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
- collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides)))
+ collapseShapeOp.getSrc(), offsets, sizes, strides)))
return failure();
Value newSliceOp = tensor::ExtractSliceOp::create(
@@ -603,13 +603,17 @@ static bool isMultipleOf(OpFoldResult ofr, int64_t factor) {
return map.getResult(0).isMultipleOf(factor);
}
+/// Given a `collapsedOffset` and `collapsedSize`, this function
+/// validates that the slice is representable as a contiguous slice
+/// in the `expandedShape` and computes the corresponding expanded sizes.
+/// Returns failure if the slice cannot be guaranteed to be contiguous.
+/// On success, populates `groupSizes` with the expanded sizes for each
+/// dimension in the reassociation group.
static LogicalResult computeExpandedSliceInfoForReassocGroup(
OpBuilder &b, OpFoldResult collapsedSize, OpFoldResult collapsedOffset,
const ReassociationIndices &reassocIndices, ArrayRef<int64_t> expandedShape,
- SmallVectorImpl<OpFoldResult> &groupSizes,
- SmallVectorImpl<OpFoldResult> &groupOffsets) {
+ SmallVectorImpl<OpFoldResult> &groupSizes) {
assert(groupSizes.empty() && "Group sizes must be empty");
- assert(groupOffsets.empty() && "Group offsets must be empty");
// The first case is when there's only one non-unit dimension in the
// reassociation group.
// When there's only one non-unit dimension, the slice is trivially
@@ -621,13 +625,10 @@ static LogicalResult computeExpandedSliceInfoForReassocGroup(
});
if (nonUnitSizeCount == 1) {
for (int64_t expandedShapeIdx : reassocIndices) {
- if (expandedShape[expandedShapeIdx] != 1) {
+ if (expandedShape[expandedShapeIdx] != 1)
groupSizes.push_back(collapsedSize);
- groupOffsets.push_back(collapsedOffset);
- continue;
- }
- groupSizes.push_back(b.getIndexAttr(1));
- groupOffsets.push_back(b.getIndexAttr(0));
+ else
+ groupSizes.push_back(b.getIndexAttr(1));
}
return success();
}
@@ -746,8 +747,7 @@ static LogicalResult computeExpandedSliceInfoForReassocGroup(
LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
OpBuilder &b, tensor::ExtractSliceOp sliceOp,
- ArrayRef<ReassociationIndices> reassociation,
- ArrayRef<int64_t> expandedShape,
+ ArrayRef<ReassociationIndices> reassociation, Value expandedValue,
SmallVectorImpl<OpFoldResult> &expandedOffsets,
SmallVectorImpl<OpFoldResult> &expandedSizes,
SmallVectorImpl<OpFoldResult> &expandedStrides) {
@@ -767,40 +767,34 @@ LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
return failure();
}
- using ReassocGroupResult =
- std::pair<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>;
- SmallVector<ReassocGroupResult> groupResults;
-
// Compute new offsets, sizes, and strides for tensor.extract_slice.
// The new tensor.extract_slice will work on a tensor that has has a rank
// equal to the rank of the src of the collapse_shape. In each iteration of
// the loop, the offsets and sizes will be computed per reassociation group.
+ ArrayRef<int64_t> expandedShape =
+ cast<RankedTensorType>(expandedValue.getType()).getShape();
+ SmallVector<SmallVector<OpFoldResult>> groupResults;
for (auto [collapsedSize, collapsedOffset, reassocIndices] :
llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
SmallVector<OpFoldResult> groupSizes;
- SmallVector<OpFoldResult> groupOffsets;
LogicalResult result = computeExpandedSliceInfoForReassocGroup(
b, collapsedSize, collapsedOffset, reassocIndices, expandedShape,
- groupSizes, groupOffsets);
+ groupSizes);
if (failed(result))
return failure();
- groupResults.emplace_back(std::make_pair(groupSizes, groupOffsets));
+ groupResults.emplace_back(groupSizes);
}
expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
for (auto [groupIdx, reassocIndices] : llvm::enumerate(reassociation)) {
- auto &[sizes, offsets] = groupResults[groupIdx];
+ auto &sizes = groupResults[groupIdx];
expandedSizes.append(sizes);
- if (!offsets.empty()) {
- expandedOffsets.append(offsets);
- continue;
- }
-
- SmallVector<int64_t> basis;
+ SmallVector<OpFoldResult> basis;
for (int64_t expandedShapeIdx : reassocIndices)
- basis.push_back(expandedShape[expandedShapeIdx]);
+ basis.push_back(tensor::getMixedSize(b, sliceOp.getLoc(), expandedValue,
+ expandedShapeIdx));
OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
Value offsetVal =
diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
index e2bfacfad5c2c..9a78b263fb84b 100644
--- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
@@ -174,7 +174,8 @@ func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size(
// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x1xf32>,
// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> {
-// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, %[[SIZE]], 1] [1, 1, 1]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] [1, %[[SIZE]], 1] [1, 1, 1]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size(%src: tensor<1x5x1xf32>, %size : index) -> tensor<?xf32> {
@@ -186,7 +187,11 @@ func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size(%src: ten
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src(
// CHECK-SAME: %[[SRC:.*]]: tensor<1x?x1xf32>,
// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> {
-// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, %[[SIZE]], 1] [1, 1, 1]
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[SRC]], %[[C1]]
+// CHECK: %[[DELIN:.*]]:3 = affine.delinearize_index %[[C0]] into (1, %[[DIM]], 1)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] [1, %[[SIZE]], 1] [1, 1, 1]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src(%src: tensor<1x?x1xf32>, %size : index) -> tensor<?xf32> {
@@ -198,7 +203,8 @@ func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src(%
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset(
// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x1xf32>,
// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<3xf32> {
-// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, %[[OFFSET]], 0] [1, 3, 1] [1, 1, 1]
+// CHECK: %[[DELIN:.*]]:3 = affine.delinearize_index %[[OFFSET]] into (1, 5, 1)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] [1, 3, 1] [1, 1, 1]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset(%src: tensor<1x5x1xf32>, %offset : index) -> tensor<3xf32> {
@@ -211,7 +217,8 @@ func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset(%src: t
// CHECK-SAME: %[[SRC:.*]]: tensor<14x1xf32>,
// CHECK-SAME: %[[OFFSET:.*]]: index,
// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> {
-// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[OFFSET]], 0] {{\[}}%[[SIZE]], 1] [1, 1]
+// CHECK: %[[DELIN:.*]]:2 = affine.delinearize_index %[[OFFSET]] into (14, 1)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[DELIN]]#0, %[[DELIN]]#1] {{\[}}%[[SIZE]], 1] [1, 1]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1]]
// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_size(%src: tensor<14x1xf32>, %offset : index, %size : index) -> tensor<?xf32> {
@@ -226,7 +233,8 @@ func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_siz
// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<20x?xf32> {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C1]], %[[C0]], 0, 0, %[[OFFSET]]] [2, 10, 1, 1, %[[SIZE]]] [1, 1, 1, 1, 1]
+// CHECK: %[[DELIN:.*]]:3 = affine.delinearize_index %[[OFFSET]] into (1, 1, 40)
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][%[[C1]], %[[C0]], %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] [2, 10, 1, 1, %[[SIZE]]] [1, 1, 1, 1, 1]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3, 4]]
// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_groups(%src: tensor<5x10x1x1x40xf32>, %offset : index, %size : index) -> tensor<20x?xf32> {
More information about the Mlir-commits
mailing list