[Mlir-commits] [mlir] [mlir][tosa] Forward concat insert_slice destination into DPS provider (PR #183490)

Dhruv Chauhan llvmlistbot at llvm.org
Thu Mar 5 05:58:56 PST 2026


https://github.com/dchauhan-arm updated https://github.com/llvm/llvm-project/pull/183490

>From c6168a16cd8fbbd4a7527e5e224e732c12c2c2c5 Mon Sep 17 00:00:00 2001
From: Dhruv Chauhan <dhruv.chauhan at arm.com>
Date: Thu, 26 Feb 2026 09:57:33 +0000
Subject: [PATCH] [mlir][tosa] Forward concat insert_slice destination into DPS
 provider

Implement concat insert_slice destination forwarding as a Tensor rewrite
pattern. The pattern forwards concat generated insert_slice destinations
into single use destination style producers, avoiding producer results
that are immediately copied into the concat result tensor.
---
 .../Dialect/Tensor/Transforms/Transforms.h    |  5 ++
 .../Conversion/TosaToTensor/CMakeLists.txt    |  1 +
 .../Conversion/TosaToTensor/TosaToTensor.cpp  |  2 +
 .../Tensor/Transforms/ConcatOpPatterns.cpp    | 90 +++++++++++++++++++
 .../TosaToTensor/tosa-to-tensor.mlir          | 19 ++++
 5 files changed, 117 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 093393eca7436..3db9f5c542516 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -96,6 +96,11 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
 /// that it can be bufferized into a sequence of copies.
 void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns that forward concat-generated
+/// `tensor.insert_slice` destinations into single-use destination-style source
+/// producers.
+void populateForwardConcatInsertSliceDestPatterns(RewritePatternSet &patterns);
+
 using ControlFoldFn = std::function<bool(OpOperand *)>;
 
 /// Populates `patterns` with patterns that replace tensor ops (such as
diff --git a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
index 2870baa20757b..568f44c52315b 100644
--- a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRTosaToTensor
 
   LINK_LIBS PUBLIC
   MLIRTensorDialect
+  MLIRTensorTransforms
   MLIRTensorUtils
   MLIRIR
   MLIRPass
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 9bf9ca3ae7a89..9bd0e15f15a50 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/IR/PatternMatch.h"
@@ -461,4 +462,5 @@ void mlir::tosa::populateTosaToTensorConversionPatterns(
   patterns
       ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
           converter, patterns->getContext());
+  tensor::populateForwardConcatInsertSliceDestPatterns(*patterns);
 }
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
index 20bed05ecc11d..e7f3422c98cf0 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
 
 using namespace mlir;
@@ -41,9 +42,98 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
   }
 };
 
+/// Forward the destination tensor of concat generated tensor.insert_slice ops
+/// into single-use destination-style tensor producers. This avoids creating a
+/// producer on a temporary tensor that is immediately copied into the concat
+/// result tensor.
+///
+/// Before:
+/// %small = tensor.empty() : tensor<4xf32>
+/// %fill = linalg.fill ins(%cst : f32) outs(%small : tensor<4xf32>)
+///     -> tensor<4xf32>
+/// %init = tensor.empty() : tensor<8xf32>
+/// %insert0 = tensor.insert_slice %fill into %init[0] [4] [1]
+///     : tensor<4xf32> into tensor<8xf32>
+/// %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1]
+///     : tensor<4xf32> into tensor<8xf32>
+///
+/// After:
+/// %init = tensor.empty() : tensor<8xf32>
+/// %slice = tensor.extract_slice %init[0] [4] [1]
+///     : tensor<8xf32> to tensor<4xf32>
+/// %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<4xf32>)
+///     -> tensor<4xf32>
+/// %insert0 = tensor.insert_slice %fill into %init[0] [4] [1]
+///     : tensor<4xf32> into tensor<8xf32>
+/// %insert1 = tensor.insert_slice %arg0 into %insert0[4] [4] [1]
+///     : tensor<4xf32> into tensor<8xf32>
+struct ForwardConcatInsertSliceDest : public OpRewritePattern<InsertSliceOp> {
+  using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertSliceOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    // Only rewrite when the insert source is an SSA result with a single use.
+    Value source = insertOp.getSource();
+    auto sourceResult = dyn_cast<OpResult>(source);
+    if (!sourceResult || !source.hasOneUse())
+      return failure();
+
+    // Restrict to concat-style insert chains where the destination is either
+    // the initial tensor.empty or a previous tensor.insert_slice result.
+    Operation *destDef = insertOp.getDest().getDefiningOp();
+    if (!isa_and_present<EmptyOp, InsertSliceOp>(destDef))
+      return failure();
+
+    // The source producer must be destination-style on tensors so we can
+    // retarget its tied output to a slice of the final concat destination.
+    auto producer = source.getDefiningOp<DestinationStyleOpInterface>();
+    if (!producer || !producer.hasPureTensorSemantics())
+      return failure();
+
+    if (producer->getNumResults() != 1)
+      return failure();
+
+    OpOperand *tiedInit = producer.getTiedOpOperand(sourceResult);
+    if (!tiedInit)
+      return failure();
+
+    auto sourceType = dyn_cast<RankedTensorType>(source.getType());
+    if (!sourceType || !isa<RankedTensorType>(insertOp.getDest().getType()))
+      return failure();
+
+    auto mixedOffsets = insertOp.getMixedOffsets();
+    auto mixedSizes = insertOp.getMixedSizes();
+    auto mixedStrides = insertOp.getMixedStrides();
+
+    // Extract slice from the final destination
+    Value extractedDest = ExtractSliceOp::create(
+        rewriter, insertOp.getLoc(), sourceType, insertOp.getDest(),
+        mixedOffsets, mixedSizes, mixedStrides);
+
+    IRMapping mapping;
+    mapping.map(tiedInit->get(), extractedDest);
+    Operation *newProducer = rewriter.clone(*producer, mapping);
+    Value newSource = newProducer->getResult(sourceResult.getResultNumber());
+
+    // Rebuild insert_slice with the retargeted producer result, then erase the
+    // original producer (guaranteed to have a single use)
+    Value newInsert = InsertSliceOp::create(
+        rewriter, insertOp.getLoc(), newSource, insertOp.getDest(),
+        mixedOffsets, mixedSizes, mixedStrides);
+    rewriter.replaceOp(insertOp, newInsert);
+    rewriter.eraseOp(producer.getOperation());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateDecomposeTensorConcatPatterns(
     RewritePatternSet &patterns) {
   patterns.add<DecomposeTensorConcatOp>(patterns.getContext());
 }
+
+void mlir::tensor::populateForwardConcatInsertSliceDestPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ForwardConcatInsertSliceDest>(patterns.getContext());
+}
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 0a276e2a5c3d2..a0695bf127b1d 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -703,6 +703,25 @@ func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf
 
 // -----
 
+// CHECK-LABEL: @concat_forward_insert_slice_dest
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xf32>)
+func.func @concat_forward_insert_slice_dest(%arg0: tensor<4xf32>) -> tensor<8xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %init = tensor.empty() : tensor<4xf32>
+  %filled = linalg.fill ins(%cst : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
+  %0 = "tosa.concat"(%filled, %arg0) {axis = 0 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<8xf32>
+  // CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+  // CHECK: %[[SMALL:.*]] = tensor.empty() : tensor<4xf32>
+  // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SMALL]] : tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<8xf32>
+  // CHECK: %[[INSERT0:.*]] = tensor.insert_slice %[[FILL]] into %[[INIT]][0] [4] [1] : tensor<4xf32> into tensor<8xf32>
+  // CHECK: %[[INSERT1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INSERT0]][4] [4] [1] : tensor<4xf32> into tensor<8xf32>
+  // CHECK: return %[[INSERT1]] : tensor<8xf32>
+  return %0 : tensor<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @pad_variable_pad_const
 // CHECK-SAME: (%[[ARG0_SSA:.*]]: tensor<2x2xi32>, %[[PAD_INPUT_TENSOR_SSA:.*]]: tensor<1xi32>)
 func.func @pad_variable_pad_const(%arg0: tensor<2x2xi32>, %pad_input_tensor: tensor<1xi32>) -> tensor<4x5xi32> {



More information about the Mlir-commits mailing list