[Mlir-commits] [mlir] [mlir][tosa] Forward concat insert_slice destination into DPS provider (PR #183490)

Dhruv Chauhan llvmlistbot at llvm.org
Thu Feb 26 02:56:02 PST 2026


https://github.com/dchauhan-arm created https://github.com/llvm/llvm-project/pull/183490

In TosaToTensor, forward concat insert_slice destination slices into single use destination style producers, which avoids creating temp producer results that are immediately copied into the concat result. Add regression test for concat + fill forwarding

>From e7f56d22b4e8891bb7777ef3d809e22f0525ffef Mon Sep 17 00:00:00 2001
From: Dhruv Chauhan <dhruv.chauhan at arm.com>
Date: Thu, 26 Feb 2026 09:57:33 +0000
Subject: [PATCH] [mlir][tosa] Forward concat insert_slice destination into DPS
 provider

In TosaToTensor, forward concat insert_slice destination slices into
single use destination style producers, which avoids creating temp
producer results that are immediately copied into the concat result.
Add regression test for concat + fill forwarding
---
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 81 ++++++++++++++++++-
 .../TosaToTensor/tosa-to-tensor.mlir          | 22 ++++-
 2 files changed, 98 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 9bf9ca3ae7a89..60d16ecc2d507 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -13,9 +13,11 @@
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/STLExtras.h"
@@ -454,11 +456,84 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
   }
 };
 
+// Forward the destination tensor of concat generated tensor.insert_slice ops
+// into single-use destination-style tensor producers. This avoids creating a
+// producer on a temporary tensor that is immediately copied into the concat
+// result tensor.
+struct ForwardConcatInsertSliceDest
+    : public OpConversionPattern<tensor::InsertSliceOp> {
+  using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tensor::InsertSliceOp insertOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only rewrite when the insert source is an SSA result with a single use.
+    Value source = adaptor.getSource();
+    auto sourceResult = dyn_cast<OpResult>(source);
+    if (!sourceResult || !source.hasOneUse())
+      return failure();
+
+    // Restrict to concat-style insert chains where the destination is either
+    // the initial tensor.empty or a previous tensor.insert_slice result.
+    Operation *destDef = adaptor.getDest().getDefiningOp();
+    if (!isa_and_present<tensor::EmptyOp, tensor::InsertSliceOp>(destDef))
+      return failure();
+
+    // The source producer must be destination-style on tensors so we can
+    // retarget its tied output to a slice of the final concat destination.
+    auto producer = source.getDefiningOp<DestinationStyleOpInterface>();
+    if (!producer || !producer.hasPureTensorSemantics())
+      return failure();
+
+    if (producer->getNumResults() != 1)
+      return failure();
+
+    OpOperand *tiedInit = producer.getTiedOpOperand(sourceResult);
+    if (!tiedInit)
+      return failure();
+
+    auto sourceType = dyn_cast<RankedTensorType>(source.getType());
+    if (!sourceType || !isa<RankedTensorType>(adaptor.getDest().getType()))
+      return failure();
+
+    // Materialize explicit index values for offset/size/stride.
+    SmallVector<Value> offsets, sizes, strides;
+    for (OpFoldResult ofr : insertOp.getMixedOffsets())
+      offsets.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
+    for (OpFoldResult ofr : insertOp.getMixedSizes())
+      sizes.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
+    for (OpFoldResult ofr : insertOp.getMixedStrides())
+      strides.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
+
+    // Extract slice from the final destination
+    Value extractedDest = tensor::ExtractSliceOp::create(
+        rewriter, insertOp.getLoc(), sourceType, adaptor.getDest(), offsets,
+        sizes, strides);
+
+    IRMapping mapping;
+    mapping.map(tiedInit->get(), extractedDest);
+    Operation *newProducer = rewriter.clone(*producer, mapping);
+    Value newSource = newProducer->getResult(sourceResult.getResultNumber());
+
+    // Rebuild insert_slice with the retargeted producer result, then erase the
+    // original producer (guaranteed to have a single use.)
+    Value newInsert = tensor::InsertSliceOp::create(
+        rewriter, insertOp.getLoc(), newSource, adaptor.getDest(), offsets,
+        sizes, strides);
+    rewriter.replaceOp(insertOp, newInsert);
+    rewriter.eraseOp(producer.getOperation());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToTensorConversionPatterns(
     const TypeConverter &converter, RewritePatternSet *patterns) {
-  patterns
-      ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
-          converter, patterns->getContext());
+  patterns->add<ConcatConverter, ForwardConcatInsertSliceDest, PadConverter,
+                ReshapeConverter, SliceConverter>(converter,
+                                                  patterns->getContext());
 }
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 0a276e2a5c3d2..f8a50ef43ebc9 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -641,7 +641,7 @@ func.func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> ()
   // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[DIM0]], 0] [%[[DIM3]], 3] [1, 1] : tensor<?x3xf32> into tensor<?x3xf32>
 
   %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i32} : (tensor<?x3xf32>, tensor<?x3xf32>)  -> (tensor<?x3xf32>)
-  return
+  return ()
 }
 
 // -----
@@ -660,7 +660,6 @@ func.func @concat_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>,
   // CHECK-DAG: %[[OFFSET2:.+]] = arith.addi %[[OFFSET1]], %[[DIM2_2]] : index
   // CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x1xf32>
   // CHECK-DAG: %[[C0_3:.+]] = arith.constant 0 : index
-  // CHECK-DAG: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0_3]] : tensor<?x1xf32>
   // CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM_4]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x1xf32>
   // CHECK-DAG: %[[C0_4:.+]] = arith.constant 0 : index
   // CHECK-DAG: %[[DIM_6:.+]] = tensor.dim %[[ARG1]], %[[C0_4]] : tensor<?x1xf32>
@@ -703,6 +702,25 @@ func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf
 
 // -----
 
+// CHECK-LABEL: @concat_forward_insert_slice_dest
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xf32>)
+func.func @concat_forward_insert_slice_dest(%arg0: tensor<4xf32>) -> tensor<8xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %init = tensor.empty() : tensor<4xf32>
+  %filled = linalg.fill ins(%cst : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
+  %0 = "tosa.concat"(%filled, %arg0) {axis = 0 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<8xf32>
+  // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<8xf32>
+  // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[INIT]][0] [4] [1] : tensor<8xf32> to tensor<4xf32>
+  // CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%[[SLICE]] : tensor<4xf32>) -> tensor<4xf32>
+  // CHECK-NOT: tensor.insert_slice %filled
+  // CHECK: %[[INSERT0:.*]] = tensor.insert_slice %[[FILL]] into %[[INIT]][0] [4] [1] [1]
+  // CHECK: %[[INSERT1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INSERT0]][4] [4] [1] [1]
+  // CHECK: return %[[INSERT1]] : tensor<8xf32>
+  return %0 : tensor<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @pad_variable_pad_const
 // CHECK-SAME: (%[[ARG0_SSA:.*]]: tensor<2x2xi32>, %[[PAD_INPUT_TENSOR_SSA:.*]]: tensor<1xi32>)
 func.func @pad_variable_pad_const(%arg0: tensor<2x2xi32>, %pad_input_tensor: tensor<1xi32>) -> tensor<4x5xi32> {



More information about the Mlir-commits mailing list