[Mlir-commits] [mlir] [mlir][mpi] Lowering MPI_Allreduce (PR #133133)
Frank Schlimbach
llvmlistbot at llvm.org
Fri Mar 28 05:06:59 PDT 2025
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/133133
>From 80179cea012a3b1a159b86d851fb244e9d030b3d Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 25 Mar 2025 19:22:40 +0100
Subject: [PATCH 1/4] lowering MPI_Allreduce (MPICH)
---
mlir/include/mlir/Dialect/MPI/IR/MPI.td | 8 +-
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 2 +-
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 166 +++++++++++++++++---
mlir/test/Conversion/MPIToLLVM/ops.mlir | 27 +++-
mlir/test/Dialect/MPI/ops.mlir | 4 +-
5 files changed, 174 insertions(+), 33 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 7c84443e5520d..17ff18c3f7d3c 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -246,12 +246,12 @@ def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
MPI_OpMaxloc,
MPI_OpReplace
]> {
- let genSpecializedAttr = 0;
+// let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mpi";
}
-def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
- let assemblyFormat = "`<` $value `>`";
-}
+// def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
+// let assemblyFormat = "`<` $value `>`";
+// }
#endif // MLIR_DIALECT_MPI_IR_MPI_TD
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index db28bd09678f8..a8267b115b9e6 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -244,7 +244,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
- MPI_OpClassAttr : $op
+ MPI_OpClassEnum : $op
);
let results = (outs Optional<MPI_Retval>:$retval);
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d91f9512ccb8f..49c0c398d32c3 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -47,6 +47,22 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
}
+std::pair<Value, Value> getRawPtrAndSize(const Location loc,
+ ConversionPatternRewriter &rewriter,
+ Value memRef, Type elType) {
+ Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+ Value dataPtr =
+ rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
+ Value offset = rewriter.create<LLVM::ExtractValueOp>(
+ loc, rewriter.getI64Type(), memRef, 2);
+ Value resPtr =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
+ Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
+ ArrayRef<int64_t>{3, 0});
+ size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+ return {resPtr, size};
+}
+
/// When lowering the mpi dialect to functions calls certain details
/// differ between various MPI implementations. This class will provide
/// these in a generic way, depending on the MPI implementation that got
@@ -77,6 +93,12 @@ class MPIImplTraits {
/// type.
virtual Value getDataType(const Location loc,
ConversionPatternRewriter &rewriter, Type type) = 0;
+
+ /// Gets or creates an MPI_Op value which corresponds to the given
+ /// enum value.
+ virtual Value getMPIOp(const Location loc,
+ ConversionPatternRewriter &rewriter,
+ mpi::MPI_OpClassEnum opAttr) = 0;
};
//===----------------------------------------------------------------------===//
@@ -94,6 +116,20 @@ class MPICHImplTraits : public MPIImplTraits {
static constexpr int MPI_UINT16_T = 0x4c00023c;
static constexpr int MPI_UINT32_T = 0x4c00043d;
static constexpr int MPI_UINT64_T = 0x4c00083e;
+ static constexpr int MPI_MAX = 0x58000001;
+ static constexpr int MPI_MIN = 0x58000002;
+ static constexpr int MPI_SUM = 0x58000003;
+ static constexpr int MPI_PROD = 0x58000004;
+ static constexpr int MPI_LAND = 0x58000005;
+ static constexpr int MPI_BAND = 0x58000006;
+ static constexpr int MPI_LOR = 0x58000007;
+ static constexpr int MPI_BOR = 0x58000008;
+ static constexpr int MPI_LXOR = 0x58000009;
+ static constexpr int MPI_BXOR = 0x5800000a;
+ static constexpr int MPI_MINLOC = 0x5800000b;
+ static constexpr int MPI_MAXLOC = 0x5800000c;
+ static constexpr int MPI_REPLACE = 0x5800000d;
+ static constexpr int MPI_NO_OP = 0x5800000e;
public:
using MPIImplTraits::MPIImplTraits;
@@ -136,6 +172,56 @@ class MPICHImplTraits : public MPIImplTraits {
assert(false && "unsupported type");
return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
}
+
+ Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
+ mpi::MPI_OpClassEnum opAttr) override {
+ int32_t op = MPI_NO_OP;
+ switch (opAttr) {
+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+ op = MPI_NO_OP;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAX:
+ op = MPI_MAX;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MIN:
+ op = MPI_MIN;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_SUM:
+ op = MPI_SUM;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_PROD:
+ op = MPI_PROD;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LAND:
+ op = MPI_LAND;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BAND:
+ op = MPI_BAND;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LOR:
+ op = MPI_LOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BOR:
+ op = MPI_BOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LXOR:
+ op = MPI_LXOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BXOR:
+ op = MPI_BXOR;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
+ op = MPI_MINLOC;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+ op = MPI_MAXLOC;
+ break;
+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
+ op = MPI_REPLACE;
+ break;
+ }
+ return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
+ }
};
//===----------------------------------------------------------------------===//
@@ -214,6 +300,12 @@ class OMPIImplTraits : public MPIImplTraits {
loc, LLVM::LLVMPointerType::get(context),
SymbolRefAttr::get(context, mtype));
}
+
+ Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
+ mpi::MPI_OpClassEnum opAttr) override {
+ llvm_unreachable("getMPIOp not implemented for OpenMPI");
+ return Value();
+ }
};
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
@@ -365,8 +457,6 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
- Type i64 = rewriter.getI64Type();
- Value memRef = adaptor.getRef();
Type elemType = op.getRef().getType().getElementType();
// ptrType `!llvm.ptr`
@@ -376,14 +466,8 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();
// get MPI_COMM_WORLD, dataType and pointer
- Value dataPtr =
- rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
- Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
- dataPtr =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
- Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+ auto [dataPtr, size] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -425,7 +509,6 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
- Value memRef = adaptor.getRef();
Type elemType = op.getRef().getType().getElementType();
// ptrType `!llvm.ptr`
@@ -435,14 +518,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
- Value dataPtr =
- rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
- Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
- dataPtr =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
- Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+ auto [dataPtr, size] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
@@ -474,6 +551,55 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
}
};
+//===----------------------------------------------------------------------===//
+// AllReduceOpLowering
+//===----------------------------------------------------------------------===//
+
+struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Type i32 = rewriter.getI32Type();
+ Type elemType = op.getSendbuf().getType().getElementType();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ auto [sendPtr, sendSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
+ auto [recvPtr, recvSize] =
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+ Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+ Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
+ Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+ // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
+ // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
+ commWorld.getType()});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
+
+ // replace op with function call
+ auto funcCall = rewriter.create<LLVM::CallOp>(
+ loc, funcDecl,
+ ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
+
+ if (op.getRetval())
+ rewriter.replaceOp(op, funcCall.getResult());
+ else
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//
@@ -498,7 +624,7 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
- SendOpLowering, RecvOpLowering>(converter);
+ SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
index 3c1b344efd50b..6917a6fa5798c 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -1,13 +1,13 @@
// RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
// COM: Test MPICH ABI
-// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH">} {
+// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
-module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
@@ -72,8 +72,23 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
-
- // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
+
+ // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+ // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
+ // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+
+ // CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
return
@@ -83,7 +98,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
// -----
// COM: Test OpenMPI ABI
-// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
+// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -91,7 +106,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
-module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
+module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {
diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir
index f23a7e18a2ee9..30e0032cde508 100644
--- a/mlir/test/Dialect/MPI/ops.mlir
+++ b/mlir/test/Dialect/MPI/ops.mlir
@@ -49,10 +49,10 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
%err7 = mpi.barrier : !mpi.retval
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
- mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
+ mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32>
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval
- %err8 = mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ %err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
// CHECK-NEXT: %7 = mpi.finalize : !mpi.retval
%rval = mpi.finalize : !mpi.retval
>From 8aeb96db613df1473599d81414515321b042c632 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 26 Mar 2025 18:41:09 +0100
Subject: [PATCH 2/4] implementing OMPIImplTraits::getMPIOp
---
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 61 +++++++++++++++++--
.../MPIToLLVM/{ops.mlir => mpitollvm.mlir} | 30 ++++++---
.../Dialect/MPI/{ops.mlir => mpiops.mlir} | 4 +-
3 files changed, 82 insertions(+), 13 deletions(-)
rename mlir/test/Conversion/MPIToLLVM/{ops.mlir => mpitollvm.mlir} (87%)
rename mlir/test/Dialect/MPI/{ops.mlir => mpiops.mlir} (93%)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 49c0c398d32c3..4e0f59305a647 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -291,10 +291,10 @@ class OMPIImplTraits : public MPIImplTraits {
auto context = rewriter.getContext();
// get external opaque struct pointer type
- auto commStructT =
+ auto typeStructT =
LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
// make sure global op definition exists
- getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
+ getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
// get address of symbol
return rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context),
@@ -303,8 +303,61 @@ class OMPIImplTraits : public MPIImplTraits {
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
mpi::MPI_OpClassEnum opAttr) override {
- llvm_unreachable("getMPIOp not implemented for OpenMPI");
- return Value();
+ StringRef op;
+ switch (opAttr) {
+ case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+ op = "ompi_mpi_no_op";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAX:
+ op = "ompi_mpi_max";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MIN:
+ op = "ompi_mpi_min";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_SUM:
+ op = "ompi_mpi_sum";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_PROD:
+ op = "ompi_mpi_prod";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LAND:
+ op = "ompi_mpi_land";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BAND:
+ op = "ompi_mpi_band";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LOR:
+ op = "ompi_mpi_lor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BOR:
+ op = "ompi_mpi_bor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_LXOR:
+ op = "ompi_mpi_lxor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_BXOR:
+ op = "ompi_mpi_bxor";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MINLOC:
+ op = "ompi_mpi_minloc";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+ op = "ompi_mpi_maxloc";
+ break;
+ case mpi::MPI_OpClassEnum::MPI_REPLACE:
+ op = "ompi_mpi_replace";
+ break;
+ }
+ auto context = rewriter.getContext();
+ // get external opaque struct pointer type
+ auto opStructT =
+ LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
+ // make sure global op definition exists
+ getOrDefineExternalStruct(loc, rewriter, op, opStructT);
+ // get address of symbol
+ return rewriter.create<LLVM::AddressOfOp>(
+ loc, LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, op));
}
};
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
similarity index 87%
rename from mlir/test/Conversion/MPIToLLVM/ops.mlir
rename to mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 6917a6fa5798c..249ef195e8f5c 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -72,16 +72,16 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
-
- // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+
+ // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
- // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
// CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
@@ -172,6 +172,22 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+ // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+ // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
+ // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK-NEXT: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+ mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
similarity index 93%
rename from mlir/test/Dialect/MPI/ops.mlir
rename to mlir/test/Dialect/MPI/mpiops.mlir
index 30e0032cde508..fb4333611a246 100644
--- a/mlir/test/Dialect/MPI/ops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -48,10 +48,10 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
%err7 = mpi.barrier : !mpi.retval
- // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
+ // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32>
- // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
%err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
// CHECK-NEXT: %7 = mpi.finalize : !mpi.retval
>From 6d293c5f8ab258712134cbff674ad3a680b14d14 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 26 Mar 2025 18:53:06 +0100
Subject: [PATCH 3/4] cleanup
---
mlir/include/mlir/Dialect/MPI/IR/MPI.td | 5 -----
1 file changed, 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 17ff18c3f7d3c..f2837e71df060 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -246,12 +246,7 @@ def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
MPI_OpMaxloc,
MPI_OpReplace
]> {
-// let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mpi";
}
-// def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
-// let assemblyFormat = "`<` $value `>`";
-// }
-
#endif // MLIR_DIALECT_MPI_IR_MPI_TD
>From 838514cffb5759dd23b1a99b04c7b111cd9ab0d2 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 28 Mar 2025 13:06:46 +0100
Subject: [PATCH 4/4] adding missing check
---
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 55 ++++++++++---------
1 file changed, 28 insertions(+), 27 deletions(-)
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 249ef195e8f5c..6b2b7c94b098b 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -73,19 +73,20 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
- // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
- // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
- // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
- // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+ // CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
+ // CHECK: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ // CHECK-NEXT: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
// CHECK: llvm.call @MPI_Finalize() : () -> i32
@@ -172,20 +173,20 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
- // CHECK-NEXT: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK-NEXT: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
- // CHECK-NEXT: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK-NEXT: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK-NEXT: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
- // CHECK-NEXT: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
- // CHECK-NEXT: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
- // CHECK-NEXT: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
- // CHECK-NEXT: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+ // CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+ // CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
+ // CHECK: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
More information about the Mlir-commits
mailing list