[Mlir-commits] [mlir] [mlir][tosa] Forward concat insert_slice destination into DPS provider (PR #183490)
Dhruv Chauhan
llvmlistbot at llvm.org
Thu Feb 26 02:56:02 PST 2026
https://github.com/dchauhan-arm created https://github.com/llvm/llvm-project/pull/183490
In TosaToTensor, forward concat insert_slice destination slices into single use destination style producers, which avoids creating temp producer results that are immediately copied into the concat result. Add regression test for concat + fill forwarding
>From e7f56d22b4e8891bb7777ef3d809e22f0525ffef Mon Sep 17 00:00:00 2001
From: Dhruv Chauhan <dhruv.chauhan at arm.com>
Date: Thu, 26 Feb 2026 09:57:33 +0000
Subject: [PATCH] [mlir][tosa] Forward concat insert_slice destination into DPS
provider
In TosaToTensor, forward concat insert_slice destination slices into
single use destination style producers, which avoids creating temp
producer results that are immediately copied into the concat result.
Add regression test for concat + fill forwarding
---
.../Conversion/TosaToTensor/TosaToTensor.cpp | 81 ++++++++++++++++++-
.../TosaToTensor/tosa-to-tensor.mlir | 22 ++++-
2 files changed, 98 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 9bf9ca3ae7a89..60d16ecc2d507 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -13,9 +13,11 @@
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
@@ -454,11 +456,84 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
}
};
+// Forward the destination tensor of concat generated tensor.insert_slice ops
+// into single-use destination-style tensor producers. This avoids creating a
+// producer on a temporary tensor that is immediately copied into the concat
+// result tensor.
+struct ForwardConcatInsertSliceDest
+ : public OpConversionPattern<tensor::InsertSliceOp> {
+ using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tensor::InsertSliceOp insertOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only rewrite when the insert source is an SSA result with a single use.
+ Value source = adaptor.getSource();
+ auto sourceResult = dyn_cast<OpResult>(source);
+ if (!sourceResult || !source.hasOneUse())
+ return failure();
+
+ // Restrict to concat-style insert chains where the destination is either
+ // the initial tensor.empty or a previous tensor.insert_slice result.
+ Operation *destDef = adaptor.getDest().getDefiningOp();
+ if (!isa_and_present<tensor::EmptyOp, tensor::InsertSliceOp>(destDef))
+ return failure();
+
+ // The source producer must be destination-style on tensors so we can
+ // retarget its tied output to a slice of the final concat destination.
+ auto producer = source.getDefiningOp<DestinationStyleOpInterface>();
+ if (!producer || !producer.hasPureTensorSemantics())
+ return failure();
+
+ if (producer->getNumResults() != 1)
+ return failure();
+
+ OpOperand *tiedInit = producer.getTiedOpOperand(sourceResult);
+ if (!tiedInit)
+ return failure();
+
+ auto sourceType = dyn_cast<RankedTensorType>(source.getType());
+ if (!sourceType || !isa<RankedTensorType>(adaptor.getDest().getType()))
+ return failure();
+
+ // Materialize explicit index values for offset/size/stride.
+ SmallVector<Value> offsets, sizes, strides;
+ for (OpFoldResult ofr : insertOp.getMixedOffsets())
+ offsets.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
+ for (OpFoldResult ofr : insertOp.getMixedSizes())
+ sizes.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
+ for (OpFoldResult ofr : insertOp.getMixedStrides())
+ strides.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
+
+ // Extract slice from the final destination
+ Value extractedDest = tensor::ExtractSliceOp::create(
+ rewriter, insertOp.getLoc(), sourceType, adaptor.getDest(), offsets,
+ sizes, strides);
+
+ IRMapping mapping;
+ mapping.map(tiedInit->get(), extractedDest);
+ Operation *newProducer = rewriter.clone(*producer, mapping);
+ Value newSource = newProducer->getResult(sourceResult.getResultNumber());
+
+ // Rebuild insert_slice with the retargeted producer result, then erase the
+ // original producer (guaranteed to have a single use.)
+ Value newInsert = tensor::InsertSliceOp::create(
+ rewriter, insertOp.getLoc(), newSource, adaptor.getDest(), offsets,
+ sizes, strides);
+ rewriter.replaceOp(insertOp, newInsert);
+ rewriter.eraseOp(producer.getOperation());
+ return success();
+ }
+};
+
} // namespace
void mlir::tosa::populateTosaToTensorConversionPatterns(
const TypeConverter &converter, RewritePatternSet *patterns) {
- patterns
- ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
- converter, patterns->getContext());
+ patterns->add<ConcatConverter, ForwardConcatInsertSliceDest, PadConverter,
+ ReshapeConverter, SliceConverter>(converter,
+ patterns->getContext());
}
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 0a276e2a5c3d2..f8a50ef43ebc9 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -641,7 +641,7 @@ func.func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> ()
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[DIM0]], 0] [%[[DIM3]], 3] [1, 1] : tensor<?x3xf32> into tensor<?x3xf32>
%0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i32} : (tensor<?x3xf32>, tensor<?x3xf32>) -> (tensor<?x3xf32>)
- return
+ return ()
}
// -----
@@ -660,7 +660,6 @@ func.func @concat_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>,
// CHECK-DAG: %[[OFFSET2:.+]] = arith.addi %[[OFFSET1]], %[[DIM2_2]] : index
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x1xf32>
// CHECK-DAG: %[[C0_3:.+]] = arith.constant 0 : index
- // CHECK-DAG: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0_3]] : tensor<?x1xf32>
// CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM_4]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x1xf32>
// CHECK-DAG: %[[C0_4:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM_6:.+]] = tensor.dim %[[ARG1]], %[[C0_4]] : tensor<?x1xf32>
@@ -703,6 +702,25 @@ func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf
// -----
+// CHECK-LABEL: @concat_forward_insert_slice_dest
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xf32>)
+func.func @concat_forward_insert_slice_dest(%arg0: tensor<4xf32>) -> tensor<8xf32> {
+ %cst = arith.constant 1.000000e+00 : f32
+ %init = tensor.empty() : tensor<4xf32>
+ %filled = linalg.fill ins(%cst : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32>
+ %0 = "tosa.concat"(%filled, %arg0) {axis = 0 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<8xf32>
+ // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<8xf32>
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[INIT]][0] [4] [1] : tensor<8xf32> to tensor<4xf32>
+ // CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%[[SLICE]] : tensor<4xf32>) -> tensor<4xf32>
+ // CHECK-NOT: tensor.insert_slice %filled
+ // CHECK: %[[INSERT0:.*]] = tensor.insert_slice %[[FILL]] into %[[INIT]][0] [4] [1] [1]
+ // CHECK: %[[INSERT1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INSERT0]][4] [4] [1] [1]
+ // CHECK: return %[[INSERT1]] : tensor<8xf32>
+ return %0 : tensor<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @pad_variable_pad_const
// CHECK-SAME: (%[[ARG0_SSA:.*]]: tensor<2x2xi32>, %[[PAD_INPUT_TENSOR_SSA:.*]]: tensor<1xi32>)
func.func @pad_variable_pad_const(%arg0: tensor<2x2xi32>, %pad_input_tensor: tensor<1xi32>) -> tensor<4x5xi32> {
More information about the Mlir-commits
mailing list