[Mlir-commits] [mlir] [mlir][shard, mpi] Marking explicitly bufferized buffers read-only where applicable (PR #186464)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 13 10:38:54 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

Marking read-only buffers as such avoids unnecessary copies in `one-shot-bufferize`


---
Full diff: https://github.com/llvm/llvm-project/pull/186464.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp (+9-7) 
- (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+7-7) 


``````````diff
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index daf8c9a154bd3..508917db1fc84 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -517,12 +517,14 @@ struct CommOpPattern : public OpConversionPattern<CommOp> {
     return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
   }
 
-  Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
+  Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder,
+                    bool readOnly) const {
     auto itype = input.getType();
     // If the source is a memref, cast it to a tensor.
     if (isa<RankedTensorType>(itype)) {
       auto memrefType = getMemrefType(cast<ShapedType>(itype));
-      input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
+      input = bufferization::ToBufferOp::create(iBuilder, memrefType, input,
+                                                readOnly);
     } else {
       assert(isa<MemRefType>(itype) &&
              "expected input to be of MemRefType or TensorType");
@@ -589,7 +591,7 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
     if (failed(gridOp))
       return failure();
     ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
-    Value input = getAsMemref(adaptor.getInput(), iBuilder);
+    Value input = getAsMemref(adaptor.getInput(), iBuilder, true);
     MemRefType inType = cast<MemRefType>(input.getType());
     if (!memref::isStaticShapeAndContiguousRowMajor(inType))
       return op.emitError(
@@ -683,7 +685,7 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
     if (scatterDim == 0) {
       // scatter_dim == 0 maps directly to MPI_Reduce_scatter_block.
       // Input must be contiguous for MPI.
-      Value input = getAsMemref(rawInput, ib);
+      Value input = getAsMemref(rawInput, ib, true);
       MemRefType inType = cast<MemRefType>(input.getType());
       if (!memref::isStaticShapeAndContiguousRowMajor(inType))
         return op.emitError("Input must be a statically shaped memref in "
@@ -743,7 +745,7 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
       //    freshly allocated buffer.
       auto mpiInType = MemRefType::get(transposedShape, elemType);
       Value transposedBuf =
-          bufferization::ToBufferOp::create(ib, mpiInType, tensorInput);
+          bufferization::ToBufferOp::create(ib, mpiInType, tensorInput, true);
       mpiInput = memref::AllocOp::create(ib, mpiInType);
       linalg::CopyOp::create(ib, transposedBuf, mpiInput);
     }
@@ -787,7 +789,7 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
       return failure();
 
     ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
-    Value input = getAsMemref(adaptor.getInput(), ib);
+    Value input = getAsMemref(adaptor.getInput(), ib, true);
     MemRefType inType = cast<MemRefType>(input.getType());
     MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
     auto inputShape = inType.getShape();
@@ -900,7 +902,7 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
       // 4. Cast back to memref if needed.
       if (isa<MemRefType>(op.getType()))
         finalOutput =
-            bufferization::ToBufferOp::create(ib, outType, finalOutput);
+            bufferization::ToBufferOp::create(ib, outType, finalOutput, true);
     }
 
     rewriter.replaceOp(op, finalOutput);
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 346a55658c170..c80c1798bb77b 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -115,7 +115,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     %arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> {
     // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
-    // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+    // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] read_only : tensor<3x4xf32> to memref<3x4xf32>
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
     // CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
     // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
@@ -181,7 +181,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     // CHECK: [[vexpanded:%.*]] = tensor.expand_shape [[varg0]] {{\[\[}}0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xf32> into tensor<2x3x4xf32>
     // CHECK: [[vempty:%.*]] = tensor.empty() : tensor<3x2x4xf32>
     // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[vexpanded]] : tensor<2x3x4xf32>) outs([[vempty]] : tensor<3x2x4xf32>) permutation = [1, 0, 2]
-    // CHECK: [[vtobuf:%.*]] = bufferization.to_buffer [[vtransposed]] : tensor<3x2x4xf32> to memref<3x2x4xf32>
+    // CHECK: [[vtobuf:%.*]] = bufferization.to_buffer [[vtransposed]] read_only : tensor<3x2x4xf32> to memref<3x2x4xf32>
     // CHECK: [[valloctmp:%.*]] = memref.alloc() : memref<3x2x4xf32>
     // CHECK: linalg.copy ins([[vtobuf]] : memref<3x2x4xf32>) outs([[valloctmp]] : memref<3x2x4xf32>)
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x4xf32>
@@ -198,7 +198,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
-    // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+    // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] read_only : tensor<3x4xf32> to memref<3x4xf32>
     // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
     // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
@@ -222,7 +222,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK-DAG: [[vc5:%.*]] = arith.constant 5 : index
-    // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+    // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] read_only : tensor<3x4xf32> to memref<3x4xf32>
     // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
     // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
@@ -260,7 +260,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     // CHECK: [[v4:%.*]] = tensor.empty() : tensor<3x5x4xf32>
     // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[v3]] : tensor<5x3x4xf32>) outs([[v4]] : tensor<3x5x4xf32>) permutation = [1, 0, 2] 
     // CHECK: [[vcollapsed:%.*]] = tensor.collapse_shape [[vtransposed]] {{\[\[}}0], [1, 2]] : tensor<3x5x4xf32> into tensor<3x20xf32>
-    // CHECK: [[v5:%.*]] = bufferization.to_buffer [[vcollapsed]] : tensor<3x20xf32> to memref<3x20xf32>
+    // CHECK: [[v5:%.*]] = bufferization.to_buffer [[vcollapsed]] read_only : tensor<3x20xf32> to memref<3x20xf32>
     %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32>
     // CHECK: return [[v5]] : memref<3x20xf32>
     return %0 : memref<3x20xf32>
@@ -505,7 +505,7 @@ func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %
   // CHECK-DAG: [[vc0:%.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
-  // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<512x512xf32> to memref<512x512xf32>
+  // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] read_only : tensor<512x512xf32> to memref<512x512xf32>
   // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
   // CHECK: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32
   // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
@@ -541,7 +541,7 @@ func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %
   }
   // CHECK: [[v16:%.*]] = linalg.matmul ins([[v9]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v15]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
   %8 = linalg.matmul ins(%3, %arg2 : tensor<512x256xf32>, tensor<256x2048xf32>) outs(%7 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
-  // CHECK: [[v17:%.*]] = bufferization.to_buffer [[v16]] : tensor<512x2048xf32> to memref<512x2048xf32>
+  // CHECK: [[v17:%.*]] = bufferization.to_buffer [[v16]] read_only : tensor<512x2048xf32> to memref<512x2048xf32>
   // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
   // CHECK: linalg.copy ins([[v17]] : memref<512x2048xf32>) outs([[valloc_0]] : memref<512x2048xf32>)
   // CHECK: [[v18:%.*]] = mpi.comm_world : !mpi.comm

``````````

</details>


https://github.com/llvm/llvm-project/pull/186464


More information about the Mlir-commits mailing list