[Mlir-commits] [mlir] [mlir][mpi] adding MPI_Allgather and lowering to LLVM (PR #176937)
Frank Schlimbach
llvmlistbot at llvm.org
Tue Jan 20 08:48:56 PST 2026
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/176937
>From 20c83362ce3db4487b0deb221ff93d09f5b8e0f5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Jan 2026 06:38:59 -0800
Subject: [PATCH 1/4] adding MPI_Allgather and lowering to LLVM
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 33 +++++
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 118 +++++++++++++++++-
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 15 +++
mlir/test/Dialect/MPI/mpiops.mlir | 6 +
4 files changed, 170 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 935e0f785ef0c..7c2b00c79e691 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -261,6 +261,39 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// AllGatherOp
+//===----------------------------------------------------------------------===//
+
+def MPI_AllGatherOp : MPI_Op<"allgather", []> {
+ let summary = [{
+ Equivalent to `MPI_Allgather(sendbuf, sendcount, sendtype,
+ recvbuf, recvcount, recvtype,
+ comm)`.
+ }];
+ let description = [{
+ MPI_Allgather collects data from all processes in a given communicator and
+ stores the data collected in the receive buffer of each process.
+
+ Each process is expected to contribute the same amount of data.
+
+ This operation can optionally return an `!mpi.retval` value that can be used
+ to check for errors.
+ }];
+
+ let arguments = (
+ ins AnyMemRef : $sendbuf,
+ AnyMemRef : $recvbuf,
+ MPI_Comm : $comm
+ );
+
+ let results = (outs Optional<MPI_Retval>:$retval);
+
+ let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $comm `)` "
+ "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
+ "(`->` type($retval)^)?";
+}
+
//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 59a16df9c59d3..f37622c0bf7f0 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -604,6 +604,62 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
}
};
+//===----------------------------------------------------------------------===//
+// CommSizeOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::CommSizeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // get some helper vars
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type i32 = rewriter.getI32Type();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ // get communicator
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+ // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+ auto SizeFuncType =
+ LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "MPI_Comm_Size", SizeFuncType);
+
+ // replace with function call
+ auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
+ auto sizeptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
+ auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
+ ValueRange{comm, sizeptr.getRes()});
+
+ // load the Size into a register
+ auto loadedSize =
+ LLVM::LoadOp::create(rewriter, loc, i32, sizeptr.getResult());
+
+ // if retval is checked, replace uses of retval with the results from the
+ // call op
+ SmallVector<Value> replacements;
+ if (op.getRetval())
+ replacements.push_back(callOp.getResult());
+
+ // replace all uses, then erase op
+ replacements.push_back(loadedSize.getRes());
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// SendOpLowering
//===----------------------------------------------------------------------===//
@@ -712,6 +768,64 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
}
};
+//===----------------------------------------------------------------------===//
+// AllGatherOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type sElemType = op.getSendbuf().getType().getElementType();
+ Type rElemType = op.getRecvbuf().getType().getElementType();
+ auto [sendPtr, sendSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+ auto [recvPtr, recvSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
+ Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+ Type i32 = rewriter.getI32Type();
+ // int MPI_Allgather(
+ // const void* buffer_send, int count_send, MPI_Datatype datatype_send,
+ // void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
+ // MPI_Comm communicator);
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, i32, sDataType.getType(),
+ ptrType, i32, rDataType.getType(),
+ comm.getType()});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
+
+ // count_recv is the number of elements received from each rank, not total
+ Value nRanks = mlir::mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
+ Value recvCountPerRank = LLVM::UDivOp::create(
+ rewriter, loc, i32, recvSize, nRanks);
+
+ // replace op with function call
+ auto funcCall = LLVM::CallOp::create(
+ rewriter, loc, funcDecl,
+ ValueRange{sendPtr, sendSize, sDataType, recvPtr, recvCountPerRank, rDataType, comm});
+
+ if (op.getRetval())
+ rewriter.replaceOp(op, funcCall.getResult());
+ else
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// AllReduceOpLowering
//===----------------------------------------------------------------------===//
@@ -801,9 +915,9 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
converter.addConversion([](mpi::CommType type) {
return IntegerType::get(type.getContext(), 64);
});
- patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
+ patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering, CommWorldOpLowering,
FinalizeOpLowering, InitOpLowering, SendOpLowering,
- RecvOpLowering, AllReduceOpLowering>(converter);
+ RecvOpLowering, AllGatherOpLowering, AllReduceOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 35fc0f5d2e754..4c1beee2fe144 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -3,6 +3,9 @@
// COM: Test MPICH ABI
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+// CHECK: llvm.func @MPI_Comm_Size(i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -88,6 +91,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
+ // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
+ %err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
// CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -119,6 +125,10 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
+// CHECK: llvm.func @MPI_Comm_Size(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
@@ -193,6 +203,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+ // CHECK: llvm.call @MPI_Comm_Size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: llvm.udiv {{.*}} : i32
+ // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
+ mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index ef457628fe2c4..87a5647ee91d2 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -65,6 +65,12 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK-NEXT: [[v5:%.*]] = mpi.barrier([[v1]]) -> !mpi.retval
%err7 = mpi.barrier(%comm) -> !mpi.retval
+ // CHECK-NEXT: [[e3:%.*]] = mpi.allgather([[varg0]], [[varg0]], [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ %err3 = mpi.allgather(%ref, %ref, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
+ // CHECK-NEXT: mpi.allgather([[varg0]], [[varg0]], [[v1]]) : memref<100xf32>, memref<100xf32>
+ mpi.allgather(%ref, %ref, %comm) : memref<100xf32>, memref<100xf32>
+
// CHECK-NEXT: [[v6:%.*]] = mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
%err8 = mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
>From fc4b6704c466026ac67e7fbf4677c2a0b3b8f24e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Jan 2026 08:13:42 -0800
Subject: [PATCH 2/4] addressing minor review comment
---
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index f37622c0bf7f0..bc5aeec8cb4b0 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -808,7 +808,7 @@ struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
// count_recv is the number of elements received from each rank, not total
- Value nRanks = mlir::mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
+ Value nRanks = mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
Value recvCountPerRank = LLVM::UDivOp::create(
rewriter, loc, i32, recvSize, nRanks);
>From b61912fde1b81a451c7b7ed41a0aec4eb25856fc Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Jan 2026 08:31:40 -0800
Subject: [PATCH 3/4] clang-format
---
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 27 ++++++++++++---------
1 file changed, 15 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index bc5aeec8cb4b0..4a1c5d1f7846c 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -800,22 +800,24 @@ struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
// void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
// MPI_Comm communicator);
auto funcType = LLVM::LLVMFunctionType::get(
- i32, {ptrType, i32, sDataType.getType(),
- ptrType, i32, rDataType.getType(),
- comm.getType()});
+ i32, {ptrType, i32, sDataType.getType(), ptrType, i32,
+ rDataType.getType(), comm.getType()});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
// count_recv is the number of elements received from each rank, not total
- Value nRanks = mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
- Value recvCountPerRank = LLVM::UDivOp::create(
- rewriter, loc, i32, recvSize, nRanks);
+ Value nRanks =
+ mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm())
+ .getSize();
+ Value recvCountPerRank =
+ LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
// replace op with function call
- auto funcCall = LLVM::CallOp::create(
- rewriter, loc, funcDecl,
- ValueRange{sendPtr, sendSize, sDataType, recvPtr, recvCountPerRank, rDataType, comm});
+ auto funcCall =
+ LLVM::CallOp::create(rewriter, loc, funcDecl,
+ ValueRange{sendPtr, sendSize, sDataType, recvPtr,
+ recvCountPerRank, rDataType, comm});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
@@ -915,9 +917,10 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
converter.addConversion([](mpi::CommType type) {
return IntegerType::get(type.getContext(), 64);
});
- patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering, CommWorldOpLowering,
- FinalizeOpLowering, InitOpLowering, SendOpLowering,
- RecvOpLowering, AllGatherOpLowering, AllReduceOpLowering>(converter);
+ patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
+ CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
+ SendOpLowering, RecvOpLowering, AllGatherOpLowering,
+ AllReduceOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
>From b083195405f3cd7d25a3db4392f6ad211be0571e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Jan 2026 08:48:40 -0800
Subject: [PATCH 4/4] clarifying implicit def of recvcount in API call
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 7c2b00c79e691..b24c12459475e 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -273,12 +273,17 @@ def MPI_AllGatherOp : MPI_Op<"allgather", []> {
}];
let description = [{
MPI_Allgather collects data from all processes in a given communicator and
- stores the data collected in the receive buffer of each process.
+ stores the gathered data in the receive buffer of each process.
- Each process is expected to contribute the same amount of data.
+ Each process contributes the same amount of data defined by `sendbuf`.
+ The MPI call specifies the number of elements contributed by each process
+ via the `recvcount` parameter. However, this operation, assumes `recvbuf`
+ to be sufficiently large to hold the data contributed by all processes.
+ Therefore, `recvcount` is implicitly defined as
+ `num_elements(recvbuf) / MPI_Comm_size(comm)`.
- This operation can optionally return an `!mpi.retval` value that can be used
- to check for errors.
+ This operation may optionally return an !mpi.retval value, which can be
+ used for error checking.
}];
let arguments = (
More information about the Mlir-commits
mailing list