[Mlir-commits] [mlir] [mlir][mpi] Lowering MPI_Allreduce (PR #133133)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 26 10:54:23 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

Adding lowering of MPI_Allreduce.

FYI: @<!-- -->tkarna @<!-- -->mofeing


---

Patch is 21.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133133.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/MPI/IR/MPI.td (-5) 
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+1-1) 
- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+201-22) 
- (renamed) mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir (+36-5) 
- (renamed) mlir/test/Dialect/MPI/mpiops.mlir (+4-4) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 7c84443e5520d..f2837e71df060 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -246,12 +246,7 @@ def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
       MPI_OpMaxloc,
       MPI_OpReplace
     ]> {
-  let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::mpi";
 }
 
-def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
 #endif // MLIR_DIALECT_MPI_IR_MPI_TD
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index db28bd09678f8..a8267b115b9e6 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -244,7 +244,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
   let arguments = (
     ins AnyMemRef : $sendbuf,
     AnyMemRef : $recvbuf,
-    MPI_OpClassAttr : $op
+    MPI_OpClassEnum : $op
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d91f9512ccb8f..4e0f59305a647 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
       moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
 }
 
+std::pair<Value, Value> getRawPtrAndSize(const Location loc,
+                                         ConversionPatternRewriter &rewriter,
+                                         Value memRef, Type elType) {
+  Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+  Value dataPtr =
+      rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
+  Value offset = rewriter.create<LLVM::ExtractValueOp>(
+      loc, rewriter.getI64Type(), memRef, 2);
+  Value resPtr =
+      rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
+  Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
+                                                     ArrayRef<int64_t>{3, 0});
+  size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+  return {resPtr, size};
+}
+
 /// When lowering the mpi dialect to functions calls certain details
 /// differ between various MPI implementations. This class will provide
 /// these in a generic way, depending on the MPI implementation that got
@@ -77,6 +93,12 @@ class MPIImplTraits {
   /// type.
   virtual Value getDataType(const Location loc,
                             ConversionPatternRewriter &rewriter, Type type) = 0;
+
+  /// Gets or creates an MPI_Op value which corresponds to the given
+  /// enum value.
+  virtual Value getMPIOp(const Location loc,
+                         ConversionPatternRewriter &rewriter,
+                         mpi::MPI_OpClassEnum opAttr) = 0;
 };
 
 //===----------------------------------------------------------------------===//
@@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
   static constexpr int MPI_UINT16_T = 0x4c00023c;
   static constexpr int MPI_UINT32_T = 0x4c00043d;
   static constexpr int MPI_UINT64_T = 0x4c00083e;
+  static constexpr int MPI_MAX = 0x58000001;
+  static constexpr int MPI_MIN = 0x58000002;
+  static constexpr int MPI_SUM = 0x58000003;
+  static constexpr int MPI_PROD = 0x58000004;
+  static constexpr int MPI_LAND = 0x58000005;
+  static constexpr int MPI_BAND = 0x58000006;
+  static constexpr int MPI_LOR = 0x58000007;
+  static constexpr int MPI_BOR = 0x58000008;
+  static constexpr int MPI_LXOR = 0x58000009;
+  static constexpr int MPI_BXOR = 0x5800000a;
+  static constexpr int MPI_MINLOC = 0x5800000b;
+  static constexpr int MPI_MAXLOC = 0x5800000c;
+  static constexpr int MPI_REPLACE = 0x5800000d;
+  static constexpr int MPI_NO_OP = 0x5800000e;
 
 public:
   using MPIImplTraits::MPIImplTraits;
@@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
       assert(false && "unsupported type");
     return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
   }
+
+  Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
+                 mpi::MPI_OpClassEnum opAttr) override {
+    int32_t op = MPI_NO_OP;
+    switch (opAttr) {
+    case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+      op = MPI_NO_OP;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MAX:
+      op = MPI_MAX;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MIN:
+      op = MPI_MIN;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_SUM:
+      op = MPI_SUM;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_PROD:
+      op = MPI_PROD;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LAND:
+      op = MPI_LAND;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BAND:
+      op = MPI_BAND;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LOR:
+      op = MPI_LOR;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BOR:
+      op = MPI_BOR;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LXOR:
+      op = MPI_LXOR;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BXOR:
+      op = MPI_BXOR;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MINLOC:
+      op = MPI_MINLOC;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+      op = MPI_MAXLOC;
+      break;
+    case mpi::MPI_OpClassEnum::MPI_REPLACE:
+      op = MPI_REPLACE;
+      break;
+    }
+    return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
+  }
 };
 
 //===----------------------------------------------------------------------===//
@@ -205,15 +291,74 @@ class OMPIImplTraits : public MPIImplTraits {
 
     auto context = rewriter.getContext();
     // get external opaque struct pointer type
-    auto commStructT =
+    auto typeStructT =
         LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
     // make sure global op definition exists
-    getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
+    getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
     // get address of symbol
     return rewriter.create<LLVM::AddressOfOp>(
         loc, LLVM::LLVMPointerType::get(context),
         SymbolRefAttr::get(context, mtype));
   }
+
+  Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
+                 mpi::MPI_OpClassEnum opAttr) override {
+    StringRef op;
+    switch (opAttr) {
+    case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+      op = "ompi_mpi_no_op";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MAX:
+      op = "ompi_mpi_max";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MIN:
+      op = "ompi_mpi_min";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_SUM:
+      op = "ompi_mpi_sum";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_PROD:
+      op = "ompi_mpi_prod";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LAND:
+      op = "ompi_mpi_land";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BAND:
+      op = "ompi_mpi_band";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LOR:
+      op = "ompi_mpi_lor";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BOR:
+      op = "ompi_mpi_bor";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_LXOR:
+      op = "ompi_mpi_lxor";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_BXOR:
+      op = "ompi_mpi_bxor";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MINLOC:
+      op = "ompi_mpi_minloc";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+      op = "ompi_mpi_maxloc";
+      break;
+    case mpi::MPI_OpClassEnum::MPI_REPLACE:
+      op = "ompi_mpi_replace";
+      break;
+    }
+    auto context = rewriter.getContext();
+    // get external opaque struct pointer type
+    auto opStructT =
+        LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
+    // make sure global op definition exists
+    getOrDefineExternalStruct(loc, rewriter, op, opStructT);
+    // get address of symbol
+    return rewriter.create<LLVM::AddressOfOp>(
+        loc, LLVM::LLVMPointerType::get(context),
+        SymbolRefAttr::get(context, op));
+  }
 };
 
 std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
@@ -365,8 +510,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
     Location loc = op.getLoc();
     MLIRContext *context = rewriter.getContext();
     Type i32 = rewriter.getI32Type();
-    Type i64 = rewriter.getI64Type();
-    Value memRef = adaptor.getRef();
     Type elemType = op.getRef().getType().getElementType();
 
     // ptrType `!llvm.ptr`
@@ -376,14 +519,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
     auto moduleOp = op->getParentOfType<ModuleOp>();
 
     // get MPI_COMM_WORLD, dataType and pointer
-    Value dataPtr =
-        rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
-    Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
-    dataPtr =
-        rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
-    Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
-                                                       ArrayRef<int64_t>{3, 0});
-    size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+    auto [dataPtr, size] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -425,7 +562,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     MLIRContext *context = rewriter.getContext();
     Type i32 = rewriter.getI32Type();
     Type i64 = rewriter.getI64Type();
-    Value memRef = adaptor.getRef();
     Type elemType = op.getRef().getType().getElementType();
 
     // ptrType `!llvm.ptr`
@@ -435,14 +571,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     auto moduleOp = op->getParentOfType<ModuleOp>();
 
     // get MPI_COMM_WORLD, dataType, status_ignore and pointer
-    Value dataPtr =
-        rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
-    Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
-    dataPtr =
-        rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
-    Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
-                                                       ArrayRef<int64_t>{3, 0});
-    size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+    auto [dataPtr, size] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -474,6 +604,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// AllReduceOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type i32 = rewriter.getI32Type();
+    Type elemType = op.getSendbuf().getType().getElementType();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    auto [sendPtr, sendSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
+    auto [recvPtr, recvSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+    Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+    Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
+    Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+    // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
+    //                    MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
+    auto funcType = LLVM::LLVMFunctionType::get(
+        i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
+              commWorld.getType()});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp funcDecl =
+        getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
+
+    // replace op with function call
+    auto funcCall = rewriter.create<LLVM::CallOp>(
+        loc, funcDecl,
+        ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
+
+    if (op.getRetval())
+      rewriter.replaceOp(op, funcCall.getResult());
+    else
+      rewriter.eraseOp(op);
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertToLLVMPatternInterface implementation
 //===----------------------------------------------------------------------===//
@@ -498,7 +677,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
 void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                               RewritePatternSet &patterns) {
   patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
-               SendOpLowering, RecvOpLowering>(converter);
+               SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
 }
 
 void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
similarity index 78%
rename from mlir/test/Conversion/MPIToLLVM/ops.mlir
rename to mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 3c1b344efd50b..249ef195e8f5c 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -1,13 +1,13 @@
 // RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
 
 // COM: Test MPICH ABI
-// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH">} {
+// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
 // CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
-module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 
   // CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
   func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
@@ -73,7 +73,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
     // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
     %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
-    // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
+    // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+    // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+    // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
+    // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+
+    // CHECK: llvm.call @MPI_Finalize() : () -> i32
     %3 = mpi.finalize : !mpi.retval
 
     return
@@ -83,7 +98,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
 // -----
 
 // COM: Test OpenMPI ABI
-// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
+// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -91,7 +106,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
 // CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
 // CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
-module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
+module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
 
   // CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
   func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {
@@ -157,6 +172,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
     %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
+    // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+    // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+    // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+    // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
+    // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK-NEXT: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+    mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+
     // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/133133


More information about the Mlir-commits mailing list