[Mlir-commits] [mlir] [mlir][tensor] Remove unit-stride restriction in InsertSliceOp folding (PR #192600)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 16 23:32:37 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR replaces manual offset/size resolution with `affine::mergeOffsetsSizesAndStrides`, simplifying the code and extending subview-of-subview folding to support non-unit strides.
---
Full diff: https://github.com/llvm/llvm-project/pull/192600.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (+7-31)
- (modified) mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir (+24-5)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index b32faf481af80..14f96be5b56dd 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -183,16 +183,6 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
if (!sourceInsertSliceOp)
return failure();
- // TODO: relax unit stride assumption where possible.
- if (!insertSliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(insertSliceOp,
- "requires unit strides");
- }
- if (!sourceInsertSliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(sourceInsertSliceOp,
- "requires unit strides");
- }
-
int64_t srcDim = 0;
llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
@@ -206,15 +196,6 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
}
}
- // Resolve sizes according to dropped dims.
- SmallVector<OpFoldResult> resolvedSizes;
- // Note: the "insertSlice" case is symmetrical to the extract/subview case:
- // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
- // passed as the destination to the helper function.
- affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
- sourceInsertSliceOp.getMixedSizes(),
- droppedDims, resolvedSizes);
-
// If we are inside a ParallelCombining region, temporarily set the
// insertion point outside: only ops of ParallelCombiningOpInterface are
// allowed in there.
@@ -222,24 +203,19 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
}
- // Resolve offsets according to source offsets and strides.
- SmallVector<Value> resolvedOffsets;
- // Note: the "insertSlice" case is symmetrical to the extract/subview case:
- // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
- // passed as the destination to the helper function.
- affine::resolveIndicesIntoOpWithOffsetsAndStrides(
- rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
- insertSliceOp.getMixedStrides(), droppedDims,
- sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
+ SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
+ if (failed(affine::mergeOffsetsSizesAndStrides(
+ rewriter, insertSliceOp.getLoc(), insertSliceOp,
+ sourceInsertSliceOp, droppedDims, newOffsets, newSizes,
+ newStrides)))
+ return failure();
// Reset the insertion point.
rewriter.setInsertionPoint(insertSliceOp);
// Replace original op.
rewriter.replaceOpWithNewOp<OpTy>(
insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
- getAsOpFoldResult(resolvedOffsets), resolvedSizes,
- insertSliceOp.getMixedStrides());
-
+ newOffsets, newSizes, newStrides);
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index cf8711eb64ab9..45937e94f08ff 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -305,6 +305,25 @@ func.func @insert_slice_of_insert_slice(%t: tensor<f32>, %r0: tensor<1xf32>, %r1
// -----
+// CHECK-LABEL: func.func @insert_slice_of_insert_slice_non_unit_stride(
+// CHECK-SAME: %[[t:.*]]: tensor<f32>,
+// CHECK-SAME: %[[r0:.*]]: tensor<1xf32>,
+// CHECK-SAME: %[[r1:.*]]: tensor<1x14xf32>,
+// CHECK-SAME: %[[pos:.*]]: index) -> tensor<1x14xf32> {
+// CHECK: tensor.insert_slice %[[t]] into %[[r1]][0, %[[pos]]] [1, 1] [1, 2] : tensor<f32> into tensor<1x14xf32>
+func.func @insert_slice_of_insert_slice_non_unit_stride(
+ %t: tensor<f32>, %r0: tensor<1xf32>, %r1: tensor<1x14xf32>, %pos: index)
+ -> tensor<1x14xf32>
+{
+ %0 = tensor.insert_slice %t into %r0[0] [1] [1]
+ : tensor<f32> into tensor<1xf32>
+ %1 = tensor.insert_slice %0 into %r1[0, %pos] [1, 1] [1, 2]
+ : tensor<1xf32> into tensor<1x14xf32>
+ return %1 : tensor<1x14xf32>
+}
+
+// -----
+
// This test fails to fold because the size `4` and `%pos` do not match:
// this requires a copy
// CHECK-LABEL: func @fail_insert_slice_of_insert_slice(
@@ -324,21 +343,21 @@ func.func @fail_insert_slice_of_insert_slice(
// -----
// Here the sizes are the same and the folding occurs properly.
-// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + s0 * s0)>
// CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic(
// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor<?xf32>
// CHECK-SAME: %[[r0:[0-9a-z]*]]: tensor<?xf32>
// CHECK-SAME: %[[r1:[0-9a-z]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[pos:[0-9a-z]*]]: index
-// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]]
-// CHECK: tensor.insert_slice %[[t]] into %[[r1]][%[[add]], 423] [%[[pos]], 1] [1, 1] : tensor<?xf32> into tensor<?x?xf32>
+// CHECK: %[[offset:.*]] = affine.apply #[[$map]]()[%[[pos]]]
+// CHECK: tensor.insert_slice %[[t]] into %[[r1]][%[[offset]], 423] [%[[pos]], 1] [%[[pos]], 1] : tensor<?xf32> into tensor<?x?xf32>
func.func @insert_slice_of_insert_slice_dynamic(
%t: tensor<?xf32>, %r0: tensor<?xf32>, %r1: tensor<?x?xf32>, %pos: index)
-> tensor<?x?xf32>
{
%0 = tensor.insert_slice %t into %r0[%pos] [%pos] [1]
: tensor<?xf32> into tensor<?xf32>
- %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1]
+ %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [%pos, 1]
: tensor<?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -385,7 +404,7 @@ func.func @parallel_insert_slice_of_insert_slice_dynamic(
%tt2 = "make_me_another_tensor"() : () -> tensor<?x?xf32>
%inserted_slice = tensor.insert_slice %tt into %tt2[%o1, 0] [%sz0, %sz1] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
- // CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[o0]], %[[o1]]]
+ // CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[o1]], %[[o0]]]
// CHECK: scf.forall.in_parallel
// CHECK: tensor.parallel_insert_slice %[[tt]] into %[[out]][%[[add]], %[[o1]]] [%[[sz0]], %[[sz1]]] [1, 1]
// CHECK-SAME: : tensor<?x?xf32> into tensor<12x34xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/192600
More information about the Mlir-commits
mailing list