[Mlir-commits] [mlir] [mlir][tosa] Forward concat insert_slice destination into DPS provider (PR #183490)
Dhruv Chauhan
llvmlistbot at llvm.org
Thu Mar 5 06:00:00 PST 2026
================
@@ -454,11 +456,84 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
}
};
+// Forward the destination tensor of concat generated tensor.insert_slice ops
+// into single-use destination-style tensor producers. This avoids creating a
+// producer on a temporary tensor that is immediately copied into the concat
+// result tensor.
+struct ForwardConcatInsertSliceDest
+ : public OpConversionPattern<tensor::InsertSliceOp> {
+ using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tensor::InsertSliceOp insertOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only rewrite when the insert source is an SSA result with a single use.
+ Value source = adaptor.getSource();
+ auto sourceResult = dyn_cast<OpResult>(source);
+ if (!sourceResult || !source.hasOneUse())
+ return failure();
+
+ // Restrict to concat-style insert chains where the destination is either
+ // the initial tensor.empty or a previous tensor.insert_slice result.
+ Operation *destDef = adaptor.getDest().getDefiningOp();
+ if (!isa_and_present<tensor::EmptyOp, tensor::InsertSliceOp>(destDef))
+ return failure();
+
+ // The source producer must be destination-style on tensors so we can
+ // retarget its tied output to a slice of the final concat destination.
+ auto producer = source.getDefiningOp<DestinationStyleOpInterface>();
+ if (!producer || !producer.hasPureTensorSemantics())
+ return failure();
+
+ if (producer->getNumResults() != 1)
+ return failure();
+
+ OpOperand *tiedInit = producer.getTiedOpOperand(sourceResult);
+ if (!tiedInit)
+ return failure();
+
+ auto sourceType = dyn_cast<RankedTensorType>(source.getType());
+ if (!sourceType || !isa<RankedTensorType>(adaptor.getDest().getType()))
+ return failure();
+
+ // Materialize explicit index values for offset/size/stride.
+ SmallVector<Value> offsets, sizes, strides;
+ for (OpFoldResult ofr : insertOp.getMixedOffsets())
+ offsets.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
+ for (OpFoldResult ofr : insertOp.getMixedSizes())
+ sizes.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
+ for (OpFoldResult ofr : insertOp.getMixedStrides())
+ strides.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, insertOp.getLoc(), ofr));
----------------
dchauhan-arm wrote:
Done
https://github.com/llvm/llvm-project/pull/183490
More information about the Mlir-commits
mailing list