[Mlir-commits] [mlir] [mlir][mpi] Lowering MPI_Allreduce (PR #133133)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 26 10:54:23 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
Adding lowering of MPI_Allreduce.
FYI: @<!-- -->tkarna @<!-- -->mofeing
---
Patch is 21.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133133.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPI.td (-5)
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+1-1)
- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+201-22)
- (renamed) mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir (+36-5)
- (renamed) mlir/test/Dialect/MPI/mpiops.mlir (+4-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 7c84443e5520d..f2837e71df060 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -246,12 +246,7 @@ def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
MPI_OpMaxloc,
MPI_OpReplace
]> {
- let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mpi";
}
-def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
#endif // MLIR_DIALECT_MPI_IR_MPI_TD
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index db28bd09678f8..a8267b115b9e6 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -244,7 +244,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
- MPI_OpClassAttr : $op
+ MPI_OpClassEnum : $op
);
let results = (outs Optional<MPI_Retval>:$retval);
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d91f9512ccb8f..4e0f59305a647 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
}
+std::pair<Value, Value> getRawPtrAndSize(const Location loc,
+ ConversionPatternRewriter &rewriter,
+ Value memRef, Type elType) {
+ Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+ Value dataPtr =
+ rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
+ Value offset = rewriter.create<LLVM::ExtractValueOp>(
+ loc, rewriter.getI64Type(), memRef, 2);
+ Value resPtr =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
+ Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
+ ArrayRef<int64_t>{3, 0});
+ size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+ return {resPtr, size};
+}
+
/// When lowering the mpi dialect to functions calls certain details
/// differ between various MPI implementations. This class will provide
/// these in a generic way, depending on the MPI implementation that got
@@ -77,6 +93,12 @@ class MPIImplTraits {
/// type.
virtual Value getDataType(const Location loc,
ConversionPatternRewriter &rewriter, Type type) = 0;
+
+ /// Gets or creates an MPI_Op value which corresponds to the given
+ /// enum value.
+ virtual Value getMPIOp(const Location loc,
+ ConversionPatternRewriter &rewriter,
+ mpi::MPI_OpClassEnum opAttr) = 0;
};
//===----------------------------------------------------------------------===//
@@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
static constexpr int MPI_UINT16_T = 0x4c00023c;
static constexpr int MPI_UINT32_T = 0x4c00043d;
static constexpr int MPI_UINT64_T = 0x4c00083e;
+ static constexpr int MPI_MAX = 0x58000001;
+ static constexpr int MPI_MIN = 0x58000002;
+ static constexpr int MPI_SUM = 0x58000003;
+ static constexpr int MPI_PROD = 0x58000004;
+ static constexpr int MPI_LAND = 0x58000005;
+ static constexpr int MPI_BAND = 0x58000006;
+ static constexpr int MPI_LOR = 0x58000007;
+ static constexpr int MPI_BOR = 0x58000008;
+ static constexpr int MPI_LXOR = 0x58000009;
+ static constexpr int MPI_BXOR = 0x5800000a;
+ static constexpr int MPI_MINLOC = 0x5800000b;
+ static constexpr int MPI_MAXLOC = 0x5800000c;
+ static constexpr int MPI_REPLACE = 0x5800000d;
+ static constexpr int MPI_NO_OP = 0x5800000e;
public:
using MPIImplTraits::MPIImplTraits;
@@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
assert(false && "unsupported type");
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
}
+
+ Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
+ mpi::MPI_OpClassEnum opAttr) override {
+ int32_t op = MPI_NO_OP;
+ switch (opAttr) {
+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+ op = MPI_NO_OP;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAX:
+ op = MPI_MAX;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MIN:
+ op = MPI_MIN;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_SUM:
+ op = MPI_SUM;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_PROD:
+ op = MPI_PROD;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LAND:
+ op = MPI_LAND;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BAND:
+ op = MPI_BAND;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LOR:
+ op = MPI_LOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BOR:
+ op = MPI_BOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LXOR:
+ op = MPI_LXOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BXOR:
+ op = MPI_BXOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
+ op = MPI_MINLOC;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+ op = MPI_MAXLOC;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
+ op = MPI_REPLACE;
+ break;
+ }
+ return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
+ }
};
//===----------------------------------------------------------------------===//
@@ -205,15 +291,74 @@ class OMPIImplTraits : public MPIImplTraits {
auto context = rewriter.getContext();
// get external opaque struct pointer type
- auto commStructT =
+ auto typeStructT =
LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
// make sure global op definition exists
- getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
+ getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
// get address of symbol
return rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context),
SymbolRefAttr::get(context, mtype));
}
+
+ Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
+ mpi::MPI_OpClassEnum opAttr) override {
+ StringRef op;
+ switch (opAttr) {
+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+ op = "ompi_mpi_no_op";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAX:
+ op = "ompi_mpi_max";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MIN:
+ op = "ompi_mpi_min";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_SUM:
+ op = "ompi_mpi_sum";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_PROD:
+ op = "ompi_mpi_prod";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LAND:
+ op = "ompi_mpi_land";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BAND:
+ op = "ompi_mpi_band";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LOR:
+ op = "ompi_mpi_lor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BOR:
+ op = "ompi_mpi_bor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LXOR:
+ op = "ompi_mpi_lxor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BXOR:
+ op = "ompi_mpi_bxor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
+ op = "ompi_mpi_minloc";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+ op = "ompi_mpi_maxloc";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
+ op = "ompi_mpi_replace";
+ break;
+ }
+ auto context = rewriter.getContext();
+ // get external opaque struct pointer type
+ auto opStructT =
+ LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
+ // make sure global op definition exists
+ getOrDefineExternalStruct(loc, rewriter, op, opStructT);
+ // get address of symbol
+ return rewriter.create<LLVM::AddressOfOp>(
+ loc, LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, op));
+ }
};
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
@@ -365,8 +510,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
- Type i64 = rewriter.getI64Type();
- Value memRef = adaptor.getRef();
Type elemType = op.getRef().getType().getElementType();
// ptrType `!llvm.ptr`
@@ -376,14 +519,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();
// get MPI_COMM_WORLD, dataType and pointer
- Value dataPtr =
- rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
- Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
- dataPtr =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
- Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+ auto [dataPtr, size] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -425,7 +562,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
- Value memRef = adaptor.getRef();
Type elemType = op.getRef().getType().getElementType();
// ptrType `!llvm.ptr`
@@ -435,14 +571,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
- Value dataPtr =
- rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
- Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
- dataPtr =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
- Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+ auto [dataPtr, size] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -474,6 +604,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
}
};
+//===----------------------------------------------------------------------===//
+// AllReduceOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type i32 = rewriter.getI32Type();
+ Type elemType = op.getSendbuf().getType().getElementType();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ auto [sendPtr, sendSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
+ auto [recvPtr, recvSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+ Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+ Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
+ Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+ // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
+ // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
+ commWorld.getType()});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
+
+ // replace op with function call
+ auto funcCall = rewriter.create<LLVM::CallOp>(
+ loc, funcDecl,
+ ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
+
+ if (op.getRetval())
+ rewriter.replaceOp(op, funcCall.getResult());
+ else
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//
@@ -498,7 +677,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
- SendOpLowering, RecvOpLowering>(converter);
+ SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
similarity index 78%
rename from mlir/test/Conversion/MPIToLLVM/ops.mlir
rename to mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 3c1b344efd50b..249ef195e8f5c 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -1,13 +1,13 @@
// RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
// COM: Test MPICH ABI
-// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH">} {
+// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
-module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
@@ -73,7 +73,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
- // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
+ // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+ // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
+ // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+
+ // CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
return
@@ -83,7 +98,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
// -----
// COM: Test OpenMPI ABI
-// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
+// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -91,7 +106,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
-module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
+module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {
@@ -157,6 +172,22 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+ // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+ // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
+ // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK-NEXT: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+ mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/133133
More information about the Mlir-commits
mailing list