[Mlir-commits] [mlir] 9543fc7 - [mlir][mpi] adding MPI_Allgather and lowering to LLVM (#176937)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 21 03:24:23 PST 2026


Author: Frank Schlimbach
Date: 2026-01-21T12:24:18+01:00
New Revision: 9543fc766d3471d0bb92a1d42314ba21a20b5bbb

URL: https://github.com/llvm/llvm-project/commit/9543fc766d3471d0bb92a1d42314ba21a20b5bbb
DIFF: https://github.com/llvm/llvm-project/commit/9543fc766d3471d0bb92a1d42314ba21a20b5bbb.diff

LOG: [mlir][mpi] adding MPI_Allgather and lowering to LLVM (#176937)

- Adding MPI_Allgather to MPI dialect
- Adding lowering to MPIToLLVM
- Also lowering MPI_Commsize (see also #140392)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
    mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
    mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
    mlir/test/Dialect/MPI/mpiops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 935e0f785ef0c..b24c12459475e 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -261,6 +261,44 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AllGatherOp
+//===----------------------------------------------------------------------===//
+
+def MPI_AllGatherOp : MPI_Op<"allgather", []> {
+  let summary = [{
+    Equivalent to `MPI_Allgather(sendbuf, sendcount, sendtype,
+                                 recvbuf, recvcount, recvtype,
+                                 comm)`.
+  }];
+  let description = [{
+    MPI_Allgather collects data from all processes in a given communicator and
+    stores the gathered data in the receive buffer of each process.
+
+    Each process contributes the same amount of data defined by `sendbuf`.
+    The MPI call specifies the number of elements contributed by each process
+    via the `recvcount` parameter. However, this operation, assumes `recvbuf`
+    to be sufficiently large to hold the data contributed by all processes.
+    Therefore, `recvcount` is implicitly defined as
+    `num_elements(recvbuf) / MPI_Comm_size(comm)`.
+
+    This operation may optionally return an !mpi.retval value, which can be
+    used for error checking.
+  }];
+
+  let arguments = (
+    ins AnyMemRef : $sendbuf,
+        AnyMemRef : $recvbuf,
+        MPI_Comm : $comm
+  );
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $comm `)` "
+                       "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
+                       "(`->` type($retval)^)?";
+}
+
 //===----------------------------------------------------------------------===//
 // AllReduceOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 59a16df9c59d3..4a1c5d1f7846c 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -604,6 +604,62 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// CommSizeOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::CommSizeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get some helper vars
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type i32 = rewriter.getI32Type();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    // get communicator
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+    // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+    auto SizeFuncType =
+        LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "MPI_Comm_Size", SizeFuncType);
+
+    // replace with function call
+    auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
+    auto sizeptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
+    auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
+                                       ValueRange{comm, sizeptr.getRes()});
+
+    // load the Size into a register
+    auto loadedSize =
+        LLVM::LoadOp::create(rewriter, loc, i32, sizeptr.getResult());
+
+    // if retval is checked, replace uses of retval with the results from the
+    // call op
+    SmallVector<Value> replacements;
+    if (op.getRetval())
+      replacements.push_back(callOp.getResult());
+
+    // replace all uses, then erase op
+    replacements.push_back(loadedSize.getRes());
+    rewriter.replaceOp(op, replacements);
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // SendOpLowering
 //===----------------------------------------------------------------------===//
@@ -712,6 +768,66 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// AllGatherOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type sElemType = op.getSendbuf().getType().getElementType();
+    Type rElemType = op.getRecvbuf().getType().getElementType();
+    auto [sendPtr, sendSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+    auto [recvPtr, recvSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
+    Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+    Type i32 = rewriter.getI32Type();
+    // int MPI_Allgather(
+    //     const void* buffer_send, int count_send, MPI_Datatype datatype_send,
+    //     void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
+    //     MPI_Comm communicator);
+    auto funcType = LLVM::LLVMFunctionType::get(
+        i32, {ptrType, i32, sDataType.getType(), ptrType, i32,
+              rDataType.getType(), comm.getType()});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp funcDecl =
+        getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
+
+    // count_recv is the number of elements received from each rank, not total
+    Value nRanks =
+        mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm())
+            .getSize();
+    Value recvCountPerRank =
+        LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
+
+    // replace op with function call
+    auto funcCall =
+        LLVM::CallOp::create(rewriter, loc, funcDecl,
+                             ValueRange{sendPtr, sendSize, sDataType, recvPtr,
+                                        recvCountPerRank, rDataType, comm});
+
+    if (op.getRetval())
+      rewriter.replaceOp(op, funcCall.getResult());
+    else
+      rewriter.eraseOp(op);
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // AllReduceOpLowering
 //===----------------------------------------------------------------------===//
@@ -801,9 +917,10 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
   converter.addConversion([](mpi::CommType type) {
     return IntegerType::get(type.getContext(), 64);
   });
-  patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
-               FinalizeOpLowering, InitOpLowering, SendOpLowering,
-               RecvOpLowering, AllReduceOpLowering>(converter);
+  patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
+               CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
+               SendOpLowering, RecvOpLowering, AllGatherOpLowering,
+               AllReduceOpLowering>(converter);
 }
 
 void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {

diff  --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 35fc0f5d2e754..4c1beee2fe144 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -3,6 +3,9 @@
 // COM: Test MPICH ABI
 // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+// CHECK: llvm.func @MPI_Comm_Size(i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
 // CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -88,6 +91,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
     %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
 
+    // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
+    %err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
     // CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -119,6 +125,10 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
 // CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
+// CHECK: llvm.func @MPI_Comm_Size(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
@@ -193,6 +203,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
     // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
     %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
+    
+    // CHECK: llvm.call @MPI_Comm_Size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
+    // CHECK: llvm.udiv {{.*}} : i32
+    // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
+    mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32>
 
     // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>

diff  --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index ef457628fe2c4..87a5647ee91d2 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -65,6 +65,12 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // CHECK-NEXT: [[v5:%.*]] = mpi.barrier([[v1]]) -> !mpi.retval
     %err7 = mpi.barrier(%comm) -> !mpi.retval
 
+    // CHECK-NEXT: [[e3:%.*]] = mpi.allgather([[varg0]], [[varg0]], [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+    %err3 = mpi.allgather(%ref, %ref, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
+    // CHECK-NEXT: mpi.allgather([[varg0]], [[varg0]], [[v1]]) : memref<100xf32>, memref<100xf32>
+    mpi.allgather(%ref, %ref, %comm) : memref<100xf32>, memref<100xf32>
+
     // CHECK-NEXT: [[v6:%.*]] = mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
     %err8 = mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
 


        


More information about the Mlir-commits mailing list