[Mlir-commits] [mlir] [MLIR][MPI] adding MemoryEffects to MPI ops for buffer-deallocation-pipeline (PR #186158)

Frank Schlimbach llvmlistbot at llvm.org
Fri Mar 13 02:55:41 PDT 2026


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/186158

>From c0f240d707c2d36eb4e493086ab8e500f0baa048 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 12 Mar 2026 08:46:16 -0700
Subject: [PATCH 1/4] adding MemoryEffects to MPI ops to work with
 buffer-deallocation-pipeline

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td    | 48 +++++++++----------
 mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp |  9 ++--
 .../ShardToMPI/convert-shard-to-mpi.mlir      |  1 -
 3 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index fb0192ba748ad..1a9937d9b7152 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -20,7 +20,7 @@ class MPI_Op<string mnemonic, list<Trait> traits = []>
 // InitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_InitOp : MPI_Op<"init", []> {
+def MPI_InitOp : MPI_Op<"init", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary =
       "Initialize the MPI library, equivalent to `MPI_Init(NULL, NULL)`";
   let description = [{
@@ -42,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> {
 // CommWorldOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
+def MPI_CommWorldOp : MPI_Op<"comm_world", [MemoryEffects<[MemRead]>]> {
   let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
   let description = [{
     This operation returns the predefined MPI_COMM_WORLD communicator.
@@ -57,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
 // CommRankOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
+def MPI_CommRankOp : MPI_Op<"comm_rank", [MemoryEffects<[MemRead]>]> {
   let summary = "Get the current rank, equivalent to "
                 "`MPI_Comm_rank(comm, &rank)`";
   let description = [{
@@ -80,7 +80,7 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
 // CommSizeOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
+def MPI_CommSizeOp : MPI_Op<"comm_size", [MemoryEffects<[MemRead]>]> {
   let summary = "Get the size of the group associated to the communicator, "
                 "equivalent to `MPI_Comm_size(comm, &size)`";
   let description = [{
@@ -103,7 +103,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
 // CommSplitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSplitOp : MPI_Op<"comm_split"> {
+def MPI_CommSplitOp : MPI_Op<"comm_split", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Partition the group associated with the given communicator into "
                 "disjoint subgroups";
   let description = [{
@@ -131,7 +131,7 @@ def MPI_CommSplitOp : MPI_Op<"comm_split"> {
 // SendOp
 //===----------------------------------------------------------------------===//
 
-def MPI_SendOp : MPI_Op<"send", []> {
+def MPI_SendOp : MPI_Op<"send", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary =
       "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
   let description = [{
@@ -144,7 +144,7 @@ def MPI_SendOp : MPI_Op<"send", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $ref,
+    ins Arg<AnyMemRef, "send buffer", [MemRead]> : $ref,
     I32 : $tag,
     I32 : $dest,
     MPI_Comm : $comm
@@ -162,7 +162,7 @@ def MPI_SendOp : MPI_Op<"send", []> {
 // ISendOp
 //===----------------------------------------------------------------------===//
 
-def MPI_ISendOp : MPI_Op<"isend", []> {
+def MPI_ISendOp : MPI_Op<"isend", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary =
       "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
   let description = [{
@@ -176,7 +176,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $ref,
+    ins Arg<AnyMemRef, "send buffer", [MemRead]> : $ref,
     I32 : $tag,
     I32 : $dest,
     MPI_Comm : $comm
@@ -197,7 +197,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
 // RecvOp
 //===----------------------------------------------------------------------===//
 
-def MPI_RecvOp : MPI_Op<"recv", []> {
+def MPI_RecvOp : MPI_Op<"recv", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
                 "comm, MPI_STATUS_IGNORE)`";
   let description = [{
@@ -214,7 +214,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $ref,
+    ins Arg<AnyMemRef, "receive buffer", [MemWrite]> : $ref,
     I32 : $tag, I32 : $source,
     MPI_Comm : $comm
   );
@@ -231,7 +231,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
 // IRecvOp
 //===----------------------------------------------------------------------===//
 
-def MPI_IRecvOp : MPI_Op<"irecv", []> {
+def MPI_IRecvOp : MPI_Op<"irecv", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, source, tag, "
                 "comm, &req)`";
   let description = [{
@@ -245,7 +245,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $ref,
+    ins Arg<AnyMemRef, "receive buffer", [MemWrite]> : $ref,
     I32 : $tag,
     I32 : $source,
     MPI_Comm : $comm
@@ -288,8 +288,8 @@ def MPI_AllGatherOp : MPI_Op<"allgather", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $sendbuf,
-        AnyMemRef : $recvbuf,
+    ins Arg<AnyMemRef, "send buffer", [MemRead]> : $sendbuf,
+        Arg<AnyMemRef, "receive buffer", [MemWrite]> : $recvbuf,
         MPI_Comm : $comm
   );
 
@@ -320,8 +320,8 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $sendbuf,
-    AnyMemRef : $recvbuf,
+    ins Arg<AnyMemRef, "send buffer", [MemRead]> : $sendbuf,
+    Arg<AnyMemRef, "receive buffer", [MemWrite]> : $recvbuf,
     MPI_ReductionOpEnum : $op,
     MPI_Comm : $comm
   );
@@ -354,8 +354,8 @@ def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
   }];
 
   let arguments = (
-    ins AnyNon0RankedMemRef : $sendbuf,
-    AnyNon0RankedMemRef : $recvbuf,
+    ins Arg<AnyNon0RankedMemRef, "send buffer", [MemRead]> : $sendbuf,
+    Arg<AnyNon0RankedMemRef, "receive buffer", [MemWrite]> : $recvbuf,
     MPI_ReductionOpEnum : $op,
     MPI_Comm : $comm
   );
@@ -372,7 +372,7 @@ def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
 // BarrierOp
 //===----------------------------------------------------------------------===//
 
-def MPI_Barrier : MPI_Op<"barrier", []> {
+def MPI_Barrier : MPI_Op<"barrier", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Equivalent to `MPI_Barrier(comm)`";
   let description = [{
     MPI_Barrier blocks execution until all processes in the communicator have
@@ -396,7 +396,7 @@ def MPI_Barrier : MPI_Op<"barrier", []> {
 // WaitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_Wait : MPI_Op<"wait", []> {
+def MPI_Wait : MPI_Op<"wait", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`";
   let description = [{
     MPI_Wait blocks execution until the request has completed.
@@ -419,7 +419,7 @@ def MPI_Wait : MPI_Op<"wait", []> {
 // FinalizeOp
 //===----------------------------------------------------------------------===//
 
-def MPI_FinalizeOp : MPI_Op<"finalize", []> {
+def MPI_FinalizeOp : MPI_Op<"finalize", [MemoryEffects<[MemRead, MemWrite]>]> {
   let summary = "Finalize the MPI library, equivalent to `MPI_Finalize()`";
   let description = [{
     This function cleans up the MPI state. Afterwards, no MPI methods may 
@@ -439,7 +439,7 @@ def MPI_FinalizeOp : MPI_Op<"finalize", []> {
 // RetvalCheckOp
 //===----------------------------------------------------------------------===//
 
-def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
+def MPI_RetvalCheckOp : MPI_Op<"retval_check", [MemoryEffects<[MemRead]>]> {
   let summary = "Check an MPI return value against an error class";
   let description = [{
     This operation compares MPI status codes to known error class
@@ -462,7 +462,7 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
 // ErrorClassOp
 //===----------------------------------------------------------------------===//
 
-def MPI_ErrorClassOp : MPI_Op<"error_class", []> {
+def MPI_ErrorClassOp : MPI_Op<"error_class", [MemoryEffects<[MemRead]>]> {
   let summary = "Get the error class from an error code, equivalent to "
                 "the `MPI_Error_class` function";
   let description = [{
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 830a9333ade4a..daf8c9a154bd3 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -755,14 +755,15 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
         ib, TypeRange(), mpiInput, output,
         getMPIReductionOp(adaptor.getReductionAttr()), comm);
 
-    // Deallocate the temporary input buffer if we allocated one.
-    if (scatterDim != 0)
-      memref::DeallocOp::create(ib, mpiInput);
-
     // If the destination is a tensor, cast it to a tensor.
     if (isa<RankedTensorType>(op.getType()))
       output =
           bufferization::ToTensorOp::create(ib, op.getType(), output, true);
+    else if (scatterDim != 0) // Deallocate the temporary input buffer
+      memref::DeallocOp::create(ib, mpiInput);
+    // Notice: If this is called from tensor-world, then we assume an extra pass
+    // will take care of deallocating the intermediate buffers.
+
     rewriter.replaceOp(op, output);
     return success();
   }
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 08c3897e4e650..346a55658c170 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -186,7 +186,6 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     // CHECK: linalg.copy ins([[vtobuf]] : memref<3x2x4xf32>) outs([[valloctmp]] : memref<3x2x4xf32>)
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x4xf32>
     // CHECK: mpi.reduce_scatter_block([[valloctmp]], [[valloc]], MPI_SUM,
-    // CHECK: memref.dealloc [[valloctmp]] : memref<3x2x4xf32>
     // CHECK: [[vout:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<2x4xf32> to tensor<2x4xf32>
     %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 1 : tensor<2x12xf32> -> tensor<2x4xf32>
     // CHECK: return [[vout]] : tensor<2x4xf32>

>From 8a300a147956abb04944f1b46fb71df87d98657f Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 12 Mar 2026 09:11:43 -0700
Subject: [PATCH 2/4] mentioning dealloacation dep in pass doc.

---
 mlir/include/mlir/Conversion/Passes.td | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index e77860897399f..f82ffabb60da6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1011,6 +1011,8 @@ def ConvertShardToMPIPass : Pass<"convert-shard-to-mpi"> {
     use that integer value instead of calling MPI_Comm_rank. This allows
     optimizations like constant shape propagation and fusion because
     shard/partition sizes depend on the rank.
+    To support the buffer deallocation pipeline, intermediate memref allocations
+    are only deallocated if the lowered operations return buffers (not tensors).
   }];
   let dependentDialects = [
     "affine::AffineDialect",

>From 542861490a5a06a7d9a3b392db13142828e9dc1b Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 12 Mar 2026 09:24:28 -0700
Subject: [PATCH 3/4] relaxing memory constraints

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 1a9937d9b7152..80d2ca66faca9 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -42,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", [MemoryEffects<[MemRead, MemWrite]>]> {
 // CommWorldOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommWorldOp : MPI_Op<"comm_world", [MemoryEffects<[MemRead]>]> {
+def MPI_CommWorldOp : MPI_Op<"comm_world", [MemoryEffects<[]>]> {
   let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
   let description = [{
     This operation returns the predefined MPI_COMM_WORLD communicator.
@@ -57,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", [MemoryEffects<[MemRead]>]> {
 // CommRankOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommRankOp : MPI_Op<"comm_rank", [MemoryEffects<[MemRead]>]> {
+def MPI_CommRankOp : MPI_Op<"comm_rank", [MemoryEffects<[]>]> {
   let summary = "Get the current rank, equivalent to "
                 "`MPI_Comm_rank(comm, &rank)`";
   let description = [{
@@ -80,7 +80,7 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", [MemoryEffects<[MemRead]>]> {
 // CommSizeOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSizeOp : MPI_Op<"comm_size", [MemoryEffects<[MemRead]>]> {
+def MPI_CommSizeOp : MPI_Op<"comm_size", [MemoryEffects<[]>]> {
   let summary = "Get the size of the group associated to the communicator, "
                 "equivalent to `MPI_Comm_size(comm, &size)`";
   let description = [{
@@ -396,7 +396,7 @@ def MPI_Barrier : MPI_Op<"barrier", [MemoryEffects<[MemRead, MemWrite]>]> {
 // WaitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_Wait : MPI_Op<"wait", [MemoryEffects<[MemRead, MemWrite]>]> {
+def MPI_Wait : MPI_Op<"wait", [MemoryEffects<[]>]> {
   let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`";
   let description = [{
     MPI_Wait blocks execution until the request has completed.
@@ -439,7 +439,7 @@ def MPI_FinalizeOp : MPI_Op<"finalize", [MemoryEffects<[MemRead, MemWrite]>]> {
 // RetvalCheckOp
 //===----------------------------------------------------------------------===//
 
-def MPI_RetvalCheckOp : MPI_Op<"retval_check", [MemoryEffects<[MemRead]>]> {
+def MPI_RetvalCheckOp : MPI_Op<"retval_check", [MemoryEffects<[]>]> {
   let summary = "Check an MPI return value against an error class";
   let description = [{
     This operation compares MPI status codes to known error class
@@ -462,7 +462,7 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", [MemoryEffects<[MemRead]>]> {
 // ErrorClassOp
 //===----------------------------------------------------------------------===//
 
-def MPI_ErrorClassOp : MPI_Op<"error_class", [MemoryEffects<[MemRead]>]> {
+def MPI_ErrorClassOp : MPI_Op<"error_class", [MemoryEffects<[]>]> {
   let summary = "Get the error class from an error code, equivalent to "
                 "the `MPI_Error_class` function";
   let description = [{

>From 91f69b6f684cc60208b129f6afe0e53c96d35c78 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 13 Mar 2026 02:51:47 -0700
Subject: [PATCH 4/4] clarified doc

---
 mlir/include/mlir/Conversion/Passes.td | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f82ffabb60da6..1cb938203c93d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1005,14 +1005,17 @@ def ConvertMemRefToSPIRVPass : Pass<"convert-memref-to-spirv"> {
 def ConvertShardToMPIPass : Pass<"convert-shard-to-mpi"> {
   let summary = "Convert Shard dialect to MPI dialect.";
   let description = [{
-    This pass converts communication operations from the Shard dialect to the
-    MPI dialect.
-    If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
-    use that integer value instead of calling MPI_Comm_rank. This allows
-    optimizations like constant shape propagation and fusion because
-    shard/partition sizes depend on the rank.
-    To support the buffer deallocation pipeline, intermediate memref allocations
-    are only deallocated if the lowered operations return buffers (not tensors).
+    This pass lowers communication operations from the Shard dialect to the MPI dialect.
+    If the module contains the DLTI attribute "MPI:comm_world-rank", its integer value is
+    used as the rank instead of calling MPI_Comm_rank. This enables optimizations such as
+    constant shape propagation and fusion, since shard and partition sizes can be
+    determined from the rank.
+    For some operations the conversion may require intermediate memref allocations.
+    For compatibility with the buffer deallocation pipeline, these allocations are only
+    deallocated when the lowered operations return buffers. When the operation was
+    defined in tensor-land, no explicit deallocation is performed. This means that the
+    deallocation must be handled by different means, e.g. by the deallocation pipeline.
+  }];
   }];
   let dependentDialects = [
     "affine::AffineDialect",



More information about the Mlir-commits mailing list