[Mlir-commits] [mlir] [mlir][mpi] adding MPI_Allgather and lowering to LLVM (PR #176937)
Anton Lydike
llvmlistbot at llvm.org
Tue Jan 20 08:20:06 PST 2026
================
@@ -712,6 +768,64 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
}
};
+//===----------------------------------------------------------------------===//
+// AllGatherOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type sElemType = op.getSendbuf().getType().getElementType();
+ Type rElemType = op.getRecvbuf().getType().getElementType();
+ auto [sendPtr, sendSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+ auto [recvPtr, recvSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
+ Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+ Type i32 = rewriter.getI32Type();
+ // int MPI_Allgather(
+ // const void* buffer_send, int count_send, MPI_Datatype datatype_send,
+ // void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
+ // MPI_Comm communicator);
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, i32, sDataType.getType(),
+ ptrType, i32, rDataType.getType(),
+ comm.getType()});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
+
+ // count_recv is the number of elements received from each rank, not total
+ Value nRanks = mlir::mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
----------------
AntonLydike wrote:
Hmm, I agree. I would be tempted to say that `sendcount = recvcount` and take the value from there. But there *are* valid cases where this is not the case, e.g. when doing struct packing/unpacking as part of the allgather using structured MPI datatypes (sending structs of (int, int, int), and receiving "raw" ints).
So we do need to somehow figure out the receive size. Two options here:
1. Pass receive size as separate parameter: Clunky, would add at least a `memref.dim` op per invocation
2. Infer receive size from `MPI_Comm_Size` and memref size: Cleaner, only require `memref.view` when aiming to receive into a sub-region of the actually passed memref
I think looking at these two options it's pretty clear (to me) what the preferred option would be. Next to the usability aspects, the semantics of "will receive into the whole buffer" are much nicer. That said, we should add a line or two in the docs mentioning this.
https://github.com/llvm/llvm-project/pull/176937
More information about the Mlir-commits
mailing list