[Mlir-commits] [mlir] [mlir][Linalg] implement bufferization for `linalg.pack` (PR #177982)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 26 07:48:38 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Ryutaro Okada (sakupan102)

<details>
<summary>Changes</summary>

Add a BufferizableOpInterface implementation for linalg.pack now that pack supports memref semantics https://github.com/llvm/llvm-project/commit/4b066c7fff3455dc547fabb676583391febe41e9. This completes the op’s bufferization path and avoids copy-before-write for destination operands.

---
Full diff: https://github.com/llvm/llvm-project/pull/177982.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp (+42) 
- (modified) mlir/test/Dialect/Linalg/bufferize.mlir (+17) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 3512ecd9d2eb2..60c685578682a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -191,6 +191,47 @@ struct SoftmaxOpInterface
     return success();
   }
 };
+
+struct PackOpInterface
+    : public DstBufferizableOpInterfaceExternalModel<PackOpInterface,
+                                                     linalg::PackOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    auto packOp = cast<linalg::PackOp>(op);
+    return !packOp.isDpsInit(&opOperand);
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options,
+                          BufferizationState &state) const {
+    auto packOp = cast<linalg::PackOp>(op);
+    if (packOp.hasPureBufferSemantics())
+      return success();
+    if (!packOp.hasPureTensorSemantics())
+      return packOp.emitError() << "op does not have pure tensor semantics";
+
+    FailureOr<Value> sourceBuffer =
+        getBuffer(rewriter, packOp.getSource(), options, state);
+    if (failed(sourceBuffer))
+      return failure();
+    FailureOr<Value> destBuffer =
+        getBuffer(rewriter, packOp.getDest(), options, state);
+    if (failed(destBuffer))
+      return failure();
+
+    SmallVector<Value> operands;
+    operands.push_back(*sourceBuffer);
+    operands.push_back(*destBuffer);
+    if (auto val = packOp.getPaddingValue())
+      operands.push_back(val);
+    llvm::append_range(operands, packOp.getInnerTiles());
+
+    linalg::PackOp::create(rewriter, packOp.getLoc(), TypeRange{}, operands,
+                           op->getAttrs());
+    replaceOpWithBufferizedValues(rewriter, op, *destBuffer);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -206,5 +247,6 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
         >::registerOpInterface(ctx);
 
     SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
+    PackOp::attachInterface<PackOpInterface>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 1c6cb88fa028b..2cb09c39b5776 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -206,3 +206,20 @@ func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf
       outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
   return %1 : tensor<2x16x32xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @bufferize_pack(
+// CHECK-SAME:   %[[SRC:.*]]: tensor<128x256xf32>, %[[DST:.*]]: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
+// CHECK-DAG:     %[[SRC_BUF:.*]] = bufferization.to_buffer %[[SRC]] : tensor<128x256xf32> to memref<128x256xf32>
+// CHECK-DAG:     %[[DST_BUF:.*]] = memref.alloc() {{.*}} : memref<16x8x8x32xf32>
+// CHECK-NOT:     memref.copy
+// CHECK:         linalg.pack %[[SRC_BUF]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DST_BUF]] : memref<128x256xf32> -> memref<16x8x8x32xf32>
+// CHECK:         %[[RESULT:.*]] = bufferization.to_tensor %[[DST_BUF]] : memref<16x8x8x32xf32> to tensor<16x8x8x32xf32>
+// CHECK:         return %[[RESULT]] : tensor<16x8x8x32xf32>
+func.func @bufferize_pack(%source: tensor<128x256xf32>, %dest: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
+  %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
+  return %0 : tensor<16x8x8x32xf32>
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/177982


More information about the Mlir-commits mailing list