[Mlir-commits] [mlir] [mlir][mpi] adding MPI_Allgather and lowering to LLVM (PR #176937)
Frank Schlimbach
llvmlistbot at llvm.org
Tue Jan 20 08:12:54 PST 2026
================
@@ -712,6 +768,64 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
}
};
+//===----------------------------------------------------------------------===//
+// AllGatherOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type sElemType = op.getSendbuf().getType().getElementType();
+ Type rElemType = op.getRecvbuf().getType().getElementType();
+ auto [sendPtr, sendSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+ auto [recvPtr, recvSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
+ Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+ Type i32 = rewriter.getI32Type();
+ // int MPI_Allgather(
+ // const void* buffer_send, int count_send, MPI_Datatype datatype_send,
+ // void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
+ // MPI_Comm communicator);
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, i32, sDataType.getType(),
+ ptrType, i32, rDataType.getType(),
+ comm.getType()});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
+
+ // count_recv is the number of elements received from each rank, not total
+ Value nRanks = mlir::mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
----------------
fschlimb wrote:
The dialect is based on memrefs. Hence all we have is the output buffer and so the total size. We could add a recvsize arg to the op, but IMHO this wouldn't be less odd.
https://github.com/llvm/llvm-project/pull/176937
More information about the Mlir-commits
mailing list