[Mlir-commits] [mlir] [mlir][mpi] adding MPI_Allgather and lowering to LLVM (PR #176937)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 20 07:00:46 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
- Adding MPI_Allgather to MPI dialect
- Adding lowering to MPIToLLVM
- Also lowering MPI_Commsize (see also #<!-- -->140392)
@<!-- -->johnmaxrin
---
Full diff: https://github.com/llvm/llvm-project/pull/176937.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+33)
- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+116-2)
- (modified) mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir (+15)
- (modified) mlir/test/Dialect/MPI/mpiops.mlir (+6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 935e0f785ef0c..7c2b00c79e691 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -261,6 +261,39 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// AllGatherOp
+//===----------------------------------------------------------------------===//
+
+def MPI_AllGatherOp : MPI_Op<"allgather", []> {
+ let summary = [{
+ Equivalent to `MPI_Allgather(sendbuf, sendcount, sendtype,
+ recvbuf, recvcount, recvtype,
+ comm)`.
+ }];
+ let description = [{
+ MPI_Allgather collects data from all processes in a given communicator and
+ stores the data collected in the receive buffer of each process.
+
+ Each process is expected to contribute the same amount of data.
+
+ This operation can optionally return an `!mpi.retval` value that can be used
+ to check for errors.
+ }];
+
+ let arguments = (
+ ins AnyMemRef : $sendbuf,
+ AnyMemRef : $recvbuf,
+ MPI_Comm : $comm
+ );
+
+ let results = (outs Optional<MPI_Retval>:$retval);
+
+ let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $comm `)` "
+ "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
+ "(`->` type($retval)^)?";
+}
+
//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 59a16df9c59d3..f37622c0bf7f0 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -604,6 +604,62 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
}
};
+//===----------------------------------------------------------------------===//
+// CommSizeOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::CommSizeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // get some helper vars
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type i32 = rewriter.getI32Type();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ // get communicator
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+ // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+ auto SizeFuncType =
+ LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "MPI_Comm_Size", SizeFuncType);
+
+ // replace with function call
+ auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
+ auto sizeptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
+ auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
+ ValueRange{comm, sizeptr.getRes()});
+
+ // load the Size into a register
+ auto loadedSize =
+ LLVM::LoadOp::create(rewriter, loc, i32, sizeptr.getResult());
+
+ // if retval is checked, replace uses of retval with the results from the
+ // call op
+ SmallVector<Value> replacements;
+ if (op.getRetval())
+ replacements.push_back(callOp.getResult());
+
+ // replace all uses, then erase op
+ replacements.push_back(loadedSize.getRes());
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// SendOpLowering
//===----------------------------------------------------------------------===//
@@ -712,6 +768,64 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
}
};
+//===----------------------------------------------------------------------===//
+// AllGatherOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type sElemType = op.getSendbuf().getType().getElementType();
+ Type rElemType = op.getRecvbuf().getType().getElementType();
+ auto [sendPtr, sendSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+ auto [recvPtr, recvSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
+ Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+ Type i32 = rewriter.getI32Type();
+ // int MPI_Allgather(
+ // const void* buffer_send, int count_send, MPI_Datatype datatype_send,
+ // void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
+ // MPI_Comm communicator);
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, i32, sDataType.getType(),
+ ptrType, i32, rDataType.getType(),
+ comm.getType()});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
+
+ // count_recv is the number of elements received from each rank, not total
+ Value nRanks = mlir::mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
+ Value recvCountPerRank = LLVM::UDivOp::create(
+ rewriter, loc, i32, recvSize, nRanks);
+
+ // replace op with function call
+ auto funcCall = LLVM::CallOp::create(
+ rewriter, loc, funcDecl,
+ ValueRange{sendPtr, sendSize, sDataType, recvPtr, recvCountPerRank, rDataType, comm});
+
+ if (op.getRetval())
+ rewriter.replaceOp(op, funcCall.getResult());
+ else
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// AllReduceOpLowering
//===----------------------------------------------------------------------===//
@@ -801,9 +915,9 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
converter.addConversion([](mpi::CommType type) {
return IntegerType::get(type.getContext(), 64);
});
- patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
+ patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering, CommWorldOpLowering,
FinalizeOpLowering, InitOpLowering, SendOpLowering,
- RecvOpLowering, AllReduceOpLowering>(converter);
+ RecvOpLowering, AllGatherOpLowering, AllReduceOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 35fc0f5d2e754..4c1beee2fe144 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -3,6 +3,9 @@
// COM: Test MPICH ABI
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+// CHECK: llvm.func @MPI_Comm_Size(i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -88,6 +91,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
+ // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
+ %err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
// CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -119,6 +125,10 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
+// CHECK: llvm.func @MPI_Comm_Size(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
@@ -193,6 +203,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+ // CHECK: llvm.call @MPI_Comm_Size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: llvm.udiv {{.*}} : i32
+ // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
+ mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index ef457628fe2c4..87a5647ee91d2 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -65,6 +65,12 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK-NEXT: [[v5:%.*]] = mpi.barrier([[v1]]) -> !mpi.retval
%err7 = mpi.barrier(%comm) -> !mpi.retval
+ // CHECK-NEXT: [[e3:%.*]] = mpi.allgather([[varg0]], [[varg0]], [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ %err3 = mpi.allgather(%ref, %ref, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
+ // CHECK-NEXT: mpi.allgather([[varg0]], [[varg0]], [[v1]]) : memref<100xf32>, memref<100xf32>
+ mpi.allgather(%ref, %ref, %comm) : memref<100xf32>, memref<100xf32>
+
// CHECK-NEXT: [[v6:%.*]] = mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
%err8 = mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
``````````
</details>
https://github.com/llvm/llvm-project/pull/176937
More information about the Mlir-commits
mailing list