[Mlir-commits] [mlir] [mlir][mpi] Lowering Mpi To LLVM (PR #127053)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 13 04:46:41 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
As agreed by @<!-- -->AntonLydike this replaces #<!-- -->95524.
* The first set of patterns to convert the MPI dialect to LLVM.
* Further conversion pattern will be added in future PRs.
This adds the following on top of #<!-- -->95524
* Support for Intel MPI by introducing MPIImplTraits to distinguish MPI implementations
* Works end-to-end in a downstream project using Intel MPI (going through Mesh dialect)
* Lowering MPI_Send/MPI_Recv
* Conversion will only be included if MPI is found
cc @<!-- -->mofeing @<!-- -->wsmoses @<!-- -->tobiasgrosser @<!-- -->hhkit
---
Patch is 44.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127053.diff
10 Files Affected:
- (added) mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h (+31)
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPI.td (+77-77)
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+8-8)
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPITypes.td (+1-1)
- (modified) mlir/include/mlir/InitAllExtensions.h (+2)
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
- (added) mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt (+34)
- (added) mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h (+151)
- (added) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+332)
- (added) mlir/test/Conversion/MPIToLLVM/ops.mlir (+84)
``````````diff
diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
new file mode 100644
index 0000000000000..8d2698aa91c7c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
@@ -0,0 +1,31 @@
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MPITOLLVM_H
+#define MLIR_CONVERSION_MPITOLLVM_H
+
+#include "mlir/IR/DialectRegistry.h"
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace mpi {
+
+void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+
+void registerConvertMPIToLLVMInterface(DialectRegistry ®istry);
+
+} // namespace mpi
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MPITOLLVM_H
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 7c84443e5520d..df0cf9d518faf 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -42,104 +42,104 @@ def MPI_Dialect : Dialect {
// Error classes enum:
//===----------------------------------------------------------------------===//
-def MPI_CodeSuccess : I32EnumAttrCase<"MPI_SUCCESS", 0, "MPI_SUCCESS">;
-def MPI_CodeErrAccess : I32EnumAttrCase<"MPI_ERR_ACCESS", 1, "MPI_ERR_ACCESS">;
-def MPI_CodeErrAmode : I32EnumAttrCase<"MPI_ERR_AMODE", 2, "MPI_ERR_AMODE">;
-def MPI_CodeErrArg : I32EnumAttrCase<"MPI_ERR_ARG", 3, "MPI_ERR_ARG">;
-def MPI_CodeErrAssert : I32EnumAttrCase<"MPI_ERR_ASSERT", 4, "MPI_ERR_ASSERT">;
+def MPI_CodeSuccess : I32EnumAttrCase<"_MPI_SUCCESS", 0, "MPI_SUCCESS">;
+def MPI_CodeErrAccess : I32EnumAttrCase<"_MPI_ERR_ACCESS", 1, "MPI_ERR_ACCESS">;
+def MPI_CodeErrAmode : I32EnumAttrCase<"_MPI_ERR_AMODE", 2, "MPI_ERR_AMODE">;
+def MPI_CodeErrArg : I32EnumAttrCase<"_MPI_ERR_ARG", 3, "MPI_ERR_ARG">;
+def MPI_CodeErrAssert : I32EnumAttrCase<"_MPI_ERR_ASSERT", 4, "MPI_ERR_ASSERT">;
def MPI_CodeErrBadFile
- : I32EnumAttrCase<"MPI_ERR_BAD_FILE", 5, "MPI_ERR_BAD_FILE">;
-def MPI_CodeErrBase : I32EnumAttrCase<"MPI_ERR_BASE", 6, "MPI_ERR_BASE">;
-def MPI_CodeErrBuffer : I32EnumAttrCase<"MPI_ERR_BUFFER", 7, "MPI_ERR_BUFFER">;
-def MPI_CodeErrComm : I32EnumAttrCase<"MPI_ERR_COMM", 8, "MPI_ERR_COMM">;
+ : I32EnumAttrCase<"_MPI_ERR_BAD_FILE", 5, "MPI_ERR_BAD_FILE">;
+def MPI_CodeErrBase : I32EnumAttrCase<"_MPI_ERR_BASE", 6, "MPI_ERR_BASE">;
+def MPI_CodeErrBuffer : I32EnumAttrCase<"_MPI_ERR_BUFFER", 7, "MPI_ERR_BUFFER">;
+def MPI_CodeErrComm : I32EnumAttrCase<"_MPI_ERR_COMM", 8, "MPI_ERR_COMM">;
def MPI_CodeErrConversion
- : I32EnumAttrCase<"MPI_ERR_CONVERSION", 9, "MPI_ERR_CONVERSION">;
-def MPI_CodeErrCount : I32EnumAttrCase<"MPI_ERR_COUNT", 10, "MPI_ERR_COUNT">;
-def MPI_CodeErrDims : I32EnumAttrCase<"MPI_ERR_DIMS", 11, "MPI_ERR_DIMS">;
-def MPI_CodeErrDisp : I32EnumAttrCase<"MPI_ERR_DISP", 12, "MPI_ERR_DISP">;
+ : I32EnumAttrCase<"_MPI_ERR_CONVERSION", 9, "MPI_ERR_CONVERSION">;
+def MPI_CodeErrCount : I32EnumAttrCase<"_MPI_ERR_COUNT", 10, "MPI_ERR_COUNT">;
+def MPI_CodeErrDims : I32EnumAttrCase<"_MPI_ERR_DIMS", 11, "MPI_ERR_DIMS">;
+def MPI_CodeErrDisp : I32EnumAttrCase<"_MPI_ERR_DISP", 12, "MPI_ERR_DISP">;
def MPI_CodeErrDupDatarep
- : I32EnumAttrCase<"MPI_ERR_DUP_DATAREP", 13, "MPI_ERR_DUP_DATAREP">;
+ : I32EnumAttrCase<"_MPI_ERR_DUP_DATAREP", 13, "MPI_ERR_DUP_DATAREP">;
def MPI_CodeErrErrhandler
- : I32EnumAttrCase<"MPI_ERR_ERRHANDLER", 14, "MPI_ERR_ERRHANDLER">;
-def MPI_CodeErrFile : I32EnumAttrCase<"MPI_ERR_FILE", 15, "MPI_ERR_FILE">;
+ : I32EnumAttrCase<"_MPI_ERR_ERRHANDLER", 14, "MPI_ERR_ERRHANDLER">;
+def MPI_CodeErrFile : I32EnumAttrCase<"_MPI_ERR_FILE", 15, "MPI_ERR_FILE">;
def MPI_CodeErrFileExists
- : I32EnumAttrCase<"MPI_ERR_FILE_EXISTS", 16, "MPI_ERR_FILE_EXISTS">;
+ : I32EnumAttrCase<"_MPI_ERR_FILE_EXISTS", 16, "MPI_ERR_FILE_EXISTS">;
def MPI_CodeErrFileInUse
- : I32EnumAttrCase<"MPI_ERR_FILE_IN_USE", 17, "MPI_ERR_FILE_IN_USE">;
-def MPI_CodeErrGroup : I32EnumAttrCase<"MPI_ERR_GROUP", 18, "MPI_ERR_GROUP">;
-def MPI_CodeErrInfo : I32EnumAttrCase<"MPI_ERR_INFO", 19, "MPI_ERR_INFO">;
+ : I32EnumAttrCase<"_MPI_ERR_FILE_IN_USE", 17, "MPI_ERR_FILE_IN_USE">;
+def MPI_CodeErrGroup : I32EnumAttrCase<"_MPI_ERR_GROUP", 18, "MPI_ERR_GROUP">;
+def MPI_CodeErrInfo : I32EnumAttrCase<"_MPI_ERR_INFO", 19, "MPI_ERR_INFO">;
def MPI_CodeErrInfoKey
- : I32EnumAttrCase<"MPI_ERR_INFO_KEY", 20, "MPI_ERR_INFO_KEY">;
+ : I32EnumAttrCase<"_MPI_ERR_INFO_KEY", 20, "MPI_ERR_INFO_KEY">;
def MPI_CodeErrInfoNokey
- : I32EnumAttrCase<"MPI_ERR_INFO_NOKEY", 21, "MPI_ERR_INFO_NOKEY">;
+ : I32EnumAttrCase<"_MPI_ERR_INFO_NOKEY", 21, "MPI_ERR_INFO_NOKEY">;
def MPI_CodeErrInfoValue
- : I32EnumAttrCase<"MPI_ERR_INFO_VALUE", 22, "MPI_ERR_INFO_VALUE">;
+ : I32EnumAttrCase<"_MPI_ERR_INFO_VALUE", 22, "MPI_ERR_INFO_VALUE">;
def MPI_CodeErrInStatus
- : I32EnumAttrCase<"MPI_ERR_IN_STATUS", 23, "MPI_ERR_IN_STATUS">;
-def MPI_CodeErrIntern : I32EnumAttrCase<"MPI_ERR_INTERN", 24, "MPI_ERR_INTERN">;
-def MPI_CodeErrIo : I32EnumAttrCase<"MPI_ERR_IO", 25, "MPI_ERR_IO">;
-def MPI_CodeErrKeyval : I32EnumAttrCase<"MPI_ERR_KEYVAL", 26, "MPI_ERR_KEYVAL">;
+ : I32EnumAttrCase<"_MPI_ERR_IN_STATUS", 23, "MPI_ERR_IN_STATUS">;
+def MPI_CodeErrIntern : I32EnumAttrCase<"_MPI_ERR_INTERN", 24, "MPI_ERR_INTERN">;
+def MPI_CodeErrIo : I32EnumAttrCase<"_MPI_ERR_IO", 25, "MPI_ERR_IO">;
+def MPI_CodeErrKeyval : I32EnumAttrCase<"_MPI_ERR_KEYVAL", 26, "MPI_ERR_KEYVAL">;
def MPI_CodeErrLocktype
- : I32EnumAttrCase<"MPI_ERR_LOCKTYPE", 27, "MPI_ERR_LOCKTYPE">;
-def MPI_CodeErrName : I32EnumAttrCase<"MPI_ERR_NAME", 28, "MPI_ERR_NAME">;
-def MPI_CodeErrNoMem : I32EnumAttrCase<"MPI_ERR_NO_MEM", 29, "MPI_ERR_NO_MEM">;
+ : I32EnumAttrCase<"_MPI_ERR_LOCKTYPE", 27, "MPI_ERR_LOCKTYPE">;
+def MPI_CodeErrName : I32EnumAttrCase<"_MPI_ERR_NAME", 28, "MPI_ERR_NAME">;
+def MPI_CodeErrNoMem : I32EnumAttrCase<"_MPI_ERR_NO_MEM", 29, "MPI_ERR_NO_MEM">;
def MPI_CodeErrNoSpace
- : I32EnumAttrCase<"MPI_ERR_NO_SPACE", 30, "MPI_ERR_NO_SPACE">;
+ : I32EnumAttrCase<"_MPI_ERR_NO_SPACE", 30, "MPI_ERR_NO_SPACE">;
def MPI_CodeErrNoSuchFile
- : I32EnumAttrCase<"MPI_ERR_NO_SUCH_FILE", 31, "MPI_ERR_NO_SUCH_FILE">;
+ : I32EnumAttrCase<"_MPI_ERR_NO_SUCH_FILE", 31, "MPI_ERR_NO_SUCH_FILE">;
def MPI_CodeErrNotSame
- : I32EnumAttrCase<"MPI_ERR_NOT_SAME", 32, "MPI_ERR_NOT_SAME">;
-def MPI_CodeErrOp : I32EnumAttrCase<"MPI_ERR_OP", 33, "MPI_ERR_OP">;
-def MPI_CodeErrOther : I32EnumAttrCase<"MPI_ERR_OTHER", 34, "MPI_ERR_OTHER">;
+ : I32EnumAttrCase<"_MPI_ERR_NOT_SAME", 32, "MPI_ERR_NOT_SAME">;
+def MPI_CodeErrOp : I32EnumAttrCase<"_MPI_ERR_OP", 33, "MPI_ERR_OP">;
+def MPI_CodeErrOther : I32EnumAttrCase<"_MPI_ERR_OTHER", 34, "MPI_ERR_OTHER">;
def MPI_CodeErrPending
- : I32EnumAttrCase<"MPI_ERR_PENDING", 35, "MPI_ERR_PENDING">;
-def MPI_CodeErrPort : I32EnumAttrCase<"MPI_ERR_PORT", 36, "MPI_ERR_PORT">;
+ : I32EnumAttrCase<"_MPI_ERR_PENDING", 35, "MPI_ERR_PENDING">;
+def MPI_CodeErrPort : I32EnumAttrCase<"_MPI_ERR_PORT", 36, "MPI_ERR_PORT">;
def MPI_CodeErrProcAborted
- : I32EnumAttrCase<"MPI_ERR_PROC_ABORTED", 37, "MPI_ERR_PROC_ABORTED">;
-def MPI_CodeErrQuota : I32EnumAttrCase<"MPI_ERR_QUOTA", 38, "MPI_ERR_QUOTA">;
-def MPI_CodeErrRank : I32EnumAttrCase<"MPI_ERR_RANK", 39, "MPI_ERR_RANK">;
+ : I32EnumAttrCase<"_MPI_ERR_PROC_ABORTED", 37, "MPI_ERR_PROC_ABORTED">;
+def MPI_CodeErrQuota : I32EnumAttrCase<"_MPI_ERR_QUOTA", 38, "MPI_ERR_QUOTA">;
+def MPI_CodeErrRank : I32EnumAttrCase<"_MPI_ERR_RANK", 39, "MPI_ERR_RANK">;
def MPI_CodeErrReadOnly
- : I32EnumAttrCase<"MPI_ERR_READ_ONLY", 40, "MPI_ERR_READ_ONLY">;
+ : I32EnumAttrCase<"_MPI_ERR_READ_ONLY", 40, "MPI_ERR_READ_ONLY">;
def MPI_CodeErrRequest
- : I32EnumAttrCase<"MPI_ERR_REQUEST", 41, "MPI_ERR_REQUEST">;
+ : I32EnumAttrCase<"_MPI_ERR_REQUEST", 41, "MPI_ERR_REQUEST">;
def MPI_CodeErrRmaAttach
- : I32EnumAttrCase<"MPI_ERR_RMA_ATTACH", 42, "MPI_ERR_RMA_ATTACH">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_ATTACH", 42, "MPI_ERR_RMA_ATTACH">;
def MPI_CodeErrRmaConflict
- : I32EnumAttrCase<"MPI_ERR_RMA_CONFLICT", 43, "MPI_ERR_RMA_CONFLICT">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_CONFLICT", 43, "MPI_ERR_RMA_CONFLICT">;
def MPI_CodeErrRmaFlavor
- : I32EnumAttrCase<"MPI_ERR_RMA_FLAVOR", 44, "MPI_ERR_RMA_FLAVOR">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_FLAVOR", 44, "MPI_ERR_RMA_FLAVOR">;
def MPI_CodeErrRmaRange
- : I32EnumAttrCase<"MPI_ERR_RMA_RANGE", 45, "MPI_ERR_RMA_RANGE">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_RANGE", 45, "MPI_ERR_RMA_RANGE">;
def MPI_CodeErrRmaShared
- : I32EnumAttrCase<"MPI_ERR_RMA_SHARED", 46, "MPI_ERR_RMA_SHARED">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_SHARED", 46, "MPI_ERR_RMA_SHARED">;
def MPI_CodeErrRmaSync
- : I32EnumAttrCase<"MPI_ERR_RMA_SYNC", 47, "MPI_ERR_RMA_SYNC">;
-def MPI_CodeErrRoot : I32EnumAttrCase<"MPI_ERR_ROOT", 48, "MPI_ERR_ROOT">;
+ : I32EnumAttrCase<"_MPI_ERR_RMA_SYNC", 47, "MPI_ERR_RMA_SYNC">;
+def MPI_CodeErrRoot : I32EnumAttrCase<"_MPI_ERR_ROOT", 48, "MPI_ERR_ROOT">;
def MPI_CodeErrService
- : I32EnumAttrCase<"MPI_ERR_SERVICE", 49, "MPI_ERR_SERVICE">;
+ : I32EnumAttrCase<"_MPI_ERR_SERVICE", 49, "MPI_ERR_SERVICE">;
def MPI_CodeErrSession
- : I32EnumAttrCase<"MPI_ERR_SESSION", 50, "MPI_ERR_SESSION">;
-def MPI_CodeErrSize : I32EnumAttrCase<"MPI_ERR_SIZE", 51, "MPI_ERR_SIZE">;
-def MPI_CodeErrSpawn : I32EnumAttrCase<"MPI_ERR_SPAWN", 52, "MPI_ERR_SPAWN">;
-def MPI_CodeErrTag : I32EnumAttrCase<"MPI_ERR_TAG", 53, "MPI_ERR_TAG">;
+ : I32EnumAttrCase<"_MPI_ERR_SESSION", 50, "MPI_ERR_SESSION">;
+def MPI_CodeErrSize : I32EnumAttrCase<"_MPI_ERR_SIZE", 51, "MPI_ERR_SIZE">;
+def MPI_CodeErrSpawn : I32EnumAttrCase<"_MPI_ERR_SPAWN", 52, "MPI_ERR_SPAWN">;
+def MPI_CodeErrTag : I32EnumAttrCase<"_MPI_ERR_TAG", 53, "MPI_ERR_TAG">;
def MPI_CodeErrTopology
- : I32EnumAttrCase<"MPI_ERR_TOPOLOGY", 54, "MPI_ERR_TOPOLOGY">;
+ : I32EnumAttrCase<"_MPI_ERR_TOPOLOGY", 54, "MPI_ERR_TOPOLOGY">;
def MPI_CodeErrTruncate
- : I32EnumAttrCase<"MPI_ERR_TRUNCATE", 55, "MPI_ERR_TRUNCATE">;
-def MPI_CodeErrType : I32EnumAttrCase<"MPI_ERR_TYPE", 56, "MPI_ERR_TYPE">;
+ : I32EnumAttrCase<"_MPI_ERR_TRUNCATE", 55, "MPI_ERR_TRUNCATE">;
+def MPI_CodeErrType : I32EnumAttrCase<"_MPI_ERR_TYPE", 56, "MPI_ERR_TYPE">;
def MPI_CodeErrUnknown
- : I32EnumAttrCase<"MPI_ERR_UNKNOWN", 57, "MPI_ERR_UNKNOWN">;
+ : I32EnumAttrCase<"_MPI_ERR_UNKNOWN", 57, "MPI_ERR_UNKNOWN">;
def MPI_CodeErrUnsupportedDatarep
- : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_DATAREP", 58,
+ : I32EnumAttrCase<"_MPI_ERR_UNSUPPORTED_DATAREP", 58,
"MPI_ERR_UNSUPPORTED_DATAREP">;
def MPI_CodeErrUnsupportedOperation
- : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_OPERATION", 59,
+ : I32EnumAttrCase<"_MPI_ERR_UNSUPPORTED_OPERATION", 59,
"MPI_ERR_UNSUPPORTED_OPERATION">;
def MPI_CodeErrValueTooLarge
- : I32EnumAttrCase<"MPI_ERR_VALUE_TOO_LARGE", 60, "MPI_ERR_VALUE_TOO_LARGE">;
-def MPI_CodeErrWin : I32EnumAttrCase<"MPI_ERR_WIN", 61, "MPI_ERR_WIN">;
+ : I32EnumAttrCase<"_MPI_ERR_VALUE_TOO_LARGE", 60, "MPI_ERR_VALUE_TOO_LARGE">;
+def MPI_CodeErrWin : I32EnumAttrCase<"_MPI_ERR_WIN", 61, "MPI_ERR_WIN">;
def MPI_CodeErrLastcode
- : I32EnumAttrCase<"MPI_ERR_LASTCODE", 62, "MPI_ERR_LASTCODE">;
+ : I32EnumAttrCase<"_MPI_ERR_LASTCODE", 62, "MPI_ERR_LASTCODE">;
def MPI_ErrorClassEnum
: I32EnumAttr<"MPI_ErrorClassEnum", "MPI error class name", [
@@ -215,20 +215,20 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
let assemblyFormat = "`<` $value `>`";
}
-def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">;
-def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">;
-def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">;
-def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">;
-def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">;
-def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">;
-def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">;
-def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">;
-def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">;
-def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">;
-def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">;
-def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
-def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
-def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
+def MPI_OpNull : I32EnumAttrCase<"_MPI_OP_NULL", 0, "MPI_OP_NULL">;
+def MPI_OpMax : I32EnumAttrCase<"_MPI_MAX", 1, "MPI_MAX">;
+def MPI_OpMin : I32EnumAttrCase<"_MPI_MIN", 2, "MPI_MIN">;
+def MPI_OpSum : I32EnumAttrCase<"_MPI_SUM", 3, "MPI_SUM">;
+def MPI_OpProd : I32EnumAttrCase<"_MPI_PROD", 4, "MPI_PROD">;
+def MPI_OpLand : I32EnumAttrCase<"_MPI_LAND", 5, "MPI_LAND">;
+def MPI_OpBand : I32EnumAttrCase<"_MPI_BAND", 6, "MPI_BAND">;
+def MPI_OpLor : I32EnumAttrCase<"_MPI_LOR", 7, "MPI_LOR">;
+def MPI_OpBor : I32EnumAttrCase<"_MPI_BOR", 8, "MPI_BOR">;
+def MPI_OpLxor : I32EnumAttrCase<"_MPI_LXOR", 9, "MPI_LXOR">;
+def MPI_OpBxor : I32EnumAttrCase<"_MPI_BXOR", 10, "MPI_BXOR">;
+def MPI_OpMinloc : I32EnumAttrCase<"_MPI_MINLOC", 11, "MPI_MINLOC">;
+def MPI_OpMaxloc : I32EnumAttrCase<"_MPI_MAXLOC", 12, "MPI_MAXLOC">;
+def MPI_OpReplace : I32EnumAttrCase<"_MPI_REPLACE", 13, "MPI_REPLACE">;
def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
MPI_OpNull,
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 284ba72af9768..db28bd09678f8 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -102,13 +102,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
- I32 : $rank
+ I32 : $dest
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
- "type($ref) `,` type($tag) `,` type($rank)"
+ let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
+ "type($ref) `,` type($tag) `,` type($dest)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}
@@ -154,11 +154,11 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
//===----------------------------------------------------------------------===//
def MPI_RecvOp : MPI_Op<"recv", []> {
- let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
+ let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
"MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
- from rank `dest`. The `tag` value and communicator enables the library to
+ from rank `source`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
@@ -172,13 +172,13 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
let arguments = (
ins AnyMemRef : $ref,
- I32 : $tag, I32 : $rank
+ I32 : $tag, I32 : $source
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
- "type($ref) `,` type($tag) `,` type($rank)"
+ let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
+ "type($ref) `,` type($tag) `,` type($source)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index fafea0eac8bb7..a55d30e778e22 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
//===----------------------------------------------------------------------===//
def MPI_Retval : MPI_Type<"Retval", "retval"> {
- let summary = "MPI function call return value";
+ let summary = "MPI function call return value (!mpi.retval)";
let description = [{
This type represents a return value from an MPI function call.
This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 887db344ed88b..6ab23ff86b3c6 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -14,6 +14,7 @@
#ifndef MLIR_INITALLEXTENSIONS_H_
#define MLIR_INITALLEXTENSIONS_H_
+#include "Conversion/MPIToLLVM/MPIToLLVM.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
@@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertFuncToLLVMInterface(registry);
index::registerConvertIndexToLLVMInterface(registry);
registerConvertMathToLLVMInterface(registry);
+ mpi::registerConvertMPIToLLVMInterface(registry);
registerConvertMemRefToLLVMInterface(registry);
registerConvertNVVMToLLVMInterface(registry);
registerConvertOpenMPToLLVMInterface(registry);
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 0bd08ec6333e6..3dc7472584cf9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
add_subdirectory(MeshToMPI)
+add_subdirectory(MPIToLLVM)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..17df603ff5686
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -0,0 +1,34 @@
+find_path(MPI_C_HEADER_DIR mpi.h
+ PATHS $ENV{I_MPI_ROOT}/include
+ $ENV{MPI_HOME}/include
+ $ENV{MPI_ROOT}/include)
+if(MPI_C_HEADER_DIR)
+ # cmake_path(REMOVE_FILENAME MPI_C_HEADER_DIR)
+ message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
+else()
+ message(WARNING "MPI not found, disabling MLIRMPIToLLVM conversion")
+ return()
+endif()
+
+add_mlir_conversion_library(MLIRMPIToLLVM
+ MPIToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRMPIDialect
+)
+target_include_directories(
+ MLIRMPIToLLVM
+ PRIVATE
+ ${MPI_C_HEADER_DIR}
+)
\ No newline at end of file
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
new file mode 100644
index 0000000000000..09811e1cb7c61
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
@@ -0,0 +1,151 @@
+#define MPICH_SKIP_MPICXX 1
+#define OMPI_SKIP_MPICXX 1
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <mpi.h>
+
+namespace {
+
+// when lowerring the mpi dialect to functions calls certain details
+// differ between various MPI implementations. This class will provide
+// these depending on the MPI implementation that got included.
+struct MPIImplTraits {
+ // get/create MPI_COMM_WORLD as a mlir::Value
+ static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+ const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter);
+ // get/create MPI datatype as a mlir::Value which corresponds to the given
+ // mlir::Type
+ static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+ cons...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/127053
More information about the Mlir-commits
mailing list