[Mlir-commits] [mlir] [mlir][bufferization] implement BufferizableOpInterface for concat op (PR #140171)

Jeremy Kun llvmlistbot at llvm.org
Thu May 15 20:57:26 PDT 2025


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/140171

>From 0c44188bfe415450a3e7888a5d229c8b68b22776 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 15 May 2025 16:38:37 -0700
Subject: [PATCH 1/3] add bufferization for concat op

---
 mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp  |  4 +-
 .../BufferizableOpInterfaceImpl.cpp           | 98 +++++++++++++++++++
 mlir/test/Dialect/Tensor/bufferize.mlir       | 42 ++++++++
 3 files changed, 142 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
index 8af087cbf0f61..e7d8f52d309c9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
@@ -49,8 +49,8 @@ void TensorDialect::initialize() {
       >();
   addInterfaces<TensorInlinerInterface>();
   declarePromisedInterfaces<
-      bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, DimOp,
-      EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
+      bufferization::BufferizableOpInterface, CastOp, CollapseShapeOp, ConcatOp,
+      DimOp, EmptyOp, ExpandShapeOp, ExtractSliceOp, ExtractOp, FromElementsOp,
       GenerateOp, InsertOp, InsertSliceOp, PadOp, ParallelInsertSliceOp, RankOp,
       ReshapeOp, SplatOp>();
   declarePromisedInterfaces<transform::FindPayloadReplacementOpInterface,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 31014172a9555..e19d6a50e706a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1048,6 +1048,103 @@ struct SplatOpInterface
   }
 };
 
+/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
+/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
+/// on subviews instead of memref.store.
+struct ConcatOpInterface
+    : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
+                                                    tensor::ConcatOp> {
+
+  bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    return true;
+  }
+
+  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+                                      const AnalysisState &state) const {
+    return {{op->getResult(0), BufferRelation::Equivalent}};
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto concatOp = cast<tensor::ConcatOp>(op);
+
+    // Allocate memory.
+    Location loc = op->getLoc();
+    FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
+        rewriter, loc, concatOp.getResult(), options,
+        /*copy=*/false);
+    if (failed(tensorAlloc))
+      return failure();
+    auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+
+    // TODO: Implement memory space for this op.
+    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+      return op->emitError("memory space not implemented yet");
+
+    MemRefLayoutAttrInterface layout;
+    MemRefType memrefType =
+        MemRefType::get(concatOp.getResultType().getShape(),
+                        concatOp.getResultType().getElementType(), layout);
+    Value dstBuffer = rewriter.create<bufferization::ToMemrefOp>(
+        op->getLoc(), memrefType, *tensorAlloc);
+
+    // Extract the dimension for the concat op
+    uint64_t concatDim = concatOp.getDim();
+
+    SmallVector<OpFoldResult> offsets(tensorType.getRank(),
+                                      rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(tensorType.getRank(),
+                                      rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> sizes;
+    for (auto dimSize : tensorType.getShape()) {
+      sizes.push_back(rewriter.getIndexAttr(dimSize));
+    }
+
+    int concatDimOffset = 0;
+    for (auto operand : concatOp.getInputs()) {
+      // Get the buffer for the operand.
+      FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
+      if (failed(srcBuffer))
+        return failure();
+
+      // Each operand may have a different size along the concat dimension,
+      // so the offset on that axis must accumulate through the loop, and the
+      // size must change to the size of the current operand.
+      auto operandTensorType = cast<RankedTensorType>(operand.getType());
+      int operandConcatDimSize = operandTensorType.getDimSize(concatDim);
+      sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
+      offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
+
+      // Create a subview of the destination buffer.
+      auto dstMemrefType = cast<MemRefType>(memrefType);
+      MemRefType subviewMemRefType =
+          memref::SubViewOp::inferRankReducedResultType(
+              operandTensorType.getShape(), dstMemrefType, offsets, sizes,
+              strides);
+      Value subview = rewriter.create<memref::SubViewOp>(
+          loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
+
+      // Copy the source buffer into the destination subview.
+      if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
+        return failure();
+
+      concatDimOffset += operandConcatDimSize;
+    }
+
+    replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
+    return success();
+  }
+};
+
 } // namespace
 } // namespace tensor
 } // namespace mlir
@@ -1057,6 +1154,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
     CastOp::attachInterface<CastOpInterface>(*ctx);
     CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
+    ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
     DimOp::attachInterface<DimOpInterface>(*ctx);
     EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index c1beed95f2006..a9ee707c670b9 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -615,6 +615,48 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
 
 // -----
 
+// CHECK-LABEL:   func @tensor.concat(
+// CHECK-SAME:        %[[F:.*]]: tensor<8xf32>)
+// CHECK:           %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
+// CHECK:           memref.copy %[[F_MEMREF]], %[[F_ALLOC_2:.*]] :
+// CHECK:           %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
+// CHECK:           memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
+// CHECK:           memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           return %[[RET]]
+// CHECK:         }
+func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
+  %t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
+  return %t : tensor<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func @tensor.concat_different_shapes(
+// CHECK-SAME:        %[[F:.*]]: tensor<8x4xf32>
+// CHECK-SAME:        %[[G:.*]]: tensor<8x5xf32>
+// CHECK-DAG:       %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK-DAG:       %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
+// CHECK:           memref.copy %[[G_MEMREF]], %[[F_ALLOC_2:.*]] :
+// CHECK:           %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
+// CHECK:           memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
+// CHECK:           memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           return %[[RET]]
+// CHECK:         }
+func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
+  %t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
+  return %t : tensor<8x9xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @tensor.splat_dynamic(
 // CHECK-SAME:  %[[F:[a-zA-Z0-9_]+]]: f32
 // CHECK-SAME:  %[[M:[a-zA-Z0-9_]+]]: index

>From 6ec00404c76b3b5848c8d0c83cb8944b5660dd45 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 15 May 2025 20:15:06 -0700
Subject: [PATCH 2/3] remove duplicates by marking operands not written-to

---
 .../Transforms/BufferizableOpInterfaceImpl.cpp       |  6 +++---
 mlir/test/Dialect/Tensor/bufferize.mlir              | 12 ++++--------
 2 files changed, 7 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index e19d6a50e706a..19d294b40244e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1059,7 +1059,7 @@ struct ConcatOpInterface
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
-    return true;
+    return false;
   }
 
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
@@ -1069,7 +1069,7 @@ struct ConcatOpInterface
 
   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                       const AnalysisState &state) const {
-    return {{op->getResult(0), BufferRelation::Equivalent}};
+    return {};
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -1109,7 +1109,7 @@ struct ConcatOpInterface
       sizes.push_back(rewriter.getIndexAttr(dimSize));
     }
 
-    int concatDimOffset = 0;
+    int64_t concatDimOffset = 0;
     for (auto operand : concatOp.getInputs()) {
       // Get the buffer for the operand.
       FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index a9ee707c670b9..1225eb840eed0 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -618,13 +618,11 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
 // CHECK-LABEL:   func @tensor.concat(
 // CHECK-SAME:        %[[F:.*]]: tensor<8xf32>)
 // CHECK:           %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
-// CHECK:           memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
-// CHECK:           memref.copy %[[F_MEMREF]], %[[F_ALLOC_2:.*]] :
 // CHECK:           %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
 // CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
-// CHECK:           memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
 // CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
-// CHECK:           memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW2]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:           return %[[RET]]
 // CHECK:         }
@@ -640,13 +638,11 @@ func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
 // CHECK-SAME:        %[[G:.*]]: tensor<8x5xf32>
 // CHECK-DAG:       %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
 // CHECK-DAG:       %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
-// CHECK:           memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
-// CHECK:           memref.copy %[[G_MEMREF]], %[[F_ALLOC_2:.*]] :
 // CHECK:           %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
 // CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
-// CHECK:           memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
 // CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
-// CHECK:           memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
+// CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:           return %[[RET]]
 // CHECK:         }

>From 5e441a53cc24fa8e182d3e1b4db644a748da9fab Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 15 May 2025 20:57:15 -0700
Subject: [PATCH 3/3] support dynamic tensors

---
 .../BufferizableOpInterfaceImpl.cpp           | 43 ++++++++++++++++---
 mlir/test/Dialect/Tensor/bufferize.mlir       | 26 +++++++++++
 2 files changed, 63 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 19d294b40244e..935716ed28711 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1099,17 +1099,35 @@ struct ConcatOpInterface
 
     // Extract the dimension for the concat op
     uint64_t concatDim = concatOp.getDim();
+    bool dynamicConcatDim = false;
 
     SmallVector<OpFoldResult> offsets(tensorType.getRank(),
                                       rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(tensorType.getRank(),
                                       rewriter.getIndexAttr(1));
     SmallVector<OpFoldResult> sizes;
-    for (auto dimSize : tensorType.getShape()) {
-      sizes.push_back(rewriter.getIndexAttr(dimSize));
+
+    for (const auto &[dimIdx, dimSize] :
+         llvm::enumerate(tensorType.getShape())) {
+      if (dimSize == ShapedType::kDynamic) {
+        auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimSize);
+        sizes.push_back(dimOp.getResult());
+        if (dimIdx == concatDim)
+          dynamicConcatDim = true;
+      } else {
+        sizes.push_back(rewriter.getIndexAttr(dimSize));
+      }
     }
 
     int64_t concatDimOffset = 0;
+    std::optional<Value> dynamicOffset;
+    std::optional<Value> dynamicSize;
+    if (dynamicConcatDim) {
+      // One or more operands have dynamic size, so we must accumulate the
+      // offset with arith ops.
+      dynamicOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    }
+
     for (auto operand : concatOp.getInputs()) {
       // Get the buffer for the operand.
       FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
@@ -1120,9 +1138,17 @@ struct ConcatOpInterface
       // so the offset on that axis must accumulate through the loop, and the
       // size must change to the size of the current operand.
       auto operandTensorType = cast<RankedTensorType>(operand.getType());
-      int operandConcatDimSize = operandTensorType.getDimSize(concatDim);
-      sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
-      offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
+      int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
+
+      if (dynamicConcatDim) {
+        offsets[concatDim] = dynamicOffset.value();
+        dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
+                          .getResult();
+        sizes[concatDim] = dynamicSize.value();
+      } else {
+        sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
+        offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
+      }
 
       // Create a subview of the destination buffer.
       auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1137,7 +1163,12 @@ struct ConcatOpInterface
       if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
         return failure();
 
-      concatDimOffset += operandConcatDimSize;
+      if (dynamicConcatDim) {
+        dynamicOffset = rewriter.create<arith::AddIOp>(
+            loc, dynamicOffset.value(), dynamicSize.value());
+      } else {
+        concatDimOffset += operandConcatDimSize;
+      }
     }
 
     replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 1225eb840eed0..999d63705a781 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -653,6 +653,32 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
 
 // -----
 
+// CHECK-LABEL:   func @tensor.concat_dynamic(
+// CHECK-SAME:        %[[F:.*]]: tensor<8x?xf32>,
+// CHECK-SAME:        %[[G:.*]]: tensor<8x?xf32>
+// CHECK-DAG:       %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK-DAG:       %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
+// CHECK-DAG:       %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK-DAG:       %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK:           %[[ALLOC:.*]] = memref.alloc
+// CHECK-SAME:                                    memref<8x?xf32>
+// CHECK-DAG:       %[[OFFSET:.*]] = arith.constant 0 : index
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK:           %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           return %[[RET]]
+// CHECK:         }
+func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> tensor<8x?xf32> {
+  %t = tensor.concat dim(1) %f, %g : (tensor<8x?xf32>, tensor<8x?xf32>) -> tensor<8x?xf32>
+  return %t : tensor<8x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @tensor.splat_dynamic(
 // CHECK-SAME:  %[[F:[a-zA-Z0-9_]+]]: f32
 // CHECK-SAME:  %[[M:[a-zA-Z0-9_]+]]: index



More information about the Mlir-commits mailing list