[Mlir-commits] [mlir] [mlir][mpi] fixing in-place and 0d mpi.all_reduce (PR #134225)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 3 03:13:58 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

* inplace allreduce needs special MPI token MPI_IN_PLACE as send buffer
* 0d tensors have no sizes/strides in LLVM memref struct

---
Full diff: https://github.com/llvm/llvm-project/pull/134225.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+29-5) 
- (modified) mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir (+6-2) 


``````````diff
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 9df5e992e8ebd..5575b295ae20a 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -15,8 +15,10 @@
 #include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <memory>
@@ -57,9 +59,14 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
       loc, rewriter.getI64Type(), memRef, 2);
   Value resPtr =
       rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
-  Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
-                                                     ArrayRef<int64_t>{3, 0});
-  size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+  Value size;
+  if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
+    size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
+                                                 ArrayRef<int64_t>{3, 0});
+    size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+  } else {
+    size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32);
+  }
   return {resPtr, size};
 }
 
@@ -97,6 +104,9 @@ class MPIImplTraits {
   /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
   virtual intptr_t getStatusIgnore() = 0;
 
+  /// Get the MPI_IN_PLACE value (void *).
+  virtual void *getInPlace() = 0;
+
   /// Gets or creates an MPI datatype as a value which corresponds to the given
   /// type.
   virtual Value getDataType(const Location loc,
@@ -158,6 +168,8 @@ class MPICHImplTraits : public MPIImplTraits {
 
   intptr_t getStatusIgnore() override { return 1; }
 
+  void *getInPlace() override { return reinterpret_cast<void *>(-1); }
+
   Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
                     Type type) override {
     int32_t mtype = 0;
@@ -283,6 +295,8 @@ class OMPIImplTraits : public MPIImplTraits {
 
   intptr_t getStatusIgnore() override { return 0; }
 
+  void *getInPlace() override { return reinterpret_cast<void *>(1); }
+
   Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
                     Type type) override {
     StringRef mtype;
@@ -516,7 +530,8 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
                    outPtr.getRes()});
 
     // load the communicator into a register
-    auto res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
+    Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
+    res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res);
 
     // if retval is checked, replace uses of retval with the results from the
     // call op
@@ -525,7 +540,7 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
       replacements.push_back(callOp.getResult());
 
     // replace op
-    replacements.push_back(res.getRes());
+    replacements.push_back(res);
     rewriter.replaceOp(op, replacements);
 
     return success();
@@ -709,6 +724,7 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
     Location loc = op.getLoc();
     MLIRContext *context = rewriter.getContext();
     Type i32 = rewriter.getI32Type();
+    Type i64 = rewriter.getI64Type();
     Type elemType = op.getSendbuf().getType().getElementType();
 
     // ptrType `!llvm.ptr`
@@ -719,6 +735,14 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
         getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
     auto [recvPtr, recvSize] =
         getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+
+    // If input and output are the same, request in-place operation.
+    if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
+      sendPtr = rewriter.create<LLVM::ConstantOp>(
+          loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
+      sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
+    }
+
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
     Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 174f7c79b9d50..35fc0f5d2e754 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -98,10 +98,12 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32
+    // CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64
+    // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
     // CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
     // CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
     // CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32
-    // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[v61]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+    // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
     mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
     // CHECK: llvm.call @MPI_Finalize() : () -> i32
@@ -202,10 +204,12 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+    // CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64
+    // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
     // CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
     // CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
     // CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
-    // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+    // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
     mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
     // CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32

``````````

</details>


https://github.com/llvm/llvm-project/pull/134225


More information about the Mlir-commits mailing list