[Mlir-commits] [mlir] [mlir][mpi] adding MPI_Allgather and lowering to LLVM (PR #176937)

Frank Schlimbach llvmlistbot at llvm.org
Tue Jan 20 08:48:56 PST 2026


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/176937

>From 20c83362ce3db4487b0deb221ff93d09f5b8e0f5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Jan 2026 06:38:59 -0800
Subject: [PATCH 1/4] adding MPI_Allgather and lowering to LLVM

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td    |  33 +++++
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   | 118 +++++++++++++++++-
 mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir |  15 +++
 mlir/test/Dialect/MPI/mpiops.mlir             |   6 +
 4 files changed, 170 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 935e0f785ef0c..7c2b00c79e691 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -261,6 +261,39 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AllGatherOp
+//===----------------------------------------------------------------------===//
+
+def MPI_AllGatherOp : MPI_Op<"allgather", []> {
+  let summary = [{
+    Equivalent to `MPI_Allgather(sendbuf, sendcount, sendtype,
+                                 recvbuf, recvcount, recvtype,
+                                 comm)`.
+  }];
+  let description = [{
+    MPI_Allgather collects data from all processes in a given communicator and
+    stores the data collected in the receive buffer of each process.
+
+    Each process is expected to contribute the same amount of data.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (
+    ins AnyMemRef : $sendbuf,
+        AnyMemRef : $recvbuf,
+        MPI_Comm : $comm
+  );
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $comm `)` "
+                       "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
+                       "(`->` type($retval)^)?";
+}
+
 //===----------------------------------------------------------------------===//
 // AllReduceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 59a16df9c59d3..f37622c0bf7f0 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -604,6 +604,62 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// CommSizeOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::CommSizeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get some helper vars
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type i32 = rewriter.getI32Type();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    // get communicator
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+    // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+    auto SizeFuncType =
+        LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "MPI_Comm_Size", SizeFuncType);
+
+    // replace with function call
+    auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
+    auto sizeptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
+    auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
+                                       ValueRange{comm, sizeptr.getRes()});
+
+    // load the Size into a register
+    auto loadedSize =
+        LLVM::LoadOp::create(rewriter, loc, i32, sizeptr.getResult());
+
+    // if retval is checked, replace uses of retval with the results from the
+    // call op
+    SmallVector<Value> replacements;
+    if (op.getRetval())
+      replacements.push_back(callOp.getResult());
+
+    // replace all uses, then erase op
+    replacements.push_back(loadedSize.getRes());
+    rewriter.replaceOp(op, replacements);
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // SendOpLowering
 //===----------------------------------------------------------------------===//
@@ -712,6 +768,64 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// AllGatherOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type sElemType = op.getSendbuf().getType().getElementType();
+    Type rElemType = op.getRecvbuf().getType().getElementType();
+    auto [sendPtr, sendSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+    auto [recvPtr, recvSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
+    Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+    Type i32 = rewriter.getI32Type();
+    // int MPI_Allgather(
+    //     const void* buffer_send, int count_send, MPI_Datatype datatype_send,
+    //     void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
+    //     MPI_Comm communicator);
+    auto funcType = LLVM::LLVMFunctionType::get(
+        i32, {ptrType, i32, sDataType.getType(),
+              ptrType, i32, rDataType.getType(),
+              comm.getType()});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp funcDecl =
+        getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
+
+    // count_recv is the number of elements received from each rank, not total
+    Value nRanks = mlir::mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
+    Value recvCountPerRank = LLVM::UDivOp::create(
+        rewriter, loc, i32, recvSize, nRanks);
+
+    // replace op with function call
+    auto funcCall = LLVM::CallOp::create(
+        rewriter, loc, funcDecl,
+        ValueRange{sendPtr, sendSize, sDataType, recvPtr, recvCountPerRank, rDataType, comm});
+
+    if (op.getRetval())
+      rewriter.replaceOp(op, funcCall.getResult());
+    else
+      rewriter.eraseOp(op);
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // AllReduceOpLowering
 //===----------------------------------------------------------------------===//
@@ -801,9 +915,9 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
   converter.addConversion([](mpi::CommType type) {
     return IntegerType::get(type.getContext(), 64);
   });
-  patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
+  patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering, CommWorldOpLowering,
                FinalizeOpLowering, InitOpLowering, SendOpLowering,
-               RecvOpLowering, AllReduceOpLowering>(converter);
+               RecvOpLowering, AllGatherOpLowering, AllReduceOpLowering>(converter);
 }
 
 void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 35fc0f5d2e754..4c1beee2fe144 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -3,6 +3,9 @@
 // COM: Test MPICH ABI
 // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+// CHECK: llvm.func @MPI_Comm_Size(i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
 // CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -88,6 +91,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
     %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
 
+    // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
+    %err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
     // CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -119,6 +125,10 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
 // CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
+// CHECK: llvm.func @MPI_Comm_Size(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
@@ -193,6 +203,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
     // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
     %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
+    
+    // CHECK: llvm.call @MPI_Comm_Size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
+    // CHECK: llvm.udiv {{.*}} : i32
+    // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
+    mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32>
 
     // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index ef457628fe2c4..87a5647ee91d2 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -65,6 +65,12 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // CHECK-NEXT: [[v5:%.*]] = mpi.barrier([[v1]]) -> !mpi.retval
     %err7 = mpi.barrier(%comm) -> !mpi.retval
 
+    // CHECK-NEXT: [[e3:%.*]] = mpi.allgather([[varg0]], [[varg0]], [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+    %err3 = mpi.allgather(%ref, %ref, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
+    // CHECK-NEXT: mpi.allgather([[varg0]], [[varg0]], [[v1]]) : memref<100xf32>, memref<100xf32>
+    mpi.allgather(%ref, %ref, %comm) : memref<100xf32>, memref<100xf32>
+
     // CHECK-NEXT: [[v6:%.*]] = mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
     %err8 = mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
 

>From fc4b6704c466026ac67e7fbf4677c2a0b3b8f24e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Jan 2026 08:13:42 -0800
Subject: [PATCH 2/4] addressing minor review comment

---
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index f37622c0bf7f0..bc5aeec8cb4b0 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -808,7 +808,7 @@ struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
         getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
 
     // count_recv is the number of elements received from each rank, not total
-    Value nRanks = mlir::mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
+    Value nRanks = mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
     Value recvCountPerRank = LLVM::UDivOp::create(
         rewriter, loc, i32, recvSize, nRanks);
 

>From b61912fde1b81a451c7b7ed41a0aec4eb25856fc Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Jan 2026 08:31:40 -0800
Subject: [PATCH 3/4] clang-format

---
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 27 ++++++++++++---------
 1 file changed, 15 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index bc5aeec8cb4b0..4a1c5d1f7846c 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -800,22 +800,24 @@ struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
     //     void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
     //     MPI_Comm communicator);
     auto funcType = LLVM::LLVMFunctionType::get(
-        i32, {ptrType, i32, sDataType.getType(),
-              ptrType, i32, rDataType.getType(),
-              comm.getType()});
+        i32, {ptrType, i32, sDataType.getType(), ptrType, i32,
+              rDataType.getType(), comm.getType()});
     // get or create function declaration:
     LLVM::LLVMFuncOp funcDecl =
         getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
 
     // count_recv is the number of elements received from each rank, not total
-    Value nRanks = mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
-    Value recvCountPerRank = LLVM::UDivOp::create(
-        rewriter, loc, i32, recvSize, nRanks);
+    Value nRanks =
+        mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm())
+            .getSize();
+    Value recvCountPerRank =
+        LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
 
     // replace op with function call
-    auto funcCall = LLVM::CallOp::create(
-        rewriter, loc, funcDecl,
-        ValueRange{sendPtr, sendSize, sDataType, recvPtr, recvCountPerRank, rDataType, comm});
+    auto funcCall =
+        LLVM::CallOp::create(rewriter, loc, funcDecl,
+                             ValueRange{sendPtr, sendSize, sDataType, recvPtr,
+                                        recvCountPerRank, rDataType, comm});
 
     if (op.getRetval())
       rewriter.replaceOp(op, funcCall.getResult());
@@ -915,9 +917,10 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
   converter.addConversion([](mpi::CommType type) {
     return IntegerType::get(type.getContext(), 64);
   });
-  patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering, CommWorldOpLowering,
-               FinalizeOpLowering, InitOpLowering, SendOpLowering,
-               RecvOpLowering, AllGatherOpLowering, AllReduceOpLowering>(converter);
+  patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
+               CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
+               SendOpLowering, RecvOpLowering, AllGatherOpLowering,
+               AllReduceOpLowering>(converter);
 }
 
 void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {

>From b083195405f3cd7d25a3db4392f6ad211be0571e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Jan 2026 08:48:40 -0800
Subject: [PATCH 4/4] clarifying implicit def of recvcount in API call

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 7c2b00c79e691..b24c12459475e 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -273,12 +273,17 @@ def MPI_AllGatherOp : MPI_Op<"allgather", []> {
   }];
   let description = [{
     MPI_Allgather collects data from all processes in a given communicator and
-    stores the data collected in the receive buffer of each process.
+    stores the gathered data in the receive buffer of each process.
 
-    Each process is expected to contribute the same amount of data.
+    Each process contributes the same amount of data defined by `sendbuf`.
+    The MPI call specifies the number of elements contributed by each process
+    via the `recvcount` parameter. However, this operation, assumes `recvbuf`
+    to be sufficiently large to hold the data contributed by all processes.
+    Therefore, `recvcount` is implicitly defined as
+    `num_elements(recvbuf) / MPI_Comm_size(comm)`.
 
-    This operation can optionally return an `!mpi.retval` value that can be used
-    to check for errors.
+    This operation may optionally return an !mpi.retval value, which can be
+    used for error checking.
   }];
 
   let arguments = (



More information about the Mlir-commits mailing list