[Mlir-commits] [mlir] [MLIR][MPI] Add LLVM lowering for mpi.comm_size (PR #140392)
Robert K Samuel
llvmlistbot at llvm.org
Sat May 17 11:36:26 PDT 2025
https://github.com/johnmaxrin created https://github.com/llvm/llvm-project/pull/140392
This patch adds a conversion pattern to lower the `mpi.comm_size` operation from the MPI dialect to the LLVM dialect. The lowering emits a call to the `@MPI_Comm_size` runtime function, enabling programs using the MPI dialect to correctly obtain the size of an MPI communicator during code generation.
The change is implemented in MPITOLLVM.cpp, with associated logic registered in the conversion pass.
This allows end-to-end lowering of MLIR programs using `mpi.comm_size`.
>From 5eb9528acc31a077c9dc3bacbb569b1298ec567b Mon Sep 17 00:00:00 2001
From: 60b36t <johnmaxrin at gmail.com>
Date: Sun, 18 May 2025 00:03:07 +0530
Subject: [PATCH] [MLIR][MPI] Add LLVM lowering for mpi.comm_size
This patch adds a conversion pattern to lower the `mpi.comm_size` operation
from the MPI dialect to the LLVM dialect. The lowering emits a call to the
`@MPI_Comm_size` runtime function, enabling programs using the MPI dialect
to correctly obtain the size of an MPI communicator during code generation.
The change is implemented in MPITOLLVM.cpp, with associated logic registered
in the conversion pass. A test case is included under
`test/Conversion/MPIToLLVM/`.
This allows end-to-end lowering of MLIR programs using `mpi.comm_size`.
---
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 58 ++++++++++++++++++++-
1 file changed, 57 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 5575b295ae20a..3cfcbb1d1c940 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -445,6 +445,62 @@ struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
}
};
+//===----------------------------------------------------------------------===//
+// CommSizeOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::CommSizeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // get some helper vars
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type i32 = rewriter.getI32Type();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ // get communicator
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+ // LLVM Function type representing `i32 MPI_Comm_size(ptr, ptr)`
+ auto rankFuncType =
+ LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "MPI_Comm_size", rankFuncType);
+
+ // replace with function call
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+ auto sizeptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
+ auto callOp = rewriter.create<LLVM::CallOp>(
+ loc, initDecl, ValueRange{comm, sizeptr.getRes()});
+
+ // load the rank into a register
+ auto loadedRank =
+ rewriter.create<LLVM::LoadOp>(loc, i32, sizeptr.getResult());
+
+ // if retval is checked, replace uses of retval with the results from the
+ // call op
+ SmallVector<Value> replacements;
+ if (op.getRetval())
+ replacements.push_back(callOp.getResult());
+
+ // replace all uses, then erase op
+ replacements.push_back(loadedRank.getRes());
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// FinalizeOpLowering
//===----------------------------------------------------------------------===//
@@ -801,7 +857,7 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
});
patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
FinalizeOpLowering, InitOpLowering, SendOpLowering,
- RecvOpLowering, AllReduceOpLowering>(converter);
+ RecvOpLowering, AllReduceOpLowering, CommSizeOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
More information about the Mlir-commits
mailing list