[Mlir-commits] [mlir] [mlir][tosa] Forward concat insert_slice destination into DPS provider (PR #183490)

Dhruv Chauhan llvmlistbot at llvm.org
Thu Mar 5 06:00:00 PST 2026


================
@@ -454,11 +456,84 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
   }
 };
 
+// Forward the destination tensor of concat generated tensor.insert_slice ops
+// into single-use destination-style tensor producers. This avoids creating a
+// producer on a temporary tensor that is immediately copied into the concat
+// result tensor.
+struct ForwardConcatInsertSliceDest
+    : public OpConversionPattern<tensor::InsertSliceOp> {
+  using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tensor::InsertSliceOp insertOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only rewrite when the insert source is an SSA result with a single use.
+    Value source = adaptor.getSource();
+    auto sourceResult = dyn_cast<OpResult>(source);
+    if (!sourceResult || !source.hasOneUse())
+      return failure();
+
+    // Restrict to concat-style insert chains where the destination is either
+    // the initial tensor.empty or a previous tensor.insert_slice result.
+    Operation *destDef = adaptor.getDest().getDefiningOp();
+    if (!isa_and_present<tensor::EmptyOp, tensor::InsertSliceOp>(destDef))
+      return failure();
+
+    // The source producer must be destination-style on tensors so we can
+    // retarget its tied output to a slice of the final concat destination.
+    auto producer = source.getDefiningOp<DestinationStyleOpInterface>();
+    if (!producer || !producer.hasPureTensorSemantics())
+      return failure();
+
+    if (producer->getNumResults() != 1)
+      return failure();
+
+    OpOperand *tiedInit = producer.getTiedOpOperand(sourceResult);
+    if (!tiedInit)
+      return failure();
+
+    auto sourceType = dyn_cast<RankedTensorType>(source.getType());
+    if (!sourceType || !isa<RankedTensorType>(adaptor.getDest().getType()))
+      return failure();
+
+    // Materialize explicit index values for offset/size/stride.
+    SmallVector<Value> offsets, sizes, strides;
+    for (OpFoldResult ofr : insertOp.getMixedOffsets())
+      offsets.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
+    for (OpFoldResult ofr : insertOp.getMixedSizes())
+      sizes.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
+    for (OpFoldResult ofr : insertOp.getMixedStrides())
+      strides.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
----------------
dchauhan-arm wrote:

Done

https://github.com/llvm/llvm-project/pull/183490


More information about the Mlir-commits mailing list