[Mlir-commits] [mlir] [mlir][mpi] Lowering Mpi To LLVM (PR #127053)
Frank Schlimbach
llvmlistbot at llvm.org
Thu Feb 13 07:27:41 PST 2025
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/127053
>From 4085cdfaefc62fc9ba21d9c695876c3130a60f6e Mon Sep 17 00:00:00 2001
From: Anton Lydike <me at antonlydike.de>
Date: Fri, 14 Jun 2024 11:26:26 +0100
Subject: [PATCH 1/4] add initial set of lowerings for MPI dialect
---
.../mlir/Conversion/MPIToLLVM/MPIToLLVM.h | 30 +++
mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 2 +-
mlir/include/mlir/InitAllExtensions.h | 2 +
mlir/lib/Conversion/CMakeLists.txt | 1 +
mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt | 17 ++
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 230 ++++++++++++++++++
mlir/test/Conversion/MPIToLLVM/ops.mlir | 40 +++
7 files changed, 321 insertions(+), 1 deletion(-)
create mode 100644 mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
create mode 100644 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
create mode 100644 mlir/test/Conversion/MPIToLLVM/ops.mlir
diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
new file mode 100644
index 0000000000000..940e5e8097318
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
@@ -0,0 +1,30 @@
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MPITOLLVM_H
+#define MLIR_CONVERSION_MPITOLLVM_H
+
+#include "mlir/IR/DialectRegistry.h"
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace mpi {
+void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+
+void registerConvertMPIToLLVMInterface(DialectRegistry ®istry);
+
+} // namespace mpi
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MPITOLLVM_H
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index fafea0eac8bb7..a55d30e778e22 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
//===----------------------------------------------------------------------===//
def MPI_Retval : MPI_Type<"Retval", "retval"> {
- let summary = "MPI function call return value";
+ let summary = "MPI function call return value (!mpi.retval)";
let description = [{
This type represents a return value from an MPI function call.
This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 887db344ed88b..6ab23ff86b3c6 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -14,6 +14,7 @@
#ifndef MLIR_INITALLEXTENSIONS_H_
#define MLIR_INITALLEXTENSIONS_H_
+#include "Conversion/MPIToLLVM/MPIToLLVM.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
@@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertFuncToLLVMInterface(registry);
index::registerConvertIndexToLLVMInterface(registry);
registerConvertMathToLLVMInterface(registry);
+ mpi::registerConvertMPIToLLVMInterface(registry);
registerConvertMemRefToLLVMInterface(registry);
registerConvertNVVMToLLVMInterface(registry);
registerConvertOpenMPToLLVMInterface(registry);
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 0bd08ec6333e6..3dc7472584cf9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
add_subdirectory(MeshToMPI)
+add_subdirectory(MPIToLLVM)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..f81fb25e56840
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRMPIToLLVM
+ MPIToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRMPIDialect
+ )
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
new file mode 100644
index 0000000000000..d87a10aab8f49
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -0,0 +1,230 @@
+//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Pass/Pass.h"
+
+#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>
+
+using namespace mlir;
+
+namespace {
+
+struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+// TODO: this was copied from GPUOpsLowering.cpp:288
+// is this okay, or should this be moved to some common file?
+LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
+ ConversionPatternRewriter &rewriter,
+ StringRef name,
+ LLVM::LLVMFunctionType type) {
+ LLVM::LLVMFuncOp ret;
+ if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
+ LLVM::Linkage::External);
+ }
+ return ret;
+}
+
+// TODO: this is pretty close to getOrDefineFunction, can probably be factored
+LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
+ ConversionPatternRewriter &rewriter,
+ StringRef name,
+ LLVM::LLVMStructType type) {
+ LLVM::GlobalOp ret;
+ if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ ret = rewriter.create<LLVM::GlobalOp>(
+ loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
+ /*value=*/Attribute(), /*alignment=*/0, 0);
+ }
+ return ret;
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// InitOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // get loc
+ auto loc = op.getLoc();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+ // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
+ auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+ Value llvmnull = nullPtrOp.getRes();
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
+ auto initFuncType =
+ LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
+
+ // replace init with function call
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
+ ValueRange{llvmnull, llvmnull});
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FinalizeOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // get loc
+ auto loc = op.getLoc();
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // LLVM Function type representing `i32 MPI_Finalize()`
+ auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
+ "MPI_Finalize", initFuncType);
+
+ // replace init with function call
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CommRankLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // get some helper vars
+ auto loc = op.getLoc();
+ auto context = rewriter.getContext();
+ auto i32 = rewriter.getI32Type();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+
+ // get external opaque struct pointer type
+ auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // make sure global op definition exists
+ getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
+ commStructT);
+
+ // get address of @MPI_COMM_WORLD
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+ auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
+ auto commWorld = rewriter.create<LLVM::AddressOfOp>(
+ loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
+
+ // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+ auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
+
+ // replace init with function call
+ auto callOp = rewriter.create<LLVM::CallOp>(
+ loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});
+
+ // load the rank into a register
+ auto loadedRank =
+ rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
+
+ // if retval is checked, replace uses of retval with the results from the call
+ // op
+ SmallVector<Value> replacements;
+ if (op.getRetval()) {
+ replacements.push_back(callOp.getResult());
+ }
+ // replace all uses, then erase op
+ replacements.push_back(loadedRank.getRes());
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.add<InitOpLowering>(converter);
+ patterns.add<CommRankOpLowering>(converter);
+ patterns.add<FinalizeOpLowering>(converter);
+}
+
+//===----------------------------------------------------------------------===//
+// ConvertToLLVMPatternInterface implementation
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Implement the interface to convert Func to LLVM.
+struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToLLVMConversionPatterns(
+ ConversionTarget &target, LLVMTypeConverter &typeConverter,
+ RewritePatternSet &patterns) const final {
+ mpi::populateMPIToLLVMConversionPatterns(typeConverter, patterns);
+ }
+};
+} // namespace
+
+void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
+ dialect->addInterfaces<FuncToLLVMDialectInterface>();
+ });
+}
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
new file mode 100644
index 0000000000000..a7a44ad24909a
--- /dev/null
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -convert-to-llvm %s | FileCheck %s
+
+module {
+// CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
+// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+
+ func.func @mpi_test(%arg0: memref<100xf32>) {
+ %0 = mpi.init : !mpi.retval
+// CHECK: %7 = llvm.mlir.zero : !llvm.ptr
+// CHECK-NEXT: %8 = llvm.call @MPI_Init(%7, %7) : (!llvm.ptr, !llvm.ptr) -> i32
+// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %8 : i32 to !mpi.retval
+
+
+ %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+// CHECK: %10 = llvm.mlir.constant(1 : i32) : i32
+// CHECK-NEXT: %11 = llvm.alloca %10 x i32 : (i32) -> !llvm.ptr
+// CHECK-NEXT: %12 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
+// CHECK-NEXT: %13 = llvm.call @MPI_Comm_rank(%12, %11) : (!llvm.ptr, !llvm.ptr) -> i32
+// CHECK-NEXT: %14 = llvm.load %11 : !llvm.ptr -> i32
+// CHECK-NEXT: %15 = builtin.unrealized_conversion_cast %13 : i32 to !mpi.retval
+
+ mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+ %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+ mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+ %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+ %3 = mpi.finalize : !mpi.retval
+// CHECK: %18 = llvm.call @MPI_Finalize() : () -> i32
+
+ %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+
+ %5 = mpi.error_class %0 : !mpi.retval
+ return
+ }
+}
>From 64d9500519f0cc8f162b978d937c3bc73b175fc8 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 19 Nov 2024 18:02:18 +0100
Subject: [PATCH 2/4] lowering MPI_send and MPI_recv; some refactoring fixing
MPI_Send/Recv signature: added size fixing MPIR ops.mlir tests MPI: fixing
enum names and merge conflicts separate MPIImplTraits.h
---
.../mlir/Conversion/MPIToLLVM/MPIToLLVM.h | 1 +
mlir/include/mlir/Dialect/MPI/IR/MPI.td | 154 +++----
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 16 +-
mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt | 3 +
mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h | 119 ++++++
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 393 +++++++++++-------
mlir/test/Conversion/MPIToLLVM/ops.mlir | 74 +++-
7 files changed, 512 insertions(+), 248 deletions(-)
create mode 100644 mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
index 940e5e8097318..8d2698aa91c7c 100644
--- a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
+++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
@@ -19,6 +19,7 @@ class RewritePatternSet;
#include "mlir/Conversion/Passes.h.inc"
namespace mpi {
+
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 7c84443e5520d..df0cf9d518faf 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -42,104 +42,104 @@ def MPI_Dialect : Dialect {
// Error classes enum:
//===----------------------------------------------------------------------===//
-def MPI_CodeSuccess : I32EnumAttrCase<"MPI_SUCCESS", 0, "MPI_SUCCESS">;
-def MPI_CodeErrAccess : I32EnumAttrCase<"MPI_ERR_ACCESS", 1, "MPI_ERR_ACCESS">;
-def MPI_CodeErrAmode : I32EnumAttrCase<"MPI_ERR_AMODE", 2, "MPI_ERR_AMODE">;
-def MPI_CodeErrArg : I32EnumAttrCase<"MPI_ERR_ARG", 3, "MPI_ERR_ARG">;
-def MPI_CodeErrAssert : I32EnumAttrCase<"MPI_ERR_ASSERT", 4, "MPI_ERR_ASSERT">;
+def MPI_CodeSuccess : I32EnumAttrCase<"_MPI_SUCCESS", 0, "MPI_SUCCESS">;
+def MPI_CodeErrAccess : I32EnumAttrCase<"_MPI_ERR_ACCESS", 1, "MPI_ERR_ACCESS">;
+def MPI_CodeErrAmode : I32EnumAttrCase<"_MPI_ERR_AMODE", 2, "MPI_ERR_AMODE">;
+def MPI_CodeErrArg : I32EnumAttrCase<"_MPI_ERR_ARG", 3, "MPI_ERR_ARG">;
+def MPI_CodeErrAssert : I32EnumAttrCase<"_MPI_ERR_ASSERT", 4, "MPI_ERR_ASSERT">;
def MPI_CodeErrBadFile
- : I32EnumAttrCase<"MPI_ERR_BAD_FILE", 5, "MPI_ERR_BAD_FILE">;
-def MPI_CodeErrBase : I32EnumAttrCase<"MPI_ERR_BASE", 6, "MPI_ERR_BASE">;
-def MPI_CodeErrBuffer : I32EnumAttrCase<"MPI_ERR_BUFFER", 7, "MPI_ERR_BUFFER">;
-def MPI_CodeErrComm : I32EnumAttrCase<"MPI_ERR_COMM", 8, "MPI_ERR_COMM">;
+ : I32EnumAttrCase<"_MPI_ERR_BAD_FILE", 5, "MPI_ERR_BAD_FILE">;
+def MPI_CodeErrBase : I32EnumAttrCase<"_MPI_ERR_BASE", 6, "MPI_ERR_BASE">;
+def MPI_CodeErrBuffer : I32EnumAttrCase<"_MPI_ERR_BUFFER", 7, "MPI_ERR_BUFFER">;
+def MPI_CodeErrComm : I32EnumAttrCase<"_MPI_ERR_COMM", 8, "MPI_ERR_COMM">;
def MPI_CodeErrConversion
- : I32EnumAttrCase<"MPI_ERR_CONVERSION", 9, "MPI_ERR_CONVERSION">;
-def MPI_CodeErrCount : I32EnumAttrCase<"MPI_ERR_COUNT", 10, "MPI_ERR_COUNT">;
-def MPI_CodeErrDims : I32EnumAttrCase<"MPI_ERR_DIMS", 11, "MPI_ERR_DIMS">;
-def MPI_CodeErrDisp : I32EnumAttrCase<"MPI_ERR_DISP", 12, "MPI_ERR_DISP">;
+ : I32EnumAttrCase<"_MPI_ERR_CONVERSION", 9, "MPI_ERR_CONVERSION">;
+def MPI_CodeErrCount : I32EnumAttrCase<"_MPI_ERR_COUNT", 10, "MPI_ERR_COUNT">;
+def MPI_CodeErrDims : I32EnumAttrCase<"_MPI_ERR_DIMS", 11, "MPI_ERR_DIMS">;
+def MPI_CodeErrDisp : I32EnumAttrCase<"_MPI_ERR_DISP", 12, "MPI_ERR_DISP">;
def MPI_CodeErrDupDatarep
- : I32EnumAttrCase<"MPI_ERR_DUP_DATAREP", 13, "MPI_ERR_DUP_DATAREP">;
+ : I32EnumAttrCase<"_MPI_ERR_DUP_DATAREP", 13, "MPI_ERR_DUP_DATAREP">;
def MPI_CodeErrErrhandler
- : I32EnumAttrCase<"MPI_ERR_ERRHANDLER", 14, "MPI_ERR_ERRHANDLER">;
-def MPI_CodeErrFile : I32EnumAttrCase<"MPI_ERR_FILE", 15, "MPI_ERR_FILE">;
+ : I32EnumAttrCase<"_MPI_ERR_ERRHANDLER", 14, "MPI_ERR_ERRHANDLER">;
+def MPI_CodeErrFile : I32EnumAttrCase<"_MPI_ERR_FILE", 15, "MPI_ERR_FILE">;
def MPI_CodeErrFileExists
- : I32EnumAttrCase<"MPI_ERR_FILE_EXISTS", 16, "MPI_ERR_FILE_EXISTS">;
+ : I32EnumAttrCase<"_MPI_ERR_FILE_EXISTS", 16, "MPI_ERR_FILE_EXISTS">;
def MPI_CodeErrFileInUse
- : I32EnumAttrCase<"MPI_ERR_FILE_IN_USE", 17, "MPI_ERR_FILE_IN_USE">;
-def MPI_CodeErrGroup : I32EnumAttrCase<"MPI_ERR_GROUP", 18, "MPI_ERR_GROUP">;
-def MPI_CodeErrInfo : I32EnumAttrCase<"MPI_ERR_INFO", 19, "MPI_ERR_INFO">;
+ : I32EnumAttrCase<"_MPI_ERR_FILE_IN_USE", 17, "MPI_ERR_FILE_IN_USE">;
+def MPI_CodeErrGroup : I32EnumAttrCase<"_MPI_ERR_GROUP", 18, "MPI_ERR_GROUP">;
+def MPI_CodeErrInfo : I32EnumAttrCase<"_MPI_ERR_INFO", 19, "MPI_ERR_INFO">;
def MPI_CodeErrInfoKey
- : I32EnumAttrCase<"MPI_ERR_INFO_KEY", 20, "MPI_ERR_INFO_KEY">;
+ : I32EnumAttrCase<"_MPI_ERR_INFO_KEY", 20, "MPI_ERR_INFO_KEY">;
def MPI_CodeErrInfoNokey
- : I32EnumAttrCase<"MPI_ERR_INFO_NOKEY", 21, "MPI_ERR_INFO_NOKEY">;
+ : I32EnumAttrCase<"_MPI_ERR_INFO_NOKEY", 21, "MPI_ERR_INFO_NOKEY">;
def MPI_CodeErrInfoValue
- : I32EnumAttrCase<"MPI_ERR_INFO_VALUE", 22, "MPI_ERR_INFO_VALUE">;
+ : I32EnumAttrCase<"_MPI_ERR_INFO_VALUE", 22, "MPI_ERR_INFO_VALUE">;
def MPI_CodeErrInStatus
- : I32EnumAttrCase<"MPI_ERR_IN_STATUS", 23, "MPI_ERR_IN_STATUS">;
-def MPI_CodeErrIntern : I32EnumAttrCase<"MPI_ERR_INTERN", 24, "MPI_ERR_INTERN">;
-def MPI_CodeErrIo : I32EnumAttrCase<"MPI_ERR_IO", 25, "MPI_ERR_IO">;
-def MPI_CodeErrKeyval : I32EnumAttrCase<"MPI_ERR_KEYVAL", 26, "MPI_ERR_KEYVAL">;
+ : I32EnumAttrCase<"_MPI_ERR_IN_STATUS", 23, "MPI_ERR_IN_STATUS">;
+def MPI_CodeErrIntern : I32EnumAttrCase<"_MPI_ERR_INTERN", 24, "MPI_ERR_INTERN">;
+def MPI_CodeErrIo : I32EnumAttrCase<"_MPI_ERR_IO", 25, "MPI_ERR_IO">;
+def MPI_CodeErrKeyval : I32EnumAttrCase<"_MPI_ERR_KEYVAL", 26, "MPI_ERR_KEYVAL">;
def MPI_CodeErrLocktype
- : I32EnumAttrCase<"MPI_ERR_LOCKTYPE", 27, "MPI_ERR_LOCKTYPE">;
-def MPI_CodeErrName : I32EnumAttrCase<"MPI_ERR_NAME", 28, "MPI_ERR_NAME">;
-def MPI_CodeErrNoMem : I32EnumAttrCase<"MPI_ERR_NO_MEM", 29, "MPI_ERR_NO_MEM">;
+ : I32EnumAttrCase<"_MPI_ERR_LOCKTYPE", 27, "MPI_ERR_LOCKTYPE">;
+def MPI_CodeErrName : I32EnumAttrCase<"_MPI_ERR_NAME", 28, "MPI_ERR_NAME">;
+def MPI_CodeErrNoMem : I32EnumAttrCase<"_MPI_ERR_NO_MEM", 29, "MPI_ERR_NO_MEM">;
def MPI_CodeErrNoSpace
- : I32EnumAttrCase<"MPI_ERR_NO_SPACE", 30, "MPI_ERR_NO_SPACE">;
+ : I32EnumAttrCase<"_MPI_ERR_NO_SPACE", 30, "MPI_ERR_NO_SPACE">;
def MPI_CodeErrNoSuchFile
- : I32EnumAttrCase<"MPI_ERR_NO_SUCH_FILE", 31, "MPI_ERR_NO_SUCH_FILE">;
+ : I32EnumAttrCase<"_MPI_ERR_NO_SUCH_FILE", 31, "MPI_ERR_NO_SUCH_FILE">;
def MPI_CodeErrNotSame
- : I32EnumAttrCase<"MPI_ERR_NOT_SAME", 32, "MPI_ERR_NOT_SAME">;
-def MPI_CodeErrOp : I32EnumAttrCase<"MPI_ERR_OP", 33, "MPI_ERR_OP">;
-def MPI_CodeErrOther : I32EnumAttrCase<"MPI_ERR_OTHER", 34, "MPI_ERR_OTHER">;
+ : I32EnumAttrCase<"_MPI_ERR_NOT_SAME", 32, "MPI_ERR_NOT_SAME">;
+def MPI_CodeErrOp : I32EnumAttrCase<"_MPI_ERR_OP", 33, "MPI_ERR_OP">;
+def MPI_CodeErrOther : I32EnumAttrCase<"_MPI_ERR_OTHER", 34, "MPI_ERR_OTHER">;
def MPI_CodeErrPending
- : I32EnumAttrCase<"MPI_ERR_PENDING", 35, "MPI_ERR_PENDING">;
-def MPI_CodeErrPort : I32EnumAttrCase<"MPI_ERR_PORT", 36, "MPI_ERR_PORT">;
+ : I32EnumAttrCase<"_MPI_ERR_PENDING", 35, "MPI_ERR_PENDING">;
+def MPI_CodeErrPort : I32EnumAttrCase<"_MPI_ERR_PORT", 36, "MPI_ERR_PORT">;
def MPI_CodeErrProcAborted
- : I32EnumAttrCase<"MPI_ERR_PROC_ABORTED", 37, "MPI_ERR_PROC_ABORTED">;
-def MPI_CodeErrQuota : I32EnumAttrCase<"MPI_ERR_QUOTA", 38, "MPI_ERR_QUOTA">;
-def MPI_CodeErrRank : I32EnumAttrCase<"MPI_ERR_RANK", 39, "MPI_ERR_RANK">;
+ : I32EnumAttrCase<"_MPI_ERR_PROC_ABORTED", 37, "MPI_ERR_PROC_ABORTED">;
+def MPI_CodeErrQuota : I32EnumAttrCase<"_MPI_ERR_QUOTA", 38, "MPI_ERR_QUOTA">;
+def MPI_CodeErrRank : I32EnumAttrCase<"_MPI_ERR_RANK", 39, "MPI_ERR_RANK">;
def MPI_CodeErrReadOnly
- : I32EnumAttrCase<"MPI_ERR_READ_ONLY", 40, "MPI_ERR_READ_ONLY">;
+ : I32EnumAttrCase<"_MPI_ERR_READ_ONLY", 40, "MPI_ERR_READ_ONLY">;
def MPI_CodeErrRequest
- : I32EnumAttrCase<"MPI_ERR_REQUEST", 41, "MPI_ERR_REQUEST">;
+ : I32EnumAttrCase<"_MPI_ERR_REQUEST", 41, "MPI_ERR_REQUEST">;
def MPI_CodeErrRmaAttach
- : I32EnumAttrCase<"MPI_ERR_RMA_ATTACH", 42, "MPI_ERR_RMA_ATTACH">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_ATTACH", 42, "MPI_ERR_RMA_ATTACH">;
def MPI_CodeErrRmaConflict
- : I32EnumAttrCase<"MPI_ERR_RMA_CONFLICT", 43, "MPI_ERR_RMA_CONFLICT">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_CONFLICT", 43, "MPI_ERR_RMA_CONFLICT">;
def MPI_CodeErrRmaFlavor
- : I32EnumAttrCase<"MPI_ERR_RMA_FLAVOR", 44, "MPI_ERR_RMA_FLAVOR">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_FLAVOR", 44, "MPI_ERR_RMA_FLAVOR">;
def MPI_CodeErrRmaRange
- : I32EnumAttrCase<"MPI_ERR_RMA_RANGE", 45, "MPI_ERR_RMA_RANGE">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_RANGE", 45, "MPI_ERR_RMA_RANGE">;
def MPI_CodeErrRmaShared
- : I32EnumAttrCase<"MPI_ERR_RMA_SHARED", 46, "MPI_ERR_RMA_SHARED">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_SHARED", 46, "MPI_ERR_RMA_SHARED">;
def MPI_CodeErrRmaSync
- : I32EnumAttrCase<"MPI_ERR_RMA_SYNC", 47, "MPI_ERR_RMA_SYNC">;
-def MPI_CodeErrRoot : I32EnumAttrCase<"MPI_ERR_ROOT", 48, "MPI_ERR_ROOT">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_SYNC", 47, "MPI_ERR_RMA_SYNC">;
+def MPI_CodeErrRoot : I32EnumAttrCase<"_MPI_ERR_ROOT", 48, "MPI_ERR_ROOT">;
def MPI_CodeErrService
- : I32EnumAttrCase<"MPI_ERR_SERVICE", 49, "MPI_ERR_SERVICE">;
+ : I32EnumAttrCase<"_MPI_ERR_SERVICE", 49, "MPI_ERR_SERVICE">;
def MPI_CodeErrSession
- : I32EnumAttrCase<"MPI_ERR_SESSION", 50, "MPI_ERR_SESSION">;
-def MPI_CodeErrSize : I32EnumAttrCase<"MPI_ERR_SIZE", 51, "MPI_ERR_SIZE">;
-def MPI_CodeErrSpawn : I32EnumAttrCase<"MPI_ERR_SPAWN", 52, "MPI_ERR_SPAWN">;
-def MPI_CodeErrTag : I32EnumAttrCase<"MPI_ERR_TAG", 53, "MPI_ERR_TAG">;
+ : I32EnumAttrCase<"_MPI_ERR_SESSION", 50, "MPI_ERR_SESSION">;
+def MPI_CodeErrSize : I32EnumAttrCase<"_MPI_ERR_SIZE", 51, "MPI_ERR_SIZE">;
+def MPI_CodeErrSpawn : I32EnumAttrCase<"_MPI_ERR_SPAWN", 52, "MPI_ERR_SPAWN">;
+def MPI_CodeErrTag : I32EnumAttrCase<"_MPI_ERR_TAG", 53, "MPI_ERR_TAG">;
def MPI_CodeErrTopology
- : I32EnumAttrCase<"MPI_ERR_TOPOLOGY", 54, "MPI_ERR_TOPOLOGY">;
+ : I32EnumAttrCase<"_MPI_ERR_TOPOLOGY", 54, "MPI_ERR_TOPOLOGY">;
def MPI_CodeErrTruncate
- : I32EnumAttrCase<"MPI_ERR_TRUNCATE", 55, "MPI_ERR_TRUNCATE">;
-def MPI_CodeErrType : I32EnumAttrCase<"MPI_ERR_TYPE", 56, "MPI_ERR_TYPE">;
+ : I32EnumAttrCase<"_MPI_ERR_TRUNCATE", 55, "MPI_ERR_TRUNCATE">;
+def MPI_CodeErrType : I32EnumAttrCase<"_MPI_ERR_TYPE", 56, "MPI_ERR_TYPE">;
def MPI_CodeErrUnknown
- : I32EnumAttrCase<"MPI_ERR_UNKNOWN", 57, "MPI_ERR_UNKNOWN">;
+ : I32EnumAttrCase<"_MPI_ERR_UNKNOWN", 57, "MPI_ERR_UNKNOWN">;
def MPI_CodeErrUnsupportedDatarep
- : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_DATAREP", 58,
+ : I32EnumAttrCase<"_MPI_ERR_UNSUPPORTED_DATAREP", 58,
"MPI_ERR_UNSUPPORTED_DATAREP">;
def MPI_CodeErrUnsupportedOperation
- : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_OPERATION", 59,
+ : I32EnumAttrCase<"_MPI_ERR_UNSUPPORTED_OPERATION", 59,
"MPI_ERR_UNSUPPORTED_OPERATION">;
def MPI_CodeErrValueTooLarge
- : I32EnumAttrCase<"MPI_ERR_VALUE_TOO_LARGE", 60, "MPI_ERR_VALUE_TOO_LARGE">;
-def MPI_CodeErrWin : I32EnumAttrCase<"MPI_ERR_WIN", 61, "MPI_ERR_WIN">;
+ : I32EnumAttrCase<"_MPI_ERR_VALUE_TOO_LARGE", 60, "MPI_ERR_VALUE_TOO_LARGE">;
+def MPI_CodeErrWin : I32EnumAttrCase<"_MPI_ERR_WIN", 61, "MPI_ERR_WIN">;
def MPI_CodeErrLastcode
- : I32EnumAttrCase<"MPI_ERR_LASTCODE", 62, "MPI_ERR_LASTCODE">;
+ : I32EnumAttrCase<"_MPI_ERR_LASTCODE", 62, "MPI_ERR_LASTCODE">;
def MPI_ErrorClassEnum
: I32EnumAttr<"MPI_ErrorClassEnum", "MPI error class name", [
@@ -215,20 +215,20 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
let assemblyFormat = "`<` $value `>`";
}
-def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">;
-def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">;
-def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">;
-def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">;
-def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">;
-def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">;
-def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">;
-def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">;
-def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">;
-def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">;
-def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">;
-def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
-def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
-def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
+def MPI_OpNull : I32EnumAttrCase<"_MPI_OP_NULL", 0, "MPI_OP_NULL">;
+def MPI_OpMax : I32EnumAttrCase<"_MPI_MAX", 1, "MPI_MAX">;
+def MPI_OpMin : I32EnumAttrCase<"_MPI_MIN", 2, "MPI_MIN">;
+def MPI_OpSum : I32EnumAttrCase<"_MPI_SUM", 3, "MPI_SUM">;
+def MPI_OpProd : I32EnumAttrCase<"_MPI_PROD", 4, "MPI_PROD">;
+def MPI_OpLand : I32EnumAttrCase<"_MPI_LAND", 5, "MPI_LAND">;
+def MPI_OpBand : I32EnumAttrCase<"_MPI_BAND", 6, "MPI_BAND">;
+def MPI_OpLor : I32EnumAttrCase<"_MPI_LOR", 7, "MPI_LOR">;
+def MPI_OpBor : I32EnumAttrCase<"_MPI_BOR", 8, "MPI_BOR">;
+def MPI_OpLxor : I32EnumAttrCase<"_MPI_LXOR", 9, "MPI_LXOR">;
+def MPI_OpBxor : I32EnumAttrCase<"_MPI_BXOR", 10, "MPI_BXOR">;
+def MPI_OpMinloc : I32EnumAttrCase<"_MPI_MINLOC", 11, "MPI_MINLOC">;
+def MPI_OpMaxloc : I32EnumAttrCase<"_MPI_MAXLOC", 12, "MPI_MAXLOC">;
+def MPI_OpReplace : I32EnumAttrCase<"_MPI_REPLACE", 13, "MPI_REPLACE">;
def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
MPI_OpNull,
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 284ba72af9768..db28bd09678f8 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -102,13 +102,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
- I32 : $rank
+ I32 : $dest
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
- "type($ref) `,` type($tag) `,` type($rank)"
+ let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
+ "type($ref) `,` type($tag) `,` type($dest)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}
@@ -154,11 +154,11 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
//===----------------------------------------------------------------------===//
def MPI_RecvOp : MPI_Op<"recv", []> {
- let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
+ let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
"MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
- from rank `dest`. The `tag` value and communicator enables the library to
+ from rank `source`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
@@ -172,13 +172,13 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
let arguments = (
ins AnyMemRef : $ref,
- I32 : $tag, I32 : $rank
+ I32 : $tag, I32 : $source
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
- "type($ref) `,` type($tag) `,` type($rank)"
+ let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
+ "type($ref) `,` type($tag) `,` type($source)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
index f81fb25e56840..5d6b82e655b82 100644
--- a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -1,8 +1,11 @@
+find_package(MPI REQUIRED) # MPI_SKIP_COMPILER_WRAPPER TRUE MPI_CXX_SKIP_MPICXX TRUE)
+
add_mlir_conversion_library(MLIRMPIToLLVM
MPIToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
+ ${MPI_C_HEADER_DIR}
DEPENDS
MLIRConversionPassIncGen
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
new file mode 100644
index 0000000000000..5fd88d0aab4ac
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
@@ -0,0 +1,119 @@
+#define MPICH_SKIP_MPICXX 1
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <mpi.h>
+
+namespace {
+
+// when lowerring the mpi dialect to functions calls certain details
+// differ between various MPI implementations. This class will provide
+// these depending on the MPI implementation that got included.
+struct MPIImplTraits {
+ // get/create MPI_COMM_WORLD as a mlir::Value
+ static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+ const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter);
+ // get/create MPI datatype as a mlir::Value which corresponds to the given
+ // mlir::Type
+ static mlir::Value getDataType(const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type type);
+};
+
+// ****************************************************************************
+// Intel MPI
+#ifdef IMPI_DEVICE_EXPORT
+
+mlir::Value
+MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter) {
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+ MPI_COMM_WORLD);
+}
+
+mlir::Value
+MPIImplTraits::getDataType(const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type type) {
+ int32_t mtype = 0;
+ if (type.isF32())
+ mtype = MPI_FLOAT;
+ else if (type.isF64())
+ mtype = MPI_DOUBLE;
+ else if (type.isInteger(64) && !type.isUnsignedInteger())
+ mtype = MPI_LONG;
+ else if (type.isInteger(64))
+ mtype = MPI_UNSIGNED_LONG;
+ else if (type.isInteger(32) && !type.isUnsignedInteger())
+ mtype = MPI_INT;
+ else if (type.isInteger(32))
+ mtype = MPI_UNSIGNED;
+ else if (type.isInteger(16) && !type.isUnsignedInteger())
+ mtype = MPI_SHORT;
+ else if (type.isInteger(16))
+ mtype = MPI_UNSIGNED_SHORT;
+ else if (type.isInteger(8) && !type.isUnsignedInteger())
+ mtype = MPI_CHAR;
+ else if (type.isInteger(8))
+ mtype = MPI_UNSIGNED_CHAR;
+ else
+ assert(false && "unsupported type");
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+ mtype);
+}
+
+// ****************************************************************************
+// OpenMPI
+#elif defined(OPEN_MPI) && OPEN_MPI == 1
+
+// TODO: this is pretty close to getOrDefineFunction, can probably be factored
+static mlir::LLVM::GlobalOp
+getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::StringRef name,
+ mlir::LLVM::LLVMStructType type) {
+ mlir::LLVM::GlobalOp ret;
+ if (!(ret = moduleOp.lookupSymbol<mlir::LLVM::GlobalOp>(name))) {
+ mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ ret = rewriter.create<mlir::LLVM::GlobalOp>(
+ loc, type, /*isConstant=*/false, mlir::LLVM::Linkage::External, name,
+ /*value=*/mlir::Attribute(), /*alignment=*/0, 0);
+ }
+ return ret;
+}
+
+mlir::Value
+MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter) {
+ auto context = rewriter.getContext();
+ auto i32 = rewriter.getI32Type();
+ // ptrType `!llvm.ptr`
+ mlir::Type ptrType = mlir::LLVM::LLVMPointerType::get(context);
+ // get external opaque struct pointer type
+ auto commStructT =
+ mlir::LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
+
+ // make sure global op definition exists
+ getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
+ commStructT);
+
+ // get address of @MPI_COMM_WORLD
+ return rewriter.create<mlir::LLVM::AddressOfOp>(
+ loc, ptrType, mlir::SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
+}
+
+mlir::Value
+MPIImplTraits::getDataType(const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type type) {
+ assert(false && "getDataType not implemented for this MPI implementation");
+ return {};
+}
+
+#else
+#error "Unsupported MPI implementation"
+#endif
+
+} // namespace
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d87a10aab8f49..adeb9e85f05f7 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -6,9 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+// This must go first (MPI gets confused otherwise)
+#include "MPIImplTraits.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Pass/Pass.h"
@@ -17,38 +19,13 @@
using namespace mlir;
-namespace {
-
-struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
-struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override;
-};
-
// TODO: this was copied from GPUOpsLowering.cpp:288
// is this okay, or should this be moved to some common file?
-LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
- ConversionPatternRewriter &rewriter,
- StringRef name,
- LLVM::LLVMFunctionType type) {
+static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
+ const Location loc,
+ ConversionPatternRewriter &rewriter,
+ StringRef name,
+ LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
@@ -59,157 +36,267 @@ LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
return ret;
}
-// TODO: this is pretty close to getOrDefineFunction, can probably be factored
-LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
- ConversionPatternRewriter &rewriter,
- StringRef name,
- LLVM::LLVMStructType type) {
- LLVM::GlobalOp ret;
- if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- ret = rewriter.create<LLVM::GlobalOp>(
- loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
- /*value=*/Attribute(), /*alignment=*/0, 0);
- }
- return ret;
-}
-
-} // namespace
+namespace {
//===----------------------------------------------------------------------===//
// InitOpLowering
//===----------------------------------------------------------------------===//
-LogicalResult
-InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- // get loc
- auto loc = op.getLoc();
+struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- // ptrType `!llvm.ptr`
- Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+ LogicalResult
+ matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // get loc
+ auto loc = op.getLoc();
- // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
- auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
- Value llvmnull = nullPtrOp.getRes();
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- // grab a reference to the global module op:
- auto moduleOp = op->getParentOfType<ModuleOp>();
+ // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
+ auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+ Value llvmnull = nullPtrOp.getRes();
- // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
- auto initFuncType =
- LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
- // get or create function declaration:
- LLVM::LLVMFuncOp initDecl =
- getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
- // replace init with function call
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
- ValueRange{llvmnull, llvmnull});
+ // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
+ auto initFuncType =
+ LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
- return success();
-}
+ // replace init with function call
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
+ ValueRange{llvmnull, llvmnull});
+
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
// FinalizeOpLowering
//===----------------------------------------------------------------------===//
-LogicalResult
-FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- // get loc
- auto loc = op.getLoc();
+struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- // grab a reference to the global module op:
- auto moduleOp = op->getParentOfType<ModuleOp>();
+ LogicalResult
+ matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // get loc
+ auto loc = op.getLoc();
- // LLVM Function type representing `i32 MPI_Finalize()`
- auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
- // get or create function declaration:
- LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
- "MPI_Finalize", initFuncType);
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
- // replace init with function call
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
+ // LLVM Function type representing `i32 MPI_Finalize()`
+ auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
- return success();
-}
+ // replace init with function call
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
+
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// CommRankLowering
+// CommRankOpLowering
//===----------------------------------------------------------------------===//
-LogicalResult
-CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- // get some helper vars
- auto loc = op.getLoc();
- auto context = rewriter.getContext();
- auto i32 = rewriter.getI32Type();
-
- // ptrType `!llvm.ptr`
- Type ptrType = LLVM::LLVMPointerType::get(context);
-
- // get external opaque struct pointer type
- auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
-
- // grab a reference to the global module op:
- auto moduleOp = op->getParentOfType<ModuleOp>();
-
- // make sure global op definition exists
- getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
- commStructT);
-
- // get address of @MPI_COMM_WORLD
- auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
- auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
- auto commWorld = rewriter.create<LLVM::AddressOfOp>(
- loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
-
- // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
- auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
- // get or create function declaration:
- LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
- moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
-
- // replace init with function call
- auto callOp = rewriter.create<LLVM::CallOp>(
- loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});
-
- // load the rank into a register
- auto loadedRank =
- rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
-
- // if retval is checked, replace uses of retval with the results from the call
- // op
- SmallVector<Value> replacements;
- if (op.getRetval()) {
- replacements.push_back(callOp.getResult());
+struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // get some helper vars
+ auto loc = op.getLoc();
+ auto context = rewriter.getContext();
+ auto i32 = rewriter.getI32Type();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // get MPI_COMM_WORLD
+ auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+
+ // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+ auto rankFuncType =
+ LLVM::LLVMFunctionType::get(i32, {commWorld.getType(), ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
+
+ // replace init with function call
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+ auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
+ auto callOp = rewriter.create<LLVM::CallOp>(
+ loc, initDecl, ValueRange{commWorld, rankptr.getRes()});
+
+ // load the rank into a register
+ auto loadedRank =
+ rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
+
+ // if retval is checked, replace uses of retval with the results from the
+ // call op
+ SmallVector<Value> replacements;
+ if (op.getRetval()) {
+ replacements.push_back(callOp.getResult());
+ }
+ // replace all uses, then erase op
+ replacements.push_back(loadedRank.getRes());
+ rewriter.replaceOp(op, replacements);
+
+ return success();
}
- // replace all uses, then erase op
- replacements.push_back(loadedRank.getRes());
- rewriter.replaceOp(op, replacements);
+};
- return success();
-}
+//===----------------------------------------------------------------------===//
+// SendOpLowering
+//===----------------------------------------------------------------------===//
+
+struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // get some helper vars
+ auto loc = op.getLoc();
+ auto context = rewriter.getContext();
+ auto i32 = rewriter.getI32Type();
+ auto i64 = rewriter.getI64Type();
+ auto memRef = adaptor.getRef();
+ auto elemType = op.getRef().getType().getElementType();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // get MPI_COMM_WORLD, dataType and pointer
+ auto dataPtr =
+ rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1)
+ .getResult();
+ auto offset =
+ rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2).getResult();
+ dataPtr =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset)
+ .getResult();
+ auto size =
+ rewriter
+ .create<LLVM::ExtractValueOp>(loc, memRef, ArrayRef<int64_t>{3, 0})
+ .getResult();
+ size = rewriter.create<LLVM::TruncOp>(loc, i32, size).getResult();
+ auto dataType = MPIImplTraits::getDataType(loc, rewriter, elemType);
+ auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+
+ // LLVM Function type representing `i32 MPI_send(datatype, dst, tag, comm)`
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, i32, i32, i32, i32, commWorld.getType()});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
+
+ // replace op with function call
+ auto funcCall = rewriter.create<LLVM::CallOp>(
+ loc, funcDecl,
+ ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
+ commWorld});
+ if (op.getRetval()) {
+ rewriter.replaceOp(op, funcCall.getResult());
+ } else {
+ rewriter.eraseOp(op);
+ }
+
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
-// Pattern Population
+// RecvOpLowering
//===----------------------------------------------------------------------===//
-void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
- RewritePatternSet &patterns) {
- patterns.add<InitOpLowering>(converter);
- patterns.add<CommRankOpLowering>(converter);
- patterns.add<FinalizeOpLowering>(converter);
-}
+struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // get some helper vars
+ auto loc = op.getLoc();
+ auto context = rewriter.getContext();
+ auto i32 = rewriter.getI32Type();
+ auto i64 = rewriter.getI64Type();
+ auto memRef = adaptor.getRef();
+ auto elemType = op.getRef().getType().getElementType();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // get MPI_COMM_WORLD, dataType, status_ignore and pointer
+ auto dataPtr =
+ rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1)
+ .getResult();
+ auto offset =
+ rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2).getResult();
+ dataPtr =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset)
+ .getResult();
+ auto size =
+ rewriter
+ .create<LLVM::ExtractValueOp>(loc, memRef, ArrayRef<int64_t>{3, 0})
+ .getResult();
+ size = rewriter.create<LLVM::TruncOp>(loc, i32, size).getResult();
+ auto dataType = MPIImplTraits::getDataType(loc, rewriter, elemType);
+ auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+ auto statusIgnore =
+ rewriter
+ .create<LLVM::ConstantOp>(
+ loc, i64, reinterpret_cast<int64_t>(MPI_STATUS_IGNORE))
+ .getResult();
+ statusIgnore = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore)
+ .getResult();
+
+ // LLVM Function type representing `i32 MPI_Recv(datatype, dst, tag, comm)`
+ auto funcType = LLVM::LLVMFunctionType::get(
+ i32, {ptrType, i32, i32, i32, i32, commWorld.getType(), ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
+
+ // replace op with function call
+ auto funcCall = rewriter.create<LLVM::CallOp>(
+ loc, funcDecl,
+ ValueRange{dataPtr, size, dataType, adaptor.getSource(),
+ adaptor.getTag(), commWorld, statusIgnore});
+ if (op.getRetval()) {
+ rewriter.replaceOp(op, funcCall.getResult());
+ } else {
+ rewriter.eraseOp(op);
+ }
+
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//
-namespace {
/// Implement the interface to convert Func to LLVM.
struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
@@ -223,7 +310,17 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
};
} // namespace
-void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mlir::mpi::populateMPIToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
+ SendOpLowering, RecvOpLowering>(converter);
+}
+
+void mlir::mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
dialect->addInterfaces<FuncToLLVMDialectInterface>();
});
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
index a7a44ad24909a..61a4219fe35fc 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -1,36 +1,80 @@
// RUN: mlir-opt -convert-to-llvm %s | FileCheck %s
module {
-// CHECK: llvm.func @MPI_Finalize() -> i32
-// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
-// CHECK: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
-// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: llvm.func @MPI_Finalize() -> i32
+ // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+ // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+ // CHECK: llvm.func @MPI_Comm_rank({{.*}}, !llvm.ptr) -> i32
+ // COMM: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
+ // CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
func.func @mpi_test(%arg0: memref<100xf32>) {
+ // CHECK: [[varg0:%.*]]: !llvm.ptr, [[varg1:%.*]]: !llvm.ptr, [[varg2:%.*]]: i64, [[varg3:%.*]]: i64, [[varg4:%.*]]: i64
+ // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
+ // CHECK-NEXT: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
+ // CHECK-NEXT: [[v8:%.*]] = builtin.unrealized_conversion_cast [[v7]] : i32 to !mpi.retval
%0 = mpi.init : !mpi.retval
-// CHECK: %7 = llvm.mlir.zero : !llvm.ptr
-// CHECK-NEXT: %8 = llvm.call @MPI_Init(%7, %7) : (!llvm.ptr, !llvm.ptr) -> i32
-// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %8 : i32 to !mpi.retval
-
+ // CHECK: [[v9:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ // CHECK-NEXT: [[v10:%.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NEXT: [[v11:%.*]] = llvm.alloca [[v10]] x i32 : (i32) -> !llvm.ptr
+ // CHECK-NEXT: [[v12:%.*]] = llvm.call @MPI_Comm_rank([[v9]], [[v11]]) : (i32, !llvm.ptr) -> i32
+ // CHECK-NEXT: [[v13:%.*]] = builtin.unrealized_conversion_cast [[v12]] : i32 to !mpi.retval
+ // CHECK-NEXT: [[v14:%.*]] = llvm.load [[v11]] : !llvm.ptr -> i32
%retval, %rank = mpi.comm_rank : !mpi.retval, i32
-// CHECK: %10 = llvm.mlir.constant(1 : i32) : i32
-// CHECK-NEXT: %11 = llvm.alloca %10 x i32 : (i32) -> !llvm.ptr
-// CHECK-NEXT: %12 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
-// CHECK-NEXT: %13 = llvm.call @MPI_Comm_rank(%12, %11) : (!llvm.ptr, !llvm.ptr) -> i32
-// CHECK-NEXT: %14 = llvm.load %11 : !llvm.ptr -> i32
-// CHECK-NEXT: %15 = builtin.unrealized_conversion_cast %13 : i32 to !mpi.retval
+ // CHECK: [[v15:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v16:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v17:%.*]] = llvm.getelementptr [[v15]][[[v16]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v18:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v19:%.*]] = llvm.trunc [[v18]] : i64 to i32
+ // CHECK-NEXT: [[v20:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK-NEXT: [[v21:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ // CHECK-NEXT: [[v22:%.*]] = llvm.call @MPI_Send([[v17]], [[v19]], [[v20]], [[v14]], [[v14]], [[v21]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+ // CHECK: [[v23:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v24:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v25:%.*]] = llvm.getelementptr [[v23]][[[v24]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v26:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v27:%.*]] = llvm.trunc [[v26]] : i64 to i32
+ // CHECK-NEXT: [[v28:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK-NEXT: [[v29:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ // CHECK-NEXT: [[v30:%.*]] = llvm.call @MPI_Send([[v25]], [[v27]], [[v28]], [[v14]], [[v14]], [[v29]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
%1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ // CHECK: [[v31:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v32:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v33:%.*]] = llvm.getelementptr [[v31]][[[v32]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v34:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v35:%.*]] = llvm.trunc [[v34]] : i64 to i32
+ // CHECK-NEXT: [[v36:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK-NEXT: [[v37:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ // CHECK-NEXT: [[v38:%.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK-NEXT: [[v39:%.*]] = llvm.inttoptr [[v38]] : i64 to !llvm.ptr
+ // CHECK-NEXT: [[v40:%.*]] = llvm.call @MPI_Recv([[v33]], [[v35]], [[v36]], [[v14]], [[v14]], [[v37]], [[v39]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+ // CHECK: [[v41:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v42:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v43:%.*]] = llvm.getelementptr [[v41]][[[v42]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: [[v44:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-NEXT: [[v45:%.*]] = llvm.trunc [[v44]] : i64 to i32
+ // CHECK-NEXT: [[v46:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK-NEXT: [[v47:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ // CHECK-NEXT: [[v48:%.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK-NEXT: [[v49:%.*]] = llvm.inttoptr [[v48]] : i64 to !llvm.ptr
+ // CHECK-NEXT: [[v50:%.*]] = llvm.call @MPI_Recv([[v43]], [[v45]], [[v46]], [[v14]], [[v14]], [[v47]], [[v49]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ // CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
-// CHECK: %18 = llvm.call @MPI_Finalize() : () -> i32
%4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
>From 3672bdf0c7b2a47de6223181ca33d1dfac119cd4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 7 Feb 2025 09:07:19 +0100
Subject: [PATCH 3/4] fixed finding MPI in cmake for MPIToLLVM MPIImplTraits
for OpenMPI
---
mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt | 20 ++++-
mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h | 74 +++++++++++++------
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 19 +++--
mlir/test/Conversion/MPIToLLVM/ops.mlir | 38 +++++-----
4 files changed, 101 insertions(+), 50 deletions(-)
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
index 5d6b82e655b82..17df603ff5686 100644
--- a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -1,11 +1,20 @@
-find_package(MPI REQUIRED) # MPI_SKIP_COMPILER_WRAPPER TRUE MPI_CXX_SKIP_MPICXX TRUE)
+find_path(MPI_C_HEADER_DIR mpi.h
+ PATHS $ENV{I_MPI_ROOT}/include
+ $ENV{MPI_HOME}/include
+ $ENV{MPI_ROOT}/include)
+if(MPI_C_HEADER_DIR)
+ # cmake_path(REMOVE_FILENAME MPI_C_HEADER_DIR)
+ message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
+else()
+ message(WARNING "MPI not found, disabling MLIRMPIToLLVM conversion")
+ return()
+endif()
add_mlir_conversion_library(MLIRMPIToLLVM
MPIToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
- ${MPI_C_HEADER_DIR}
DEPENDS
MLIRConversionPassIncGen
@@ -17,4 +26,9 @@ add_mlir_conversion_library(MLIRMPIToLLVM
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRMPIDialect
- )
+)
+target_include_directories(
+ MLIRMPIToLLVM
+ PRIVATE
+ ${MPI_C_HEADER_DIR}
+)
\ No newline at end of file
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
index 5fd88d0aab4ac..09811e1cb7c61 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
@@ -1,4 +1,5 @@
#define MPICH_SKIP_MPICXX 1
+#define OMPI_SKIP_MPICXX 1
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -16,7 +17,8 @@ struct MPIImplTraits {
mlir::ConversionPatternRewriter &rewriter);
// get/create MPI datatype as a mlir::Value which corresponds to the given
// mlir::Type
- static mlir::Value getDataType(const mlir::Location loc,
+ static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+ const mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type type);
};
@@ -33,7 +35,7 @@ MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
}
mlir::Value
-MPIImplTraits::getDataType(const mlir::Location loc,
+MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type type) {
int32_t mtype = 0;
@@ -42,21 +44,21 @@ MPIImplTraits::getDataType(const mlir::Location loc,
else if (type.isF64())
mtype = MPI_DOUBLE;
else if (type.isInteger(64) && !type.isUnsignedInteger())
- mtype = MPI_LONG;
+ mtype = MPI_INT64_T;
else if (type.isInteger(64))
- mtype = MPI_UNSIGNED_LONG;
+ mtype = MPI_UINT64_T;
else if (type.isInteger(32) && !type.isUnsignedInteger())
- mtype = MPI_INT;
+ mtype = MPI_INT32_T;
else if (type.isInteger(32))
- mtype = MPI_UNSIGNED;
+ mtype = MPI_UINT32_T;
else if (type.isInteger(16) && !type.isUnsignedInteger())
- mtype = MPI_SHORT;
+ mtype = MPI_INT16_T;
else if (type.isInteger(16))
- mtype = MPI_UNSIGNED_SHORT;
+ mtype = MPI_UINT16_T;
else if (type.isInteger(8) && !type.isUnsignedInteger())
- mtype = MPI_CHAR;
+ mtype = MPI_INT8_T;
else if (type.isInteger(8))
- mtype = MPI_UNSIGNED_CHAR;
+ mtype = MPI_UINT8_T;
else
assert(false && "unsupported type");
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
@@ -88,28 +90,58 @@ mlir::Value
MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter) {
auto context = rewriter.getContext();
- auto i32 = rewriter.getI32Type();
- // ptrType `!llvm.ptr`
- mlir::Type ptrType = mlir::LLVM::LLVMPointerType::get(context);
// get external opaque struct pointer type
auto commStructT =
- mlir::LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
+ mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
+ const char *name = "ompi_mpi_comm_world";
// make sure global op definition exists
- getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
- commStructT);
+ (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
- // get address of @MPI_COMM_WORLD
+ // get address of symbol
return rewriter.create<mlir::LLVM::AddressOfOp>(
- loc, ptrType, mlir::SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
+ loc, mlir::LLVM::LLVMPointerType::get(context),
+ mlir::SymbolRefAttr::get(context, name));
}
mlir::Value
-MPIImplTraits::getDataType(const mlir::Location loc,
+MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter,
mlir::Type type) {
- assert(false && "getDataType not implemented for this MPI implementation");
- return {};
+ const char *mtype = nullptr;
+ if (type.isF32())
+ mtype = "ompi_mpi_float";
+ else if (type.isF64())
+ mtype = "ompi_mpi_double";
+ else if (type.isInteger(64) && !type.isUnsignedInteger())
+ mtype = "ompi_mpi_int64_t";
+ else if (type.isInteger(64))
+ mtype = "ompi_mpi_uint64_t";
+ else if (type.isInteger(32) && !type.isUnsignedInteger())
+ mtype = "ompi_mpi_int32_t";
+ else if (type.isInteger(32))
+ mtype = "ompi_mpi_uint32_t";
+ else if (type.isInteger(16) && !type.isUnsignedInteger())
+ mtype = "ompi_mpi_int16_t";
+ else if (type.isInteger(16))
+ mtype = "ompi_mpi_uint16_t";
+ else if (type.isInteger(8) && !type.isUnsignedInteger())
+ mtype = "ompi_mpi_int8_t";
+ else if (type.isInteger(8))
+ mtype = "ompi_mpi_uint8_t";
+ else
+ assert(false && "unsupported type");
+
+ auto context = rewriter.getContext();
+ // get external opaque struct pointer type
+ auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
+ "ompi_predefined_datatype_t", context);
+ // make sure global op definition exists
+ (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype, commStructT);
+ // get address of symbol
+ return rewriter.create<mlir::LLVM::AddressOfOp>(
+ loc, mlir::LLVM::LLVMPointerType::get(context),
+ mlir::SymbolRefAttr::get(context, mtype));
}
#else
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index adeb9e85f05f7..9980c789cda6e 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -198,12 +198,14 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
.create<LLVM::ExtractValueOp>(loc, memRef, ArrayRef<int64_t>{3, 0})
.getResult();
size = rewriter.create<LLVM::TruncOp>(loc, i32, size).getResult();
- auto dataType = MPIImplTraits::getDataType(loc, rewriter, elemType);
+ auto dataType =
+ MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
- // LLVM Function type representing `i32 MPI_send(datatype, dst, tag, comm)`
+ // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
+ // tag, comm)`
auto funcType = LLVM::LLVMFunctionType::get(
- i32, {ptrType, i32, i32, i32, i32, commWorld.getType()});
+ i32, {ptrType, i32, dataType.getType(), i32, i32, commWorld.getType()});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
@@ -261,7 +263,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
.create<LLVM::ExtractValueOp>(loc, memRef, ArrayRef<int64_t>{3, 0})
.getResult();
size = rewriter.create<LLVM::TruncOp>(loc, i32, size).getResult();
- auto dataType = MPIImplTraits::getDataType(loc, rewriter, elemType);
+ auto dataType =
+ MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
auto statusIgnore =
rewriter
@@ -271,9 +274,11 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
statusIgnore = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore)
.getResult();
- // LLVM Function type representing `i32 MPI_Recv(datatype, dst, tag, comm)`
- auto funcType = LLVM::LLVMFunctionType::get(
- i32, {ptrType, i32, i32, i32, i32, commWorld.getType(), ptrType});
+ // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
+ // tag, comm)`
+ auto funcType =
+ LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
+ i32, commWorld.getType(), ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
index 61a4219fe35fc..449e6418976cc 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -2,9 +2,9 @@
module {
// CHECK: llvm.func @MPI_Finalize() -> i32
- // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
- // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
- // CHECK: llvm.func @MPI_Comm_rank({{.*}}, !llvm.ptr) -> i32
+ // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}, !llvm.ptr) -> i32
+ // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}) -> i32
+ // CHECK: llvm.func @MPI_Comm_rank({{.+}}, !llvm.ptr) -> i32
// COMM: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
@@ -21,10 +21,10 @@ module {
// CHECK-NEXT: [[v8:%.*]] = builtin.unrealized_conversion_cast [[v7]] : i32 to !mpi.retval
%0 = mpi.init : !mpi.retval
- // CHECK: [[v9:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ // CHECK: [[v9:%.*]] = llvm.mlir.
// CHECK-NEXT: [[v10:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: [[v11:%.*]] = llvm.alloca [[v10]] x i32 : (i32) -> !llvm.ptr
- // CHECK-NEXT: [[v12:%.*]] = llvm.call @MPI_Comm_rank([[v9]], [[v11]]) : (i32, !llvm.ptr) -> i32
+ // CHECK-NEXT: [[v12:%.*]] = llvm.call @MPI_Comm_rank([[v9]], [[v11]]) : ({{.+}}, !llvm.ptr) -> i32
// CHECK-NEXT: [[v13:%.*]] = builtin.unrealized_conversion_cast [[v12]] : i32 to !mpi.retval
// CHECK-NEXT: [[v14:%.*]] = llvm.load [[v11]] : !llvm.ptr -> i32
%retval, %rank = mpi.comm_rank : !mpi.retval, i32
@@ -34,9 +34,9 @@ module {
// CHECK-NEXT: [[v17:%.*]] = llvm.getelementptr [[v15]][[[v16]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: [[v18:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: [[v19:%.*]] = llvm.trunc [[v18]] : i64 to i32
- // CHECK-NEXT: [[v20:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK-NEXT: [[v21:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
- // CHECK-NEXT: [[v22:%.*]] = llvm.call @MPI_Send([[v17]], [[v19]], [[v20]], [[v14]], [[v14]], [[v21]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+ // CHECK-NEXT: [[v20:%.*]] = llvm.mlir.
+ // CHECK-NEXT: [[v21:%.*]] = llvm.mlir.
+ // CHECK-NEXT: [[v22:%.*]] = llvm.call @MPI_Send([[v17]], [[v19]], [[v20]], [[v14]], [[v14]], [[v21]])
mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
// CHECK: [[v23:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -44,9 +44,9 @@ module {
// CHECK-NEXT: [[v25:%.*]] = llvm.getelementptr [[v23]][[[v24]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: [[v26:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: [[v27:%.*]] = llvm.trunc [[v26]] : i64 to i32
- // CHECK-NEXT: [[v28:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK-NEXT: [[v29:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
- // CHECK-NEXT: [[v30:%.*]] = llvm.call @MPI_Send([[v25]], [[v27]], [[v28]], [[v14]], [[v14]], [[v29]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+ // CHECK-NEXT: [[v28:%.*]] = llvm.mlir.
+ // CHECK-NEXT: [[v29:%.*]] = llvm.mlir.
+ // CHECK-NEXT: [[v30:%.*]] = llvm.call @MPI_Send([[v25]], [[v27]], [[v28]], [[v14]], [[v14]], [[v29]])
%1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v31:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -54,11 +54,11 @@ module {
// CHECK-NEXT: [[v33:%.*]] = llvm.getelementptr [[v31]][[[v32]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: [[v34:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: [[v35:%.*]] = llvm.trunc [[v34]] : i64 to i32
- // CHECK-NEXT: [[v36:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK-NEXT: [[v37:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
- // CHECK-NEXT: [[v38:%.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK-NEXT: [[v36:%.*]] = llvm.mlir.
+ // CHECK-NEXT: [[v37:%.*]] = llvm.mlir.
+ // CHECK-NEXT: [[v38:%.*]] = llvm.mlir.constant({{[0-9]+}} : i64) : i64
// CHECK-NEXT: [[v39:%.*]] = llvm.inttoptr [[v38]] : i64 to !llvm.ptr
- // CHECK-NEXT: [[v40:%.*]] = llvm.call @MPI_Recv([[v33]], [[v35]], [[v36]], [[v14]], [[v14]], [[v37]], [[v39]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+ // CHECK-NEXT: [[v40:%.*]] = llvm.call @MPI_Recv([[v33]], [[v35]], [[v36]], [[v14]], [[v14]], [[v37]], [[v39]])
mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
// CHECK: [[v41:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -66,11 +66,11 @@ module {
// CHECK-NEXT: [[v43:%.*]] = llvm.getelementptr [[v41]][[[v42]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK-NEXT: [[v44:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: [[v45:%.*]] = llvm.trunc [[v44]] : i64 to i32
- // CHECK-NEXT: [[v46:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK-NEXT: [[v47:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
- // CHECK-NEXT: [[v48:%.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK-NEXT: [[v46:%.*]] = llvm.mlir.
+ // CHECK-NEXT: [[v47:%.*]] = llvm.mlir.
+ // CHECK-NEXT: [[v48:%.*]] = llvm.mlir.constant({{[0-9]+}} : i64) : i64
// CHECK-NEXT: [[v49:%.*]] = llvm.inttoptr [[v48]] : i64 to !llvm.ptr
- // CHECK-NEXT: [[v50:%.*]] = llvm.call @MPI_Recv([[v43]], [[v45]], [[v46]], [[v14]], [[v14]], [[v47]], [[v49]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+ // CHECK-NEXT: [[v50:%.*]] = llvm.call @MPI_Recv([[v43]], [[v45]], [[v46]], [[v14]], [[v14]], [[v47]], [[v49]])
%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: llvm.call @MPI_Finalize() : () -> i32
>From 0597c7473dce0bd62520b1cc821bc82f703b05b4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 13 Feb 2025 16:27:29 +0100
Subject: [PATCH 4/4] provide empty register/populate functions even when
MPIToLLVM is disabled
---
mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt | 21 +++++++++-----------
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 15 ++++++++++++++
2 files changed, 24 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
index 17df603ff5686..39c5a9d145818 100644
--- a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -2,13 +2,6 @@ find_path(MPI_C_HEADER_DIR mpi.h
PATHS $ENV{I_MPI_ROOT}/include
$ENV{MPI_HOME}/include
$ENV{MPI_ROOT}/include)
-if(MPI_C_HEADER_DIR)
- # cmake_path(REMOVE_FILENAME MPI_C_HEADER_DIR)
- message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
-else()
- message(WARNING "MPI not found, disabling MLIRMPIToLLVM conversion")
- return()
-endif()
add_mlir_conversion_library(MLIRMPIToLLVM
MPIToLLVM.cpp
@@ -27,8 +20,12 @@ add_mlir_conversion_library(MLIRMPIToLLVM
MLIRLLVMDialect
MLIRMPIDialect
)
-target_include_directories(
- MLIRMPIToLLVM
- PRIVATE
- ${MPI_C_HEADER_DIR}
-)
\ No newline at end of file
+
+if(MPI_C_HEADER_DIR)
+ # cmake_path(REMOVE_FILENAME MPI_C_HEADER_DIR)
+ message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
+ target_include_directories(obj.MLIRMPIToLLVM PRIVATE ${MPI_C_HEADER_DIR})
+ target_compile_definitions(obj.MLIRMPIToLLVM PUBLIC FOUND_MPI_C_HEADER=1)
+else()
+ message(WARNING "MPI not found, disabling MLIRMPIToLLVM conversion")
+endif()
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 9980c789cda6e..9f9f0ee078746 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//
+// skip if no MPI C header was found
+#ifdef FOUND_MPI_C_HEADER
+
// This must go first (MPI gets confused otherwise)
#include "MPIImplTraits.h"
@@ -330,3 +333,15 @@ void mlir::mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
dialect->addInterfaces<FuncToLLVMDialectInterface>();
});
}
+
+#else // FOUND_MPI_C_HEADER
+
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+using namespace mlir;
+
+void mlir::mpi::populateMPIToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {}
+
+void mlir::mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {}
+
+#endif // FOUND_MPI_C_HEADER
More information about the Mlir-commits
mailing list