[Mlir-commits] [mlir] Revert "[mlir][tensor] Forward concat insert_slice destination into DPS provider" (PR #190143)

Dhruv Chauhan llvmlistbot at llvm.org
Thu Apr 2 03:28:00 PDT 2026


https://github.com/dchauhan-arm updated https://github.com/llvm/llvm-project/pull/190143

>From de05fdfaf6d869461fe37614c8897e811192b87d Mon Sep 17 00:00:00 2001
From: Dhruv Chauhan <dhruv.chauhan at arm.com>
Date: Thu, 2 Apr 2026 11:09:17 +0100
Subject: [PATCH] =?UTF-8?q?Revert=20"[mlir][tensor]=20Forward=20concat=20i?=
 =?UTF-8?q?nsert=5Fslice=20destination=20into=20DPS=20provi=E2=80=A6"?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This reverts commit 1418f80d5c4d1f54e2f11f0650818c6c602aa505.

The change can cause an infinite rewrite loop when
ForwardConcatInsertSliceDest interacts with
FoldEmptyTensorWithExtractSliceOp.
---
 .../Dialect/Tensor/Transforms/Transforms.h    |  5 -
 .../Tensor/Transforms/ConcatOpPatterns.cpp    | 98 -------------------
 .../Tensor/Transforms/FoldTensorSubsetOps.cpp |  1 -
 .../Tensor/fold-tensor-subset-ops.mlir        | 25 -----
 4 files changed, 129 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 3db9f5c542516..093393eca7436 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -96,11 +96,6 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
 /// that it can be bufferized into a sequence of copies.
 void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
 
-/// Populates `patterns` with patterns that forward concat-generated
-/// `tensor.insert_slice` destinations into single-use destination-style source
-/// producers.
-void populateForwardConcatInsertSliceDestPatterns(RewritePatternSet &patterns);
-
 using ControlFoldFn = std::function<bool(OpOperand *)>;
 
 /// Populates `patterns` with patterns that replace tensor ops (such as
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
index e164fd7d60983..20bed05ecc11d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -8,7 +8,6 @@
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
-#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
 
 using namespace mlir;
@@ -42,106 +41,9 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
   }
 };
 
-/// Forward the destination tensor of concat generated tensor.insert_slice ops
-/// into single-use destination-style tensor producers. This avoids creating a
-/// producer on a temporary tensor that is immediately copied into the concat
-/// result tensor.
-///
-/// Before:
-/// %small = tensor.empty() : tensor<4xf32>
-/// %fill = linalg.fill ins(%cst : f32) outs(%small : tensor<4xf32>)
-///     -> tensor<4xf32>
-/// %init = tensor.empty() : tensor<8xf32>
-/// %insert0 = tensor.insert_slice %fill into %init[0] [4] [1]
-///     : tensor<4xf32> into tensor<8xf32>
-/// %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1]
-///     : tensor<4xf32> into tensor<8xf32>
-///
-/// After:
-/// %init = tensor.empty() : tensor<8xf32>
-/// %slice = tensor.extract_slice %init[0] [4] [1]
-///     : tensor<8xf32> to tensor<4xf32>
-/// %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<4xf32>)
-///     -> tensor<4xf32>
-/// %insert0 = tensor.insert_slice %fill into %init[0] [4] [1]
-///     : tensor<4xf32> into tensor<8xf32>
-/// %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1]
-///     : tensor<4xf32> into tensor<8xf32>
-struct ForwardConcatInsertSliceDest : public OpRewritePattern<InsertSliceOp> {
-  using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(InsertSliceOp insertOp,
-                                PatternRewriter &rewriter) const override {
-    // Only rewrite when the insert source is an SSA result with a single use.
-    Value source = insertOp.getSource();
-    auto sourceResult = dyn_cast<OpResult>(source);
-    if (!sourceResult || !source.hasOneUse())
-      return failure();
-
-    // Restrict to concat-style insert chains where the destination is either
-    // the initial tensor.empty or a previous tensor.insert_slice result.
-    Operation *destDef = insertOp.getDest().getDefiningOp();
-    if (!isa_and_present<EmptyOp, InsertSliceOp>(destDef))
-      return failure();
-
-    // The source producer must be destination-style on tensors so we can
-    // retarget its tied output to a slice of the final concat destination.
-    auto producer = source.getDefiningOp<DestinationStyleOpInterface>();
-    if (!producer || !producer.hasPureTensorSemantics())
-      return failure();
-
-    if (producer->getNumResults() != 1)
-      return failure();
-
-    OpOperand *tiedInit = producer.getTiedOpOperand(sourceResult);
-    if (!tiedInit)
-      return failure();
-
-    auto sourceType = dyn_cast<RankedTensorType>(source.getType());
-    if (!sourceType || !isa<RankedTensorType>(insertOp.getDest().getType()))
-      return failure();
-
-    auto mixedOffsets = insertOp.getMixedOffsets();
-    auto mixedSizes = insertOp.getMixedSizes();
-    auto mixedStrides = insertOp.getMixedStrides();
-
-    auto extractedInit = tiedInit->get().getDefiningOp<ExtractSliceOp>();
-    if (extractedInit && extractedInit.getSource() == insertOp.getDest() &&
-        llvm::equal(extractedInit.getMixedOffsets(), mixedOffsets) &&
-        llvm::equal(extractedInit.getMixedSizes(), mixedSizes) &&
-        llvm::equal(extractedInit.getMixedStrides(), mixedStrides)) {
-      return failure();
-    }
-
-    // Extract slice from the final destination
-    Value extractedDest = ExtractSliceOp::create(
-        rewriter, insertOp.getLoc(), sourceType, insertOp.getDest(),
-        mixedOffsets, mixedSizes, mixedStrides);
-
-    IRMapping mapping;
-    mapping.map(tiedInit->get(), extractedDest);
-    Operation *newProducer = rewriter.clone(*producer, mapping);
-    Value newSource = newProducer->getResult(sourceResult.getResultNumber());
-
-    // Rebuild insert_slice with the retargeted producer result, then erase the
-    // original producer (guaranteed to have a single use)
-    Value newInsert = InsertSliceOp::create(
-        rewriter, insertOp.getLoc(), newSource, insertOp.getDest(),
-        mixedOffsets, mixedSizes, mixedStrides);
-    rewriter.replaceOp(insertOp, newInsert);
-    rewriter.eraseOp(producer.getOperation());
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::tensor::populateDecomposeTensorConcatPatterns(
     RewritePatternSet &patterns) {
   patterns.add<DecomposeTensorConcatOp>(patterns.getContext());
 }
-
-void mlir::tensor::populateForwardConcatInsertSliceDestPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<ForwardConcatInsertSliceDest>(patterns.getContext());
-}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 65b3bf27f0ae4..b32faf481af80 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -246,7 +246,6 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
 
 void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
   populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
-  populateForwardConcatInsertSliceDestPatterns(patterns);
   patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
                InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
       patterns.getContext());
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index 724db05ccfa8c..cf8711eb64ab9 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -345,31 +345,6 @@ func.func @insert_slice_of_insert_slice_dynamic(
 
 // -----
 
-// CHECK-LABEL: func.func @forward_concat_insert_slice_dest
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xf32>)
-func.func @forward_concat_insert_slice_dest(%arg0: tensor<4xf32>)
-    -> tensor<8xf32> {
-  %cst = arith.constant 1.000000e+00 : f32
-  %small = tensor.empty() : tensor<4xf32>
-  %fill = linalg.fill ins(%cst : f32) outs(%small : tensor<4xf32>)
-      -> tensor<4xf32>
-  %init = tensor.empty() : tensor<8xf32>
-  %insert0 = tensor.insert_slice %fill into %init[0] [4] [1]
-      : tensor<4xf32> into tensor<8xf32>
-  %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1]
-      : tensor<4xf32> into tensor<8xf32>
-  return %insert1 : tensor<8xf32>
-}
-// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<8xf32>
-// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[INIT]][0] [4] [1] : tensor<8xf32> to tensor<4xf32>
-// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SLICE]] : tensor<4xf32>) -> tensor<4xf32>
-// CHECK: %[[INSERT0:.*]] = tensor.insert_slice %[[FILL]] into %[[INIT]][0] [4] [1] : tensor<4xf32> into tensor<8xf32>
-// CHECK: %[[INSERT1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INSERT0]][4] [4] [1] : tensor<4xf32> into tensor<8xf32>
-// CHECK: return %[[INSERT1]] : tensor<8xf32>
-
-// -----
-
 // Here the sizes are the same and the folding occurs properly.
 //       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)>
 // CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic(



More information about the Mlir-commits mailing list