[Mlir-commits] [mlir] [mlir][bufferization] implement BufferizableOpInterface for concat op (PR #140171)
Jeremy Kun
llvmlistbot at llvm.org
Thu May 15 20:57:49 PDT 2025
================
@@ -1048,6 +1048,103 @@ struct SplatOpInterface
}
};
+/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
+/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
+/// on subviews instead of memref.store.
+struct ConcatOpInterface
+ : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
+ tensor::ConcatOp> {
+
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return true;
+ }
+
+ AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return {{op->getResult(0), BufferRelation::Equivalent}};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto concatOp = cast<tensor::ConcatOp>(op);
+
+ // Allocate memory.
+ Location loc = op->getLoc();
+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
+ rewriter, loc, concatOp.getResult(), options,
+ /*copy=*/false);
+ if (failed(tensorAlloc))
+ return failure();
+ auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+
+ // TODO: Implement memory space for this op.
+ if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+ return op->emitError("memory space not implemented yet");
+
+ MemRefLayoutAttrInterface layout;
+ MemRefType memrefType =
+ MemRefType::get(concatOp.getResultType().getShape(),
+ concatOp.getResultType().getElementType(), layout);
+ Value dstBuffer = rewriter.create<bufferization::ToMemrefOp>(
+ op->getLoc(), memrefType, *tensorAlloc);
+
+ // Extract the dimension for the concat op
+ uint64_t concatDim = concatOp.getDim();
+
+ SmallVector<OpFoldResult> offsets(tensorType.getRank(),
+ rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(tensorType.getRank(),
+ rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes;
+ for (auto dimSize : tensorType.getShape()) {
+ sizes.push_back(rewriter.getIndexAttr(dimSize));
----------------
j2kun wrote:
Done
https://github.com/llvm/llvm-project/pull/140171
More information about the Mlir-commits
mailing list