[Mlir-commits] [mlir] [MLIR][MPI] adding MemoryEffects to MPI ops for buffer-deallocation-pipeline (PR #186158)

Frank Schlimbach llvmlistbot at llvm.org
Thu Mar 12 09:25:30 PDT 2026


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/186158

>From 4e473a9a081ff52176f719942975328db9ec25e5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 12 Mar 2026 08:46:16 -0700
Subject: [PATCH 1/3] adding MemoryEffects to MPI ops to work with
 buffer-deallocation-pipeline

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td    | 48 +++++++++----------
 mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp |  9 ++--
 .../ShardToMPI/convert-shard-to-mpi.mlir      |  1 -
 3 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index fb0192ba748ad..1a9937d9b7152 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -20,7 +20,7 @@ class MPI_Op<string mnemonic, list<Trait> traits = []>
 // InitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_InitOp : MPI_Op<"init", []> {
+def MPI_InitOp : MPI_Op<"init", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary =
       "Initialize the MPI library, equivalent to `MPI_Init(NULL, NULL)`";
   let description = [{
@@ -42,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> {
 // CommWorldOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
+def MPI_CommWorldOp : MPI_Op<"comm_world", [MemoryEffects<[MemRead]>]> {
   let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
   let description = [{
     This operation returns the predefined MPI_COMM_WORLD communicator.
@@ -57,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
 // CommRankOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
+def MPI_CommRankOp : MPI_Op<"comm_rank", [MemoryEffects<[MemRead]>]> {
   let summary = "Get the current rank, equivalent to "
                 "`MPI_Comm_rank(comm, &rank)`";
   let description = [{
@@ -80,7 +80,7 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
 // CommSizeOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
+def MPI_CommSizeOp : MPI_Op<"comm_size", [MemoryEffects<[MemRead]>]> {
   let summary = "Get the size of the group associated to the communicator, "
                 "equivalent to `MPI_Comm_size(comm, &size)`";
   let description = [{
@@ -103,7 +103,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
 // CommSplitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSplitOp : MPI_Op<"comm_split"> {
+def MPI_CommSplitOp : MPI_Op<"comm_split", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Partition the group associated with the given communicator into "
                 "disjoint subgroups";
   let description = [{
@@ -131,7 +131,7 @@ def MPI_CommSplitOp : MPI_Op<"comm_split"> {
 // SendOp
 //===----------------------------------------------------------------------===//
 
-def MPI_SendOp : MPI_Op<"send", []> {
+def MPI_SendOp : MPI_Op<"send", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary =
       "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
   let description = [{
@@ -144,7 +144,7 @@ def MPI_SendOp : MPI_Op<"send", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $ref,
+    ins Arg<AnyMemRef, "send buffer", [MemRead]> : $ref,
     I32 : $tag,
     I32 : $dest,
     MPI_Comm : $comm
@@ -162,7 +162,7 @@ def MPI_SendOp : MPI_Op<"send", []> {
 // ISendOp
 //===----------------------------------------------------------------------===//
 
-def MPI_ISendOp : MPI_Op<"isend", []> {
+def MPI_ISendOp : MPI_Op<"isend", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary =
       "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
   let description = [{
@@ -176,7 +176,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $ref,
+    ins Arg<AnyMemRef, "send buffer", [MemRead]> : $ref,
     I32 : $tag,
     I32 : $dest,
     MPI_Comm : $comm
@@ -197,7 +197,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
 // RecvOp
 //===----------------------------------------------------------------------===//
 
-def MPI_RecvOp : MPI_Op<"recv", []> {
+def MPI_RecvOp : MPI_Op<"recv", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
                 "comm, MPI_STATUS_IGNORE)`";
   let description = [{
@@ -214,7 +214,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $ref,
+    ins Arg<AnyMemRef, "receive buffer", [MemWrite]> : $ref,
     I32 : $tag, I32 : $source,
     MPI_Comm : $comm
   );
@@ -231,7 +231,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
 // IRecvOp
 //===----------------------------------------------------------------------===//
 
-def MPI_IRecvOp : MPI_Op<"irecv", []> {
+def MPI_IRecvOp : MPI_Op<"irecv", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, source, tag, "
                 "comm, &req)`";
   let description = [{
@@ -245,7 +245,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $ref,
+    ins Arg<AnyMemRef, "receive buffer", [MemWrite]> : $ref,
     I32 : $tag,
     I32 : $source,
     MPI_Comm : $comm
@@ -288,8 +288,8 @@ def MPI_AllGatherOp : MPI_Op<"allgather", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $sendbuf,
-        AnyMemRef : $recvbuf,
+    ins Arg<AnyMemRef, "send buffer", [MemRead]> : $sendbuf,
+        Arg<AnyMemRef, "receive buffer", [MemWrite]> : $recvbuf,
         MPI_Comm : $comm
   );
 
@@ -320,8 +320,8 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $sendbuf,
-    AnyMemRef : $recvbuf,
+    ins Arg<AnyMemRef, "send buffer", [MemRead]> : $sendbuf,
+    Arg<AnyMemRef, "receive buffer", [MemWrite]> : $recvbuf,
     MPI_ReductionOpEnum : $op,
     MPI_Comm : $comm
   );
@@ -354,8 +354,8 @@ def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
   }];
 
   let arguments = (
-    ins AnyNon0RankedMemRef : $sendbuf,
-    AnyNon0RankedMemRef : $recvbuf,
+    ins Arg<AnyNon0RankedMemRef, "send buffer", [MemRead]> : $sendbuf,
+    Arg<AnyNon0RankedMemRef, "receive buffer", [MemWrite]> : $recvbuf,
     MPI_ReductionOpEnum : $op,
     MPI_Comm : $comm
   );
@@ -372,7 +372,7 @@ def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
 // BarrierOp
 //===----------------------------------------------------------------------===//
 
-def MPI_Barrier : MPI_Op<"barrier", []> {
+def MPI_Barrier : MPI_Op<"barrier", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Equivalent to `MPI_Barrier(comm)`";
   let description = [{
     MPI_Barrier blocks execution until all processes in the communicator have
@@ -396,7 +396,7 @@ def MPI_Barrier : MPI_Op<"barrier", []> {
 // WaitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_Wait : MPI_Op<"wait", []> {
+def MPI_Wait : MPI_Op<"wait", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`";
   let description = [{
     MPI_Wait blocks execution until the request has completed.
@@ -419,7 +419,7 @@ def MPI_Wait : MPI_Op<"wait", []> {
 // FinalizeOp
 //===----------------------------------------------------------------------===//
 
-def MPI_FinalizeOp : MPI_Op<"finalize", []> {
+def MPI_FinalizeOp : MPI_Op<"finalize", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Finalize the MPI library, equivalent to `MPI_Finalize()`";
   let description = [{
     This function cleans up the MPI state. Afterwards, no MPI methods may 
@@ -439,7 +439,7 @@ def MPI_FinalizeOp : MPI_Op<"finalize", []> {
 // RetvalCheckOp
 //===----------------------------------------------------------------------===//
 
-def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
+def MPI_RetvalCheckOp : MPI_Op<"retval_check", [MemoryEffects<[MemRead]>]> {
   let summary = "Check an MPI return value against an error class";
   let description = [{
     This operation compares MPI status codes to known error class
@@ -462,7 +462,7 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
 // ErrorClassOp
 //===----------------------------------------------------------------------===//
 
-def MPI_ErrorClassOp : MPI_Op<"error_class", []> {
+def MPI_ErrorClassOp : MPI_Op<"error_class", [MemoryEffects<[MemRead]>]> {
   let summary = "Get the error class from an error code, equivalent to "
                 "the `MPI_Error_class` function";
   let description = [{
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 830a9333ade4a..daf8c9a154bd3 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -755,14 +755,15 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
         ib, TypeRange(), mpiInput, output,
         getMPIReductionOp(adaptor.getReductionAttr()), comm);
 
-    // Deallocate the temporary input buffer if we allocated one.
-    if (scatterDim != 0)
-      memref::DeallocOp::create(ib, mpiInput);
-
     // If the destination is a tensor, cast it to a tensor.
     if (isa<RankedTensorType>(op.getType()))
       output =
           bufferization::ToTensorOp::create(ib, op.getType(), output, true);
+    else if (scatterDim != 0) // Deallocate the temporary input buffer
+      memref::DeallocOp::create(ib, mpiInput);
+    // Notice: If this is called from tensor-world, then we assume an extra pass
+    // will take care of deallocating the intermediate buffers.
+
     rewriter.replaceOp(op, output);
     return success();
   }
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 08c3897e4e650..346a55658c170 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -186,7 +186,6 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     // CHECK: linalg.copy ins([[vtobuf]] : memref<3x2x4xf32>) outs([[valloctmp]] : memref<3x2x4xf32>)
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x4xf32>
     // CHECK: mpi.reduce_scatter_block([[valloctmp]], [[valloc]], MPI_SUM,
-    // CHECK: memref.dealloc [[valloctmp]] : memref<3x2x4xf32>
     // CHECK: [[vout:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<2x4xf32> to tensor<2x4xf32>
     %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 1 : tensor<2x12xf32> -> tensor<2x4xf32>
     // CHECK: return [[vout]] : tensor<2x4xf32>

>From f21ffa7a3e1773ad5243aacb373a64903d35ad33 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 12 Mar 2026 09:11:43 -0700
Subject: [PATCH 2/3] mentioning dealloacation dep in pass doc.

---
 mlir/include/mlir/Conversion/Passes.td | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index e77860897399f..f82ffabb60da6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1011,6 +1011,8 @@ def ConvertShardToMPIPass : Pass<"convert-shard-to-mpi"> {
     use that integer value instead of calling MPI_Comm_rank. This allows
     optimizations like constant shape propagation and fusion because
     shard/partition sizes depend on the rank.
+    To support the buffer deallocation pipeline, intermediate memref allocations
+    are only deallocated if the lowered operations return buffers (not tensors).
   }];
   let dependentDialects = [
     "affine::AffineDialect",

>From dbf3b1a6b549e513bcdec8fb9603c2cf81ff76d5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 12 Mar 2026 09:24:28 -0700
Subject: [PATCH 3/3] relaxing memory constraints

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 1a9937d9b7152..80d2ca66faca9 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -42,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", [MemoryEffects<[MemRead, MemWrite]>]> {
 // CommWorldOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommWorldOp : MPI_Op<"comm_world", [MemoryEffects<[MemRead]>]> {
+def MPI_CommWorldOp : MPI_Op<"comm_world", [MemoryEffects<[]>]> {
   let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
   let description = [{
     This operation returns the predefined MPI_COMM_WORLD communicator.
@@ -57,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", [MemoryEffects<[MemRead]>]> {
 // CommRankOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommRankOp : MPI_Op<"comm_rank", [MemoryEffects<[MemRead]>]> {
+def MPI_CommRankOp : MPI_Op<"comm_rank", [MemoryEffects<[]>]> {
   let summary = "Get the current rank, equivalent to "
                 "`MPI_Comm_rank(comm, &rank)`";
   let description = [{
@@ -80,7 +80,7 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", [MemoryEffects<[MemRead]>]> {
 // CommSizeOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSizeOp : MPI_Op<"comm_size", [MemoryEffects<[MemRead]>]> {
+def MPI_CommSizeOp : MPI_Op<"comm_size", [MemoryEffects<[]>]> {
   let summary = "Get the size of the group associated to the communicator, "
                 "equivalent to `MPI_Comm_size(comm, &size)`";
   let description = [{
@@ -396,7 +396,7 @@ def MPI_Barrier : MPI_Op<"barrier", [MemoryEffects<[MemRead, MemWrite]>]> {
 // WaitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_Wait : MPI_Op<"wait", [MemoryEffects<[MemRead, MemWrite]>]> {
+def MPI_Wait : MPI_Op<"wait", [MemoryEffects<[]>]> {
   let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`";
   let description = [{
     MPI_Wait blocks execution until the request has completed.
@@ -439,7 +439,7 @@ def MPI_FinalizeOp : MPI_Op<"finalize", [MemoryEffects<[MemRead, MemWrite]>]> {
 // RetvalCheckOp
 //===----------------------------------------------------------------------===//
 
-def MPI_RetvalCheckOp : MPI_Op<"retval_check", [MemoryEffects<[MemRead]>]> {
+def MPI_RetvalCheckOp : MPI_Op<"retval_check", [MemoryEffects<[]>]> {
   let summary = "Check an MPI return value against an error class";
   let description = [{
     This operation compares MPI status codes to known error class
@@ -462,7 +462,7 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", [MemoryEffects<[MemRead]>]> {
 // ErrorClassOp
 //===----------------------------------------------------------------------===//
 
-def MPI_ErrorClassOp : MPI_Op<"error_class", [MemoryEffects<[MemRead]>]> {
+def MPI_ErrorClassOp : MPI_Op<"error_class", [MemoryEffects<[]>]> {
   let summary = "Get the error class from an error code, equivalent to "
                 "the `MPI_Error_class` function";
   let description = [{



More information about the Mlir-commits mailing list