[Mlir-commits] [mlir] [mlir][tensor] Forward concat insert_slice destination into DPS provider (PR #183490)
Dhruv Chauhan
llvmlistbot at llvm.org
Tue Mar 17 08:32:43 PDT 2026
https://github.com/dchauhan-arm updated https://github.com/llvm/llvm-project/pull/183490
>From f52be5d4cd71cfe78f36ea4ef75f0138cd311406 Mon Sep 17 00:00:00 2001
From: Dhruv Chauhan <dhruv.chauhan at arm.com>
Date: Thu, 26 Feb 2026 09:57:33 +0000
Subject: [PATCH] [mlir][tensor] Forward concat insert_slice destination into
DPS provider
Implement concat insert_slice destination forwarding as a Tensor rewrite
pattern and apply it from the tensor subset folding pass.
The pattern forwards concat generated `insert_slice` destinations into
single use detination-style producers, avoiding producer results that
are immediately copied into the concat result tensor.
Add a Tensor dialect regression lit test that checks forwarded
`tensor.extract_slice + linalg.fill` shape
---
.../Dialect/Tensor/Transforms/Transforms.h | 5 +
.../Tensor/Transforms/ConcatOpPatterns.cpp | 98 +++++++++++++++++++
.../Tensor/Transforms/FoldTensorSubsetOps.cpp | 1 +
.../Tensor/fold-tensor-subset-ops.mlir | 25 +++++
4 files changed, 129 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 093393eca7436..3db9f5c542516 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -96,6 +96,11 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
/// that it can be bufferized into a sequence of copies.
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns that forward concat-generated
+/// `tensor.insert_slice` destinations into single-use destination-style source
+/// producers.
+void populateForwardConcatInsertSliceDestPatterns(RewritePatternSet &patterns);
+
using ControlFoldFn = std::function<bool(OpOperand *)>;
/// Populates `patterns` with patterns that replace tensor ops (such as
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
index 20bed05ecc11d..e164fd7d60983 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
@@ -41,9 +42,106 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
}
};
+/// Forward the destination tensor of concat generated tensor.insert_slice ops
+/// into single-use destination-style tensor producers. This avoids creating a
+/// producer on a temporary tensor that is immediately copied into the concat
+/// result tensor.
+///
+/// Before:
+/// %small = tensor.empty() : tensor<4xf32>
+/// %fill = linalg.fill ins(%cst : f32) outs(%small : tensor<4xf32>)
+/// -> tensor<4xf32>
+/// %init = tensor.empty() : tensor<8xf32>
+/// %insert0 = tensor.insert_slice %fill into %init[0] [4] [1]
+/// : tensor<4xf32> into tensor<8xf32>
+/// %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1]
+/// : tensor<4xf32> into tensor<8xf32>
+///
+/// After:
+/// %init = tensor.empty() : tensor<8xf32>
+/// %slice = tensor.extract_slice %init[0] [4] [1]
+/// : tensor<8xf32> to tensor<4xf32>
+/// %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<4xf32>)
+/// -> tensor<4xf32>
+/// %insert0 = tensor.insert_slice %fill into %init[0] [4] [1]
+/// : tensor<4xf32> into tensor<8xf32>
+/// %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1]
+/// : tensor<4xf32> into tensor<8xf32>
+struct ForwardConcatInsertSliceDest : public OpRewritePattern<InsertSliceOp> {
+ using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertSliceOp insertOp,
+ PatternRewriter &rewriter) const override {
+ // Only rewrite when the insert source is an SSA result with a single use.
+ Value source = insertOp.getSource();
+ auto sourceResult = dyn_cast<OpResult>(source);
+ if (!sourceResult || !source.hasOneUse())
+ return failure();
+
+ // Restrict to concat-style insert chains where the destination is either
+ // the initial tensor.empty or a previous tensor.insert_slice result.
+ Operation *destDef = insertOp.getDest().getDefiningOp();
+ if (!isa_and_present<EmptyOp, InsertSliceOp>(destDef))
+ return failure();
+
+ // The source producer must be destination-style on tensors so we can
+ // retarget its tied output to a slice of the final concat destination.
+ auto producer = source.getDefiningOp<DestinationStyleOpInterface>();
+ if (!producer || !producer.hasPureTensorSemantics())
+ return failure();
+
+ if (producer->getNumResults() != 1)
+ return failure();
+
+ OpOperand *tiedInit = producer.getTiedOpOperand(sourceResult);
+ if (!tiedInit)
+ return failure();
+
+ auto sourceType = dyn_cast<RankedTensorType>(source.getType());
+ if (!sourceType || !isa<RankedTensorType>(insertOp.getDest().getType()))
+ return failure();
+
+ auto mixedOffsets = insertOp.getMixedOffsets();
+ auto mixedSizes = insertOp.getMixedSizes();
+ auto mixedStrides = insertOp.getMixedStrides();
+
+ auto extractedInit = tiedInit->get().getDefiningOp<ExtractSliceOp>();
+ if (extractedInit && extractedInit.getSource() == insertOp.getDest() &&
+ llvm::equal(extractedInit.getMixedOffsets(), mixedOffsets) &&
+ llvm::equal(extractedInit.getMixedSizes(), mixedSizes) &&
+ llvm::equal(extractedInit.getMixedStrides(), mixedStrides)) {
+ return failure();
+ }
+
+ // Extract slice from the final destination
+ Value extractedDest = ExtractSliceOp::create(
+ rewriter, insertOp.getLoc(), sourceType, insertOp.getDest(),
+ mixedOffsets, mixedSizes, mixedStrides);
+
+ IRMapping mapping;
+ mapping.map(tiedInit->get(), extractedDest);
+ Operation *newProducer = rewriter.clone(*producer, mapping);
+ Value newSource = newProducer->getResult(sourceResult.getResultNumber());
+
+ // Rebuild insert_slice with the retargeted producer result, then erase the
+ // original producer (guaranteed to have a single use)
+ Value newInsert = InsertSliceOp::create(
+ rewriter, insertOp.getLoc(), newSource, insertOp.getDest(),
+ mixedOffsets, mixedSizes, mixedStrides);
+ rewriter.replaceOp(insertOp, newInsert);
+ rewriter.eraseOp(producer.getOperation());
+ return success();
+ }
+};
+
} // namespace
void mlir::tensor::populateDecomposeTensorConcatPatterns(
RewritePatternSet &patterns) {
patterns.add<DecomposeTensorConcatOp>(patterns.getContext());
}
+
+void mlir::tensor::populateForwardConcatInsertSliceDestPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ForwardConcatInsertSliceDest>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index b32faf481af80..65b3bf27f0ae4 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -246,6 +246,7 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
+ populateForwardConcatInsertSliceDestPatterns(patterns);
patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
patterns.getContext());
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index cf8711eb64ab9..724db05ccfa8c 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -345,6 +345,31 @@ func.func @insert_slice_of_insert_slice_dynamic(
// -----
+// CHECK-LABEL: func.func @forward_concat_insert_slice_dest
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xf32>)
+func.func @forward_concat_insert_slice_dest(%arg0: tensor<4xf32>)
+ -> tensor<8xf32> {
+ %cst = arith.constant 1.000000e+00 : f32
+ %small = tensor.empty() : tensor<4xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%small : tensor<4xf32>)
+ -> tensor<4xf32>
+ %init = tensor.empty() : tensor<8xf32>
+ %insert0 = tensor.insert_slice %fill into %init[0] [4] [1]
+ : tensor<4xf32> into tensor<8xf32>
+ %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1]
+ : tensor<4xf32> into tensor<8xf32>
+ return %insert1 : tensor<8xf32>
+}
+// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[INIT]][0] [4] [1] : tensor<8xf32> to tensor<4xf32>
+// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SLICE]] : tensor<4xf32>) -> tensor<4xf32>
+// CHECK: %[[INSERT0:.*]] = tensor.insert_slice %[[FILL]] into %[[INIT]][0] [4] [1] : tensor<4xf32> into tensor<8xf32>
+// CHECK: %[[INSERT1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INSERT0]][4] [4] [1] : tensor<4xf32> into tensor<8xf32>
+// CHECK: return %[[INSERT1]] : tensor<8xf32>
+
+// -----
+
// Here the sizes are the same and the folding occurs properly.
// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)>
// CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic(
More information about the Mlir-commits
mailing list