[Mlir-commits] [mlir] [MLIR][MPI] Add LLVM lowering for mpi.comm_size (PR #140392)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat May 17 11:37:22 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Robert K Samuel (johnmaxrin)

<details>
<summary>Changes</summary>

This patch adds a conversion pattern to lower the `mpi.comm_size` operation from the MPI dialect to the LLVM dialect. The lowering emits a call to the `@<!-- -->MPI_Comm_size` runtime function, enabling programs using the MPI dialect to correctly obtain the size of an MPI communicator during code generation.

The change is implemented in MPITOLLVM.cpp, with associated logic registered in the conversion pass. 

This allows end-to-end lowering of MLIR programs using `mpi.comm_size`.

---
Full diff: https://github.com/llvm/llvm-project/pull/140392.diff


1 Files Affected:

- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+57-1) 


``````````diff
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 5575b295ae20a..3cfcbb1d1c940 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -445,6 +445,62 @@ struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// CommSizeOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::CommSizeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get some helper vars
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type i32 = rewriter.getI32Type();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    // get communicator
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+    // LLVM Function type representing `i32 MPI_Comm_size(ptr, ptr)`
+    auto rankFuncType =
+        LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "MPI_Comm_size", rankFuncType);
+
+    // replace with function call
+    auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+    auto sizeptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
+    auto callOp = rewriter.create<LLVM::CallOp>(
+        loc, initDecl, ValueRange{comm, sizeptr.getRes()});
+
+    // load the rank into a register
+    auto loadedRank =
+        rewriter.create<LLVM::LoadOp>(loc, i32, sizeptr.getResult());
+
+    // if retval is checked, replace uses of retval with the results from the
+    // call op
+    SmallVector<Value> replacements;
+    if (op.getRetval())
+      replacements.push_back(callOp.getResult());
+
+    // replace all uses, then erase op
+    replacements.push_back(loadedRank.getRes());
+    rewriter.replaceOp(op, replacements);
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // FinalizeOpLowering
 //===----------------------------------------------------------------------===//
@@ -801,7 +857,7 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
   });
   patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
                FinalizeOpLowering, InitOpLowering, SendOpLowering,
-               RecvOpLowering, AllReduceOpLowering>(converter);
+               RecvOpLowering, AllReduceOpLowering, CommSizeOpLowering>(converter);
 }
 
 void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/140392


More information about the Mlir-commits mailing list