[Mlir-commits] [mlir] [mlir][tensor] Remove unit-stride restriction in InsertSliceOp folding (PR #192600)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 16 23:32:37 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR replaces manual offset/size resolution with `affine::mergeOffsetsSizesAndStrides`, simplifying the code and extending subview-of-subview folding to support non-unit strides.

---
Full diff: https://github.com/llvm/llvm-project/pull/192600.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (+7-31) 
- (modified) mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir (+24-5) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index b32faf481af80..14f96be5b56dd 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -183,16 +183,6 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
     if (!sourceInsertSliceOp)
       return failure();
 
-    // TODO: relax unit stride assumption where possible.
-    if (!insertSliceOp.hasUnitStride()) {
-      return rewriter.notifyMatchFailure(insertSliceOp,
-                                         "requires unit strides");
-    }
-    if (!sourceInsertSliceOp.hasUnitStride()) {
-      return rewriter.notifyMatchFailure(sourceInsertSliceOp,
-                                         "requires unit strides");
-    }
-
     int64_t srcDim = 0;
     llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
     for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
@@ -206,15 +196,6 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
       }
     }
 
-    // Resolve sizes according to dropped dims.
-    SmallVector<OpFoldResult> resolvedSizes;
-    // Note: the "insertSlice" case is symmetrical to the extract/subview case:
-    // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
-    // passed as the destination to the helper function.
-    affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
-                                        sourceInsertSliceOp.getMixedSizes(),
-                                        droppedDims, resolvedSizes);
-
     // If we are inside a ParallelCombining region, temporarily set the
     // insertion point outside: only ops of ParallelCombiningOpInterface are
     // allowed in there.
@@ -222,24 +203,19 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
       rewriter.setInsertionPoint(insertSliceOp->getParentOp());
     }
 
-    // Resolve offsets according to source offsets and strides.
-    SmallVector<Value> resolvedOffsets;
-    // Note: the "insertSlice" case is symmetrical to the extract/subview case:
-    // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
-    // passed as the destination to the helper function.
-    affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-        rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
-        insertSliceOp.getMixedStrides(), droppedDims,
-        sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
+    SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
+    if (failed(affine::mergeOffsetsSizesAndStrides(
+            rewriter, insertSliceOp.getLoc(), insertSliceOp,
+            sourceInsertSliceOp, droppedDims, newOffsets, newSizes,
+            newStrides)))
+      return failure();
 
     // Reset the insertion point.
     rewriter.setInsertionPoint(insertSliceOp);
     // Replace original op.
     rewriter.replaceOpWithNewOp<OpTy>(
         insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
-        getAsOpFoldResult(resolvedOffsets), resolvedSizes,
-        insertSliceOp.getMixedStrides());
-
+        newOffsets, newSizes, newStrides);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index cf8711eb64ab9..45937e94f08ff 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -305,6 +305,25 @@ func.func @insert_slice_of_insert_slice(%t: tensor<f32>, %r0: tensor<1xf32>, %r1
 
 // -----
 
+// CHECK-LABEL:   func.func @insert_slice_of_insert_slice_non_unit_stride(
+// CHECK-SAME:      %[[t:.*]]: tensor<f32>,
+// CHECK-SAME:      %[[r0:.*]]: tensor<1xf32>,
+// CHECK-SAME:      %[[r1:.*]]: tensor<1x14xf32>,
+// CHECK-SAME:      %[[pos:.*]]: index) -> tensor<1x14xf32> {
+// CHECK:           tensor.insert_slice %[[t]] into %[[r1]][0, %[[pos]]] [1, 1] [1, 2] : tensor<f32> into tensor<1x14xf32>
+func.func @insert_slice_of_insert_slice_non_unit_stride(
+  %t: tensor<f32>, %r0: tensor<1xf32>, %r1: tensor<1x14xf32>, %pos: index)
+    -> tensor<1x14xf32> 
+{
+  %0 = tensor.insert_slice %t into %r0[0] [1] [1] 
+    : tensor<f32> into tensor<1xf32>
+  %1 = tensor.insert_slice %0 into %r1[0, %pos] [1, 1] [1, 2] 
+    : tensor<1xf32> into tensor<1x14xf32>
+  return %1 : tensor<1x14xf32>
+}
+
+// -----
+
 // This test fails to fold because the size `4` and `%pos` do not match: 
 // this requires a copy
 // CHECK-LABEL: func @fail_insert_slice_of_insert_slice(
@@ -324,21 +343,21 @@ func.func @fail_insert_slice_of_insert_slice(
 // -----
 
 // Here the sizes are the same and the folding occurs properly.
-//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)>
+//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + s0 * s0)>
 // CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic(
 //  CHECK-SAME:     %[[t:[0-9a-z]*]]: tensor<?xf32>
 //  CHECK-SAME:     %[[r0:[0-9a-z]*]]: tensor<?xf32>
 //  CHECK-SAME:     %[[r1:[0-9a-z]*]]: tensor<?x?xf32>
 //  CHECK-SAME:     %[[pos:[0-9a-z]*]]: index
-//       CHECK:   %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]]
-//       CHECK:   tensor.insert_slice %[[t]] into %[[r1]][%[[add]], 423] [%[[pos]], 1] [1, 1] : tensor<?xf32> into tensor<?x?xf32>
+//       CHECK:   %[[offset:.*]] = affine.apply #[[$map]]()[%[[pos]]]
+//       CHECK:   tensor.insert_slice %[[t]] into %[[r1]][%[[offset]], 423] [%[[pos]], 1] [%[[pos]], 1] : tensor<?xf32> into tensor<?x?xf32>
 func.func @insert_slice_of_insert_slice_dynamic(
   %t: tensor<?xf32>, %r0: tensor<?xf32>, %r1: tensor<?x?xf32>, %pos: index)
     -> tensor<?x?xf32> 
 {
   %0 = tensor.insert_slice %t into %r0[%pos] [%pos] [1] 
     : tensor<?xf32> into tensor<?xf32>
-  %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1] 
+  %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [%pos, 1] 
     : tensor<?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
@@ -385,7 +404,7 @@ func.func @parallel_insert_slice_of_insert_slice_dynamic(
     %tt2 = "make_me_another_tensor"() : () -> tensor<?x?xf32>
     %inserted_slice = tensor.insert_slice %tt into %tt2[%o1, 0] [%sz0, %sz1] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 
-    //      CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[o0]], %[[o1]]]
+    //      CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[o1]], %[[o0]]]
     //      CHECK: scf.forall.in_parallel
     //      CHECK:   tensor.parallel_insert_slice %[[tt]] into %[[out]][%[[add]], %[[o1]]] [%[[sz0]], %[[sz1]]] [1, 1]
     // CHECK-SAME:     : tensor<?x?xf32> into tensor<12x34xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/192600


More information about the Mlir-commits mailing list