[Mlir-commits] [mlir] [mlir][Linalg] implement bufferization for `linalg.pack` (PR #177982)
Ryutaro Okada
llvmlistbot at llvm.org
Mon Jan 26 08:38:51 PST 2026
https://github.com/sakupan102 updated https://github.com/llvm/llvm-project/pull/177982
>From 7bfe3ae09cfeadf8925094432554cab86c032e93 Mon Sep 17 00:00:00 2001
From: Ryutaro Okada <1015ryu88 at gmail.com>
Date: Tue, 27 Jan 2026 00:46:44 +0900
Subject: [PATCH 1/2] [mlir][Linalg] implement bufferization for `linalg.pack`
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Add a BufferizableOpInterface implementation for linalg.pack
now that pack supports memref semantics https://github.com/llvm/llvm-project/commit/4b066c7fff3455dc547fabb676583391febe41e9.
This completes the op’s bufferization path and avoids copy-before-write
for destination operands.
Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
---
.../BufferizableOpInterfaceImpl.cpp | 42 +++++++++++++++++++
mlir/test/Dialect/Linalg/bufferize.mlir | 17 ++++++++
2 files changed, 59 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 3512ecd9d2eb2..60c685578682a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -191,6 +191,47 @@ struct SoftmaxOpInterface
return success();
}
};
+
+struct PackOpInterface
+ : public DstBufferizableOpInterfaceExternalModel<PackOpInterface,
+ linalg::PackOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ auto packOp = cast<linalg::PackOp>(op);
+ return !packOp.isDpsInit(&opOperand);
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
+ auto packOp = cast<linalg::PackOp>(op);
+ if (packOp.hasPureBufferSemantics())
+ return success();
+ if (!packOp.hasPureTensorSemantics())
+ return packOp.emitError() << "op does not have pure tensor semantics";
+
+ FailureOr<Value> sourceBuffer =
+ getBuffer(rewriter, packOp.getSource(), options, state);
+ if (failed(sourceBuffer))
+ return failure();
+ FailureOr<Value> destBuffer =
+ getBuffer(rewriter, packOp.getDest(), options, state);
+ if (failed(destBuffer))
+ return failure();
+
+ SmallVector<Value> operands;
+ operands.push_back(*sourceBuffer);
+ operands.push_back(*destBuffer);
+ if (auto val = packOp.getPaddingValue())
+ operands.push_back(val);
+ llvm::append_range(operands, packOp.getInnerTiles());
+
+ linalg::PackOp::create(rewriter, packOp.getLoc(), TypeRange{}, operands,
+ op->getAttrs());
+ replaceOpWithBufferizedValues(rewriter, op, *destBuffer);
+ return success();
+ }
+};
} // namespace
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -206,5 +247,6 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
>::registerOpInterface(ctx);
SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
+ PackOp::attachInterface<PackOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 1c6cb88fa028b..2cb09c39b5776 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -206,3 +206,20 @@ func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf
outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
return %1 : tensor<2x16x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @bufferize_pack(
+// CHECK-SAME: %[[SRC:.*]]: tensor<128x256xf32>, %[[DST:.*]]: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
+// CHECK-DAG: %[[SRC_BUF:.*]] = bufferization.to_buffer %[[SRC]] : tensor<128x256xf32> to memref<128x256xf32>
+// CHECK-DAG: %[[DST_BUF:.*]] = memref.alloc() {{.*}} : memref<16x8x8x32xf32>
+// CHECK-NOT: memref.copy
+// CHECK: linalg.pack %[[SRC_BUF]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DST_BUF]] : memref<128x256xf32> -> memref<16x8x8x32xf32>
+// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[DST_BUF]] : memref<16x8x8x32xf32> to tensor<16x8x8x32xf32>
+// CHECK: return %[[RESULT]] : tensor<16x8x8x32xf32>
+func.func @bufferize_pack(%source: tensor<128x256xf32>, %dest: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
+ %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
+ return %0 : tensor<16x8x8x32xf32>
+}
+
>From 340a43568744f6c2cefad3426a9695b5f1298678 Mon Sep 17 00:00:00 2001
From: Ryutaro Okada <1015ryu88 at gmail.com>
Date: Tue, 27 Jan 2026 01:38:29 +0900
Subject: [PATCH 2/2] expand test to include padding_value and outer_dims_perm
Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
---
mlir/test/Dialect/Linalg/bufferize.mlir | 24 +++++++++++++-----------
1 file changed, 13 insertions(+), 11 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 2cb09c39b5776..49585af730807 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -210,16 +210,18 @@ func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf
// -----
// CHECK-LABEL: func @bufferize_pack(
-// CHECK-SAME: %[[SRC:.*]]: tensor<128x256xf32>, %[[DST:.*]]: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
-// CHECK-DAG: %[[SRC_BUF:.*]] = bufferization.to_buffer %[[SRC]] : tensor<128x256xf32> to memref<128x256xf32>
-// CHECK-DAG: %[[DST_BUF:.*]] = memref.alloc() {{.*}} : memref<16x8x8x32xf32>
+// CHECK-SAME: %[[SRC:.*]]: tensor<200x127x256xf32>, %[[DST:.*]]: tensor<256x64x200x2xf32>) -> tensor<256x64x200x2xf32> {
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[SRC_BUF:.*]] = bufferization.to_buffer %[[SRC]] : tensor<200x127x256xf32> to memref<200x127x256xf32>
+// CHECK-DAG: %[[DST_BUF:.*]] = memref.alloc() {{.*}} : memref<256x64x200x2xf32>
// CHECK-NOT: memref.copy
-// CHECK: linalg.pack %[[SRC_BUF]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DST_BUF]] : memref<128x256xf32> -> memref<16x8x8x32xf32>
-// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[DST_BUF]] : memref<16x8x8x32xf32> to tensor<16x8x8x32xf32>
-// CHECK: return %[[RESULT]] : tensor<16x8x8x32xf32>
-func.func @bufferize_pack(%source: tensor<128x256xf32>, %dest: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
- %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
- into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
- return %0 : tensor<16x8x8x32xf32>
+// CHECK: linalg.pack %[[SRC_BUF]] padding_value(%[[CST]] : f32) outer_dims_perm = [2, 1, 0] inner_dims_pos = [1] inner_tiles = [2] into %[[DST_BUF]] : memref<200x127x256xf32> -> memref<256x64x200x2xf32>
+// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[DST_BUF]] : memref<256x64x200x2xf32> to tensor<256x64x200x2xf32>
+// CHECK: return %[[RESULT]] : tensor<256x64x200x2xf32>
+func.func @bufferize_pack(%arg0: tensor<200x127x256xf32>, %arg1: tensor<256x64x200x2xf32>) -> tensor<256x64x200x2xf32> {
+ %pad = arith.constant 0.0 : f32
+ %0 = linalg.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [2, 1, 0]
+ inner_dims_pos = [1] inner_tiles = [2] into %arg1
+ : tensor<200x127x256xf32> -> tensor<256x64x200x2xf32>
+ return %0 : tensor<256x64x200x2xf32>
}
-
More information about the Mlir-commits
mailing list