[Mlir-commits] [mlir] [mlir][shard, mpi] Marking explicitly bufferized buffers read-only where applicable (PR #186464)
Frank Schlimbach
llvmlistbot at llvm.org
Fri Mar 13 10:38:19 PDT 2026
https://github.com/fschlimb created https://github.com/llvm/llvm-project/pull/186464
Marking read-only buffers as such avoids unnecessary copies in `one-shot-bufferize`
>From adf8346968b1a7dfe25b2637021003466bc800c6 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 13 Mar 2026 10:35:12 -0700
Subject: [PATCH] marking explicitly bufferized buffers read-only where
applicable
---
mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp | 16 +++++++++-------
.../ShardToMPI/convert-shard-to-mpi.mlir | 14 +++++++-------
2 files changed, 16 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index daf8c9a154bd3..508917db1fc84 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -517,12 +517,14 @@ struct CommOpPattern : public OpConversionPattern<CommOp> {
return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
}
- Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
+ Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder,
+ bool readOnly) const {
auto itype = input.getType();
// If the source is a memref, cast it to a tensor.
if (isa<RankedTensorType>(itype)) {
auto memrefType = getMemrefType(cast<ShapedType>(itype));
- input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
+ input = bufferization::ToBufferOp::create(iBuilder, memrefType, input,
+ readOnly);
} else {
assert(isa<MemRefType>(itype) &&
"expected input to be of MemRefType or TensorType");
@@ -589,7 +591,7 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
if (failed(gridOp))
return failure();
ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
- Value input = getAsMemref(adaptor.getInput(), iBuilder);
+ Value input = getAsMemref(adaptor.getInput(), iBuilder, true);
MemRefType inType = cast<MemRefType>(input.getType());
if (!memref::isStaticShapeAndContiguousRowMajor(inType))
return op.emitError(
@@ -683,7 +685,7 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
if (scatterDim == 0) {
// scatter_dim == 0 maps directly to MPI_Reduce_scatter_block.
// Input must be contiguous for MPI.
- Value input = getAsMemref(rawInput, ib);
+ Value input = getAsMemref(rawInput, ib, true);
MemRefType inType = cast<MemRefType>(input.getType());
if (!memref::isStaticShapeAndContiguousRowMajor(inType))
return op.emitError("Input must be a statically shaped memref in "
@@ -743,7 +745,7 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
// freshly allocated buffer.
auto mpiInType = MemRefType::get(transposedShape, elemType);
Value transposedBuf =
- bufferization::ToBufferOp::create(ib, mpiInType, tensorInput);
+ bufferization::ToBufferOp::create(ib, mpiInType, tensorInput, true);
mpiInput = memref::AllocOp::create(ib, mpiInType);
linalg::CopyOp::create(ib, transposedBuf, mpiInput);
}
@@ -787,7 +789,7 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
return failure();
ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
- Value input = getAsMemref(adaptor.getInput(), ib);
+ Value input = getAsMemref(adaptor.getInput(), ib, true);
MemRefType inType = cast<MemRefType>(input.getType());
MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
auto inputShape = inType.getShape();
@@ -900,7 +902,7 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
// 4. Cast back to memref if needed.
if (isa<MemRefType>(op.getType()))
finalOutput =
- bufferization::ToBufferOp::create(ib, outType, finalOutput);
+ bufferization::ToBufferOp::create(ib, outType, finalOutput, true);
}
rewriter.replaceOp(op, finalOutput);
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 346a55658c170..c80c1798bb77b 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -115,7 +115,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
- // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+ // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] read_only : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
// CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
@@ -181,7 +181,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK: [[vexpanded:%.*]] = tensor.expand_shape [[varg0]] {{\[\[}}0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xf32> into tensor<2x3x4xf32>
// CHECK: [[vempty:%.*]] = tensor.empty() : tensor<3x2x4xf32>
// CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[vexpanded]] : tensor<2x3x4xf32>) outs([[vempty]] : tensor<3x2x4xf32>) permutation = [1, 0, 2]
- // CHECK: [[vtobuf:%.*]] = bufferization.to_buffer [[vtransposed]] : tensor<3x2x4xf32> to memref<3x2x4xf32>
+ // CHECK: [[vtobuf:%.*]] = bufferization.to_buffer [[vtransposed]] read_only : tensor<3x2x4xf32> to memref<3x2x4xf32>
// CHECK: [[valloctmp:%.*]] = memref.alloc() : memref<3x2x4xf32>
// CHECK: linalg.copy ins([[vtobuf]] : memref<3x2x4xf32>) outs([[valloctmp]] : memref<3x2x4xf32>)
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x4xf32>
@@ -198,7 +198,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
- // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+ // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] read_only : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
// CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
@@ -222,7 +222,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index
- // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+ // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] read_only : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
// CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
@@ -260,7 +260,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK: [[v4:%.*]] = tensor.empty() : tensor<3x5x4xf32>
// CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v3]] : tensor<5x3x4xf32>) outs([[v4]] : tensor<3x5x4xf32>) permutation = [1, 0, 2]
// CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32>
- // CHECK: [[v5:%.*]] = bufferization.to_buffer [[vcollapsed]] : tensor<3x20xf32> to memref<3x20xf32>
+ // CHECK: [[v5:%.*]] = bufferization.to_buffer [[vcollapsed]] read_only : tensor<3x20xf32> to memref<3x20xf32>
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32>
// CHECK: return [[v5]] : memref<3x20xf32>
return %0 : memref<3x20xf32>
@@ -505,7 +505,7 @@ func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %
// CHECK-DAG: [[vc0:%.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
- // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<512x512xf32> to memref<512x512xf32>
+ // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] read_only : tensor<512x512xf32> to memref<512x512xf32>
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32
// CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
@@ -541,7 +541,7 @@ func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %
}
// CHECK: [[v16:%.*]] = linalg.matmul ins([[v9]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v15]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
%8 = linalg.matmul ins(%3, %arg2 : tensor<512x256xf32>, tensor<256x2048xf32>) outs(%7 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
- // CHECK: [[v17:%.*]] = bufferization.to_buffer [[v16]] : tensor<512x2048xf32> to memref<512x2048xf32>
+ // CHECK: [[v17:%.*]] = bufferization.to_buffer [[v16]] read_only : tensor<512x2048xf32> to memref<512x2048xf32>
// CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
// CHECK: linalg.copy ins([[v17]] : memref<512x2048xf32>) outs([[valloc_0]] : memref<512x2048xf32>)
// CHECK: [[v18:%.*]] = mpi.comm_world : !mpi.comm
More information about the Mlir-commits
mailing list