[Mlir-commits] [mlir] [mlir][bufferization] implement BufferizableOpInterface for concat op (PR #140171)

Jeremy Kun llvmlistbot at llvm.org
Thu May 15 20:57:49 PDT 2025


================
@@ -1048,6 +1048,103 @@ struct SplatOpInterface
   }
 };
 
+/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
+/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
+/// on subviews instead of memref.store.
+struct ConcatOpInterface
+    : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
+                                                    tensor::ConcatOp> {
+
+  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    return true;
+  }
+
+  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+                                      const AnalysisState &state) const {
+    return {{op->getResult(0), BufferRelation::Equivalent}};
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto concatOp = cast<tensor::ConcatOp>(op);
+
+    // Allocate memory.
+    Location loc = op->getLoc();
+    FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
+        rewriter, loc, concatOp.getResult(), options,
+        /*copy=*/false);
+    if (failed(tensorAlloc))
+      return failure();
+    auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+
+    // TODO: Implement memory space for this op.
+    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+      return op->emitError("memory space not implemented yet");
+
+    MemRefLayoutAttrInterface layout;
+    MemRefType memrefType =
+        MemRefType::get(concatOp.getResultType().getShape(),
+                        concatOp.getResultType().getElementType(), layout);
+    Value dstBuffer = rewriter.create<bufferization::ToMemrefOp>(
+        op->getLoc(), memrefType, *tensorAlloc);
+
+    // Extract the dimension for the concat op
+    uint64_t concatDim = concatOp.getDim();
+
+    SmallVector<OpFoldResult> offsets(tensorType.getRank(),
+                                      rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(tensorType.getRank(),
+                                      rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> sizes;
+    for (auto dimSize : tensorType.getShape()) {
+      sizes.push_back(rewriter.getIndexAttr(dimSize));
----------------
j2kun wrote:

Done

https://github.com/llvm/llvm-project/pull/140171


More information about the Mlir-commits mailing list