[Mlir-commits] [mlir] [mlir][mpi] adding MPI_Allgather and lowering to LLVM (PR #176937)

Frank Schlimbach llvmlistbot at llvm.org
Tue Jan 20 08:12:54 PST 2026


================
@@ -712,6 +768,64 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// AllGatherOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type sElemType = op.getSendbuf().getType().getElementType();
+    Type rElemType = op.getRecvbuf().getType().getElementType();
+    auto [sendPtr, sendSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+    auto [recvPtr, recvSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
+    Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+    Type i32 = rewriter.getI32Type();
+    // int MPI_Allgather(
+    //     const void* buffer_send, int count_send, MPI_Datatype datatype_send,
+    //     void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
+    //     MPI_Comm communicator);
+    auto funcType = LLVM::LLVMFunctionType::get(
+        i32, {ptrType, i32, sDataType.getType(),
+              ptrType, i32, rDataType.getType(),
+              comm.getType()});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp funcDecl =
+        getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
+
+    // count_recv is the number of elements received from each rank, not total
+    Value nRanks = mlir::mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm()).getSize();
----------------
fschlimb wrote:

The dialect is based on memrefs. Hence all we have is the output buffer and so the total size. We could add a recvsize arg to the op, but IMHO this wouldn't be less odd.

https://github.com/llvm/llvm-project/pull/176937


More information about the Mlir-commits mailing list