[Mlir-commits] [mlir] Restore #140171 with to_memref -> to_buffer (PR #140355)
Jeremy Kun
llvmlistbot at llvm.org
Fri May 16 22:09:38 PDT 2025
https://github.com/j2kun created https://github.com/llvm/llvm-project/pull/140355
https://github.com/llvm/llvm-project/pull/140171 was reverted because an op's name as changed and I neglected to rebase before merging.
>From 5e7fc62967aa3adb335b0e27711ca17aebdd34a7 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 16 May 2025 20:28:57 -0700
Subject: [PATCH 1/2] Restore [mlir][bufferization] implement
BufferizableOpInterface for concat op (#140171)
This restores the previously reverted commit with forward fixes
---
mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp | 4 +-
.../BufferizableOpInterfaceImpl.cpp | 129 ++++++++++++++++++
mlir/test/Dialect/Tensor/bufferize.mlir | 91 ++++++++++++
3 files changed, 222 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
index 8af087cbf0f61..e7d8f52d309c9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
@@ -49,8 +49,8 @@ void TensorDialect::initialize() {
>();
addInterfaces<TensorInlinerInterface>();
declarePromisedInterfaces<
- bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, DimOp,
- EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
+ bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, ConcatOp,
+ DimOp, EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
GenerateOp, InsertOp, InsertSliceOp, PadOp, ParallelInsertSliceOp, RankOp,
ReshapeOp, SplatOp>();
declarePromisedInterfaces<transform::FindPayloadReplacementOpInterface,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c0e697292d2a0..e1706b841eb31 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1048,6 +1048,134 @@ struct SplatOpInterface
}
};
+/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
+/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
+/// on subviews instead of memref.store.
+struct ConcatOpInterface
+ : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
+ tensor::ConcatOp> {
+
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return false;
+ }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return true;
+ }
+
+ AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ OpBuilder::InsertionGuard g(rewriter);
+ auto concatOp = cast<tensor::ConcatOp>(op);
+
+ // Allocate memory.
+ Location loc = op->getLoc();
+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
+ rewriter, loc, concatOp.getResult(), options,
+ /*copy=*/false);
+ if (failed(tensorAlloc))
+ return failure();
+ auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+
+ // TODO: Implement memory space for this op.
+ if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+ return op->emitError("memory space not implemented yet");
+
+ MemRefLayoutAttrInterface layout;
+ MemRefType memrefType =
+ MemRefType::get(concatOp.getResultType().getShape(),
+ concatOp.getResultType().getElementType(), layout);
+ Value dstBuffer = rewriter.create<bufferization::ToMemrefOp>(
+ op->getLoc(), memrefType, *tensorAlloc);
+
+ // Extract the dimension for the concat op
+ uint64_t concatDim = concatOp.getDim();
+ bool dynamicConcatDim = false;
+
+ SmallVector<OpFoldResult> offsets(tensorType.getRank(),
+ rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(tensorType.getRank(),
+ rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes;
+
+ for (const auto &[dimIdx, dimSize] :
+ llvm::enumerate(tensorType.getShape())) {
+ if (dimSize == ShapedType::kDynamic) {
+ auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
+ sizes.push_back(dimOp.getResult());
+ if (dimIdx == concatDim)
+ dynamicConcatDim = true;
+ } else {
+ sizes.push_back(rewriter.getIndexAttr(dimSize));
+ }
+ }
+
+ int64_t concatDimOffset = 0;
+ std::optional<Value> dynamicOffset;
+ std::optional<Value> dynamicSize;
+ if (dynamicConcatDim) {
+ // One or more operands have dynamic size, so we must accumulate the
+ // offset with arith ops.
+ dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ }
+
+ for (auto operand : concatOp.getInputs()) {
+ // Get the buffer for the operand.
+ FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
+ if (failed(srcBuffer))
+ return failure();
+
+ // Each operand may have a different size along the concat dimension,
+ // so the offset on that axis must accumulate through the loop, and the
+ // size must change to the size of the current operand.
+ auto operandTensorType = cast<RankedTensorType>(operand.getType());
+ int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
+
+ if (dynamicConcatDim) {
+ offsets[concatDim] = dynamicOffset.value();
+ dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
+ .getResult();
+ sizes[concatDim] = dynamicSize.value();
+ } else {
+ sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
+ offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
+ }
+
+ // Create a subview of the destination buffer.
+ auto dstMemrefType = cast<MemRefType>(memrefType);
+ MemRefType subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ operandTensorType.getShape(), dstMemrefType, offsets, sizes,
+ strides);
+ Value subview = rewriter.create<memref::SubViewOp>(
+ loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
+
+ // Copy the source buffer into the destination subview.
+ if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
+ return failure();
+
+ if (dynamicConcatDim) {
+ dynamicOffset = rewriter.create<arith::AddIOp>(
+ loc, dynamicOffset.value(), dynamicSize.value());
+ } else {
+ concatDimOffset += operandConcatDimSize;
+ }
+ }
+
+ replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
+ return success();
+ }
+};
+
} // namespace
} // namespace tensor
} // namespace mlir
@@ -1057,6 +1185,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
+ ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
DimOp::attachInterface<DimOpInterface>(*ctx);
EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 567c4abea488e..308e52a6d9b9a 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -615,6 +615,97 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
// -----
+// CHECK-LABEL: func @tensor.concat(
+// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
+// CHECK: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
+ %t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
+ return %t : tensor<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor.concat_different_shapes(
+// CHECK-SAME: %[[F:.*]]: tensor<8x4xf32>
+// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
+ %t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
+ return %t : tensor<8x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor.concat_dynamic(
+// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>,
+// CHECK-SAME: %[[G:.*]]: tensor<8x?xf32>
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc
+// CHECK-SAME: memref<8x?xf32>
+// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> tensor<8x?xf32> {
+ %t = tensor.concat dim(1) %f, %g : (tensor<8x?xf32>, tensor<8x?xf32>) -> tensor<8x?xf32>
+ return %t : tensor<8x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor.concat_dynamic_nonconcat_dim(
+// CHECK-SAME: %[[F:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[G:.*]]: tensor<?x?xf32>
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK: %[[ALLOC:.*]] = memref.alloc
+// CHECK-SAME: memref<?x?xf32>
+// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %t = tensor.concat dim(1) %f, %g : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %t : tensor<?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.splat_dynamic(
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
>From 0832a8693777d17fbcd5d13d31e28deb11ae7133 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 16 May 2025 22:08:26 -0700
Subject: [PATCH 2/2] to_memref -> to_buffer
---
.../Transforms/BufferizableOpInterfaceImpl.cpp | 2 +-
mlir/test/Dialect/Tensor/bufferize.mlir | 14 +++++++-------
2 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index e1706b841eb31..6525e58d002a2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1094,7 +1094,7 @@ struct ConcatOpInterface
MemRefType memrefType =
MemRefType::get(concatOp.getResultType().getShape(),
concatOp.getResultType().getElementType(), layout);
- Value dstBuffer = rewriter.create<bufferization::ToMemrefOp>(
+ Value dstBuffer = rewriter.create<bufferization::ToBufferOp>(
op->getLoc(), memrefType, *tensorAlloc);
// Extract the dimension for the concat op
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 308e52a6d9b9a..e9c3ba7e3b970 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -617,7 +617,7 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
// CHECK-LABEL: func @tensor.concat(
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
-// CHECK: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
@@ -636,8 +636,8 @@ func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
// CHECK-LABEL: func @tensor.concat_different_shapes(
// CHECK-SAME: %[[F:.*]]: tensor<8x4xf32>
// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
-// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
-// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
@@ -656,8 +656,8 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
// CHECK-LABEL: func @tensor.concat_dynamic(
// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>,
// CHECK-SAME: %[[G:.*]]: tensor<8x?xf32>
-// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
-// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
@@ -682,8 +682,8 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
// CHECK-LABEL: func @tensor.concat_dynamic_nonconcat_dim(
// CHECK-SAME: %[[F:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[G:.*]]: tensor<?x?xf32>
-// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
-// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
More information about the Mlir-commits
mailing list