[Mlir-commits] [mlir] [mlir] Initial patch to add an MPI dialect (PR #68892)

Anton Lydike llvmlistbot at llvm.org
Sat Jan 6 05:54:08 PST 2024


https://github.com/AntonLydike updated https://github.com/llvm/llvm-project/pull/68892

>From 17f7fe017d458a814cadb333c28d5f9dda3af0d4 Mon Sep 17 00:00:00 2001
From: Anton Lydike <me at antonlydike.de>
Date: Thu, 14 Dec 2023 18:02:00 +0100
Subject: [PATCH] [mlir] Initial patch to add an MPI dialect

---
 mlir/include/mlir/Dialect/CMakeLists.txt      |   1 +
 mlir/include/mlir/Dialect/MPI/CMakeLists.txt  |   1 +
 .../mlir/Dialect/MPI/IR/CMakeLists.txt        |  22 ++
 mlir/include/mlir/Dialect/MPI/IR/MPI.h        |  33 +++
 mlir/include/mlir/Dialect/MPI/IR/MPI.td       | 219 ++++++++++++++++++
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td    | 189 +++++++++++++++
 mlir/include/mlir/Dialect/MPI/IR/MPITypes.td  |  43 ++++
 mlir/include/mlir/InitAllDialects.h           |   2 +
 mlir/lib/Dialect/CMakeLists.txt               |   1 +
 mlir/lib/Dialect/MPI/CMakeLists.txt           |   1 +
 mlir/lib/Dialect/MPI/IR/CMakeLists.txt        |  19 ++
 mlir/lib/Dialect/MPI/IR/MPI.cpp               |  56 +++++
 mlir/lib/Dialect/MPI/IR/MPIOps.cpp            |  21 ++
 mlir/test/Dialect/MPI/invalid.mlir            |  50 ++++
 mlir/test/Dialect/MPI/ops.mlir                |  35 +++
 15 files changed, 693 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/MPI/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/MPI/IR/MPI.h
 create mode 100644 mlir/include/mlir/Dialect/MPI/IR/MPI.td
 create mode 100644 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
 create mode 100644 mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
 create mode 100644 mlir/lib/Dialect/MPI/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/MPI/IR/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/MPI/IR/MPI.cpp
 create mode 100644 mlir/lib/Dialect/MPI/IR/MPIOps.cpp
 create mode 100644 mlir/test/Dialect/MPI/invalid.mlir
 create mode 100644 mlir/test/Dialect/MPI/ops.mlir

diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 1c4569ecfa5848..9788e24e4a1d91 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -21,6 +21,7 @@ add_subdirectory(Math)
 add_subdirectory(MemRef)
 add_subdirectory(Mesh)
 add_subdirectory(MLProgram)
+add_subdirectory(MPI)
 add_subdirectory(NVGPU)
 add_subdirectory(OpenACC)
 add_subdirectory(OpenACCMPCommon)
diff --git a/mlir/include/mlir/Dialect/MPI/CMakeLists.txt b/mlir/include/mlir/Dialect/MPI/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MPI/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..dfec2ea486cb29
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MPI/IR/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_dialect(MPI mpi)
+add_mlir_doc(MPIOps MPI Dialects/ -gen-dialect-doc)
+
+# Add MPI operations
+set(LLVM_TARGET_DEFINITIONS MPIOps.td)
+mlir_tablegen(MPIOps.h.inc -gen-op-decls)
+mlir_tablegen(MPIOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRMPIOpsIncGen)
+
+# Add MPI types
+set(LLVM_TARGET_DEFINITIONS MPITypes.td)
+mlir_tablegen(MPITypesGen.h.inc -gen-typedef-decls)
+mlir_tablegen(MPITypesGen.cpp.inc -gen-typedef-defs)
+add_public_tablegen_target(MLIRMPITypesIncGen)
+
+# Add MPI attributes
+set(LLVM_TARGET_DEFINITIONS MPI.td)
+mlir_tablegen(MPIEnums.h.inc -gen-enum-decls)
+mlir_tablegen(MPIEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(MPIAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(MPIAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRMPIAttrsIncGen)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.h b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
new file mode 100644
index 00000000000000..f06b911ce3fe31
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
@@ -0,0 +1,33 @@
+//===- MPI.h - MPI dialect ----------------------------------------*- C++-*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_MPI_IR_MPI_H_
+#define MLIR_DIALECT_MPI_IR_MPI_H_
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// MPIDialect
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MPI/IR/MPIDialect.h.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/MPI/IR/MPITypesGen.h.inc"
+
+#include "mlir/Dialect/MPI/IR/MPIEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/MPI/IR/MPIAttrDefs.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MPI/IR/MPIOps.h.inc"
+
+#endif // MLIR_DIALECT_MPI_IR_MPI_H_
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
new file mode 100644
index 00000000000000..04e84020681294
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -0,0 +1,219 @@
+//===- MPI.td - Base defs for mpi dialect ------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MPI_IR_MPI_TD
+#define MLIR_DIALECT_MPI_IR_MPI_TD
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/EnumAttr.td"
+
+def MPI_Dialect : Dialect {
+  let name = "mpi";
+  let cppNamespace = "::mlir::mpi";
+  let description = [{
+    This dialect models the Message Passing Interface (MPI), version 
+    4.0. It is meant to serve as an interfacing dialect that is targeted
+    by higher-level dialects. The MPI dialect itself can be lowered to 
+    multiple MPI implementations and hide differences in ABI. The dialect
+    models the functions of the MPI specification as close to 1:1 as possible
+    while preserving SSA value semantics where it makes sense, and uses 
+    `memref` types instead of bare pointers.
+
+    This dialect is under active development, and while stability is an
+    eventual goal, it is not guaranteed at this juncture. Given the early 
+    state, it is recommended to inquire further prior to using this dialect.
+
+    For an in-depth documentation of the MPI library interface, please refer 
+    to official documentation such as the 
+    [OpenMPI online documentation](https://www.open-mpi.org/doc/current/).
+  }];
+
+  let usePropertiesForAttributes = 1;
+  let useDefaultAttributePrinterParser = 1;
+  let useDefaultTypePrinterParser = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Error classes enum:
+//===----------------------------------------------------------------------===//
+
+def MPI_CodeSuccess : I32EnumAttrCase<"MPI_SUCCESS", 0, "MPI_SUCCESS">;
+def MPI_CodeErrAccess : I32EnumAttrCase<"MPI_ERR_ACCESS", 1, "MPI_ERR_ACCESS">;
+def MPI_CodeErrAmode : I32EnumAttrCase<"MPI_ERR_AMODE", 2, "MPI_ERR_AMODE">;
+def MPI_CodeErrArg : I32EnumAttrCase<"MPI_ERR_ARG", 3, "MPI_ERR_ARG">;
+def MPI_CodeErrAssert : I32EnumAttrCase<"MPI_ERR_ASSERT", 4, "MPI_ERR_ASSERT">;
+def MPI_CodeErrBadFile
+    : I32EnumAttrCase<"MPI_ERR_BAD_FILE", 5, "MPI_ERR_BAD_FILE">;
+def MPI_CodeErrBase : I32EnumAttrCase<"MPI_ERR_BASE", 6, "MPI_ERR_BASE">;
+def MPI_CodeErrBuffer : I32EnumAttrCase<"MPI_ERR_BUFFER", 7, "MPI_ERR_BUFFER">;
+def MPI_CodeErrComm : I32EnumAttrCase<"MPI_ERR_COMM", 8, "MPI_ERR_COMM">;
+def MPI_CodeErrConversion
+    : I32EnumAttrCase<"MPI_ERR_CONVERSION", 9, "MPI_ERR_CONVERSION">;
+def MPI_CodeErrCount : I32EnumAttrCase<"MPI_ERR_COUNT", 10, "MPI_ERR_COUNT">;
+def MPI_CodeErrDims : I32EnumAttrCase<"MPI_ERR_DIMS", 11, "MPI_ERR_DIMS">;
+def MPI_CodeErrDisp : I32EnumAttrCase<"MPI_ERR_DISP", 12, "MPI_ERR_DISP">;
+def MPI_CodeErrDupDatarep
+    : I32EnumAttrCase<"MPI_ERR_DUP_DATAREP", 13, "MPI_ERR_DUP_DATAREP">;
+def MPI_CodeErrErrhandler
+    : I32EnumAttrCase<"MPI_ERR_ERRHANDLER", 14, "MPI_ERR_ERRHANDLER">;
+def MPI_CodeErrFile : I32EnumAttrCase<"MPI_ERR_FILE", 15, "MPI_ERR_FILE">;
+def MPI_CodeErrFileExists
+    : I32EnumAttrCase<"MPI_ERR_FILE_EXISTS", 16, "MPI_ERR_FILE_EXISTS">;
+def MPI_CodeErrFileInUse
+    : I32EnumAttrCase<"MPI_ERR_FILE_IN_USE", 17, "MPI_ERR_FILE_IN_USE">;
+def MPI_CodeErrGroup : I32EnumAttrCase<"MPI_ERR_GROUP", 18, "MPI_ERR_GROUP">;
+def MPI_CodeErrInfo : I32EnumAttrCase<"MPI_ERR_INFO", 19, "MPI_ERR_INFO">;
+def MPI_CodeErrInfoKey
+    : I32EnumAttrCase<"MPI_ERR_INFO_KEY", 20, "MPI_ERR_INFO_KEY">;
+def MPI_CodeErrInfoNokey
+    : I32EnumAttrCase<"MPI_ERR_INFO_NOKEY", 21, "MPI_ERR_INFO_NOKEY">;
+def MPI_CodeErrInfoValue
+    : I32EnumAttrCase<"MPI_ERR_INFO_VALUE", 22, "MPI_ERR_INFO_VALUE">;
+def MPI_CodeErrInStatus
+    : I32EnumAttrCase<"MPI_ERR_IN_STATUS", 23, "MPI_ERR_IN_STATUS">;
+def MPI_CodeErrIntern : I32EnumAttrCase<"MPI_ERR_INTERN", 24, "MPI_ERR_INTERN">;
+def MPI_CodeErrIo : I32EnumAttrCase<"MPI_ERR_IO", 25, "MPI_ERR_IO">;
+def MPI_CodeErrKeyval : I32EnumAttrCase<"MPI_ERR_KEYVAL", 26, "MPI_ERR_KEYVAL">;
+def MPI_CodeErrLocktype
+    : I32EnumAttrCase<"MPI_ERR_LOCKTYPE", 27, "MPI_ERR_LOCKTYPE">;
+def MPI_CodeErrName : I32EnumAttrCase<"MPI_ERR_NAME", 28, "MPI_ERR_NAME">;
+def MPI_CodeErrNoMem : I32EnumAttrCase<"MPI_ERR_NO_MEM", 29, "MPI_ERR_NO_MEM">;
+def MPI_CodeErrNoSpace
+    : I32EnumAttrCase<"MPI_ERR_NO_SPACE", 30, "MPI_ERR_NO_SPACE">;
+def MPI_CodeErrNoSuchFile
+    : I32EnumAttrCase<"MPI_ERR_NO_SUCH_FILE", 31, "MPI_ERR_NO_SUCH_FILE">;
+def MPI_CodeErrNotSame
+    : I32EnumAttrCase<"MPI_ERR_NOT_SAME", 32, "MPI_ERR_NOT_SAME">;
+def MPI_CodeErrOp : I32EnumAttrCase<"MPI_ERR_OP", 33, "MPI_ERR_OP">;
+def MPI_CodeErrOther : I32EnumAttrCase<"MPI_ERR_OTHER", 34, "MPI_ERR_OTHER">;
+def MPI_CodeErrPending
+    : I32EnumAttrCase<"MPI_ERR_PENDING", 35, "MPI_ERR_PENDING">;
+def MPI_CodeErrPort : I32EnumAttrCase<"MPI_ERR_PORT", 36, "MPI_ERR_PORT">;
+def MPI_CodeErrProcAborted
+    : I32EnumAttrCase<"MPI_ERR_PROC_ABORTED", 37, "MPI_ERR_PROC_ABORTED">;
+def MPI_CodeErrQuota : I32EnumAttrCase<"MPI_ERR_QUOTA", 38, "MPI_ERR_QUOTA">;
+def MPI_CodeErrRank : I32EnumAttrCase<"MPI_ERR_RANK", 39, "MPI_ERR_RANK">;
+def MPI_CodeErrReadOnly
+    : I32EnumAttrCase<"MPI_ERR_READ_ONLY", 40, "MPI_ERR_READ_ONLY">;
+def MPI_CodeErrRequest
+    : I32EnumAttrCase<"MPI_ERR_REQUEST", 41, "MPI_ERR_REQUEST">;
+def MPI_CodeErrRmaAttach
+    : I32EnumAttrCase<"MPI_ERR_RMA_ATTACH", 42, "MPI_ERR_RMA_ATTACH">;
+def MPI_CodeErrRmaConflict
+    : I32EnumAttrCase<"MPI_ERR_RMA_CONFLICT", 43, "MPI_ERR_RMA_CONFLICT">;
+def MPI_CodeErrRmaFlavor
+    : I32EnumAttrCase<"MPI_ERR_RMA_FLAVOR", 44, "MPI_ERR_RMA_FLAVOR">;
+def MPI_CodeErrRmaRange
+    : I32EnumAttrCase<"MPI_ERR_RMA_RANGE", 45, "MPI_ERR_RMA_RANGE">;
+def MPI_CodeErrRmaShared
+    : I32EnumAttrCase<"MPI_ERR_RMA_SHARED", 46, "MPI_ERR_RMA_SHARED">;
+def MPI_CodeErrRmaSync
+    : I32EnumAttrCase<"MPI_ERR_RMA_SYNC", 47, "MPI_ERR_RMA_SYNC">;
+def MPI_CodeErrRoot : I32EnumAttrCase<"MPI_ERR_ROOT", 48, "MPI_ERR_ROOT">;
+def MPI_CodeErrService
+    : I32EnumAttrCase<"MPI_ERR_SERVICE", 49, "MPI_ERR_SERVICE">;
+def MPI_CodeErrSession
+    : I32EnumAttrCase<"MPI_ERR_SESSION", 50, "MPI_ERR_SESSION">;
+def MPI_CodeErrSize : I32EnumAttrCase<"MPI_ERR_SIZE", 51, "MPI_ERR_SIZE">;
+def MPI_CodeErrSpawn : I32EnumAttrCase<"MPI_ERR_SPAWN", 52, "MPI_ERR_SPAWN">;
+def MPI_CodeErrTag : I32EnumAttrCase<"MPI_ERR_TAG", 53, "MPI_ERR_TAG">;
+def MPI_CodeErrTopology
+    : I32EnumAttrCase<"MPI_ERR_TOPOLOGY", 54, "MPI_ERR_TOPOLOGY">;
+def MPI_CodeErrTruncate
+    : I32EnumAttrCase<"MPI_ERR_TRUNCATE", 55, "MPI_ERR_TRUNCATE">;
+def MPI_CodeErrType : I32EnumAttrCase<"MPI_ERR_TYPE", 56, "MPI_ERR_TYPE">;
+def MPI_CodeErrUnknown
+    : I32EnumAttrCase<"MPI_ERR_UNKNOWN", 57, "MPI_ERR_UNKNOWN">;
+def MPI_CodeErrUnsupportedDatarep
+    : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_DATAREP", 58,
+                      "MPI_ERR_UNSUPPORTED_DATAREP">;
+def MPI_CodeErrUnsupportedOperation
+    : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_OPERATION", 59,
+                      "MPI_ERR_UNSUPPORTED_OPERATION">;
+def MPI_CodeErrValueTooLarge
+    : I32EnumAttrCase<"MPI_ERR_VALUE_TOO_LARGE", 60, "MPI_ERR_VALUE_TOO_LARGE">;
+def MPI_CodeErrWin : I32EnumAttrCase<"MPI_ERR_WIN", 61, "MPI_ERR_WIN">;
+def MPI_CodeErrLastcode
+    : I32EnumAttrCase<"MPI_ERR_LASTCODE", 62, "MPI_ERR_LASTCODE">;
+
+def MPI_ErrorClassEnum
+    : I32EnumAttr<"MPI_ErrorClassEnum", "MPI error class name", [
+      MPI_CodeSuccess,
+      MPI_CodeErrAccess,
+      MPI_CodeErrAmode,
+      MPI_CodeErrArg,
+      MPI_CodeErrAssert,
+      MPI_CodeErrBadFile,
+      MPI_CodeErrBase,
+      MPI_CodeErrBuffer,
+      MPI_CodeErrComm,
+      MPI_CodeErrConversion,
+      MPI_CodeErrCount,
+      MPI_CodeErrDims,
+      MPI_CodeErrDisp,
+      MPI_CodeErrDupDatarep,
+      MPI_CodeErrErrhandler,
+      MPI_CodeErrFile,
+      MPI_CodeErrFileExists,
+      MPI_CodeErrFileInUse,
+      MPI_CodeErrGroup,
+      MPI_CodeErrInfo,
+      MPI_CodeErrInfoKey,
+      MPI_CodeErrInfoNokey,
+      MPI_CodeErrInfoValue,
+      MPI_CodeErrInStatus,
+      MPI_CodeErrIntern,
+      MPI_CodeErrIo,
+      MPI_CodeErrKeyval,
+      MPI_CodeErrLocktype,
+      MPI_CodeErrName,
+      MPI_CodeErrNoMem,
+      MPI_CodeErrNoSpace,
+      MPI_CodeErrNoSuchFile,
+      MPI_CodeErrNotSame,
+      MPI_CodeErrOp,
+      MPI_CodeErrOther,
+      MPI_CodeErrPending,
+      MPI_CodeErrPort,
+      MPI_CodeErrProcAborted,
+      MPI_CodeErrQuota,
+      MPI_CodeErrRank,
+      MPI_CodeErrReadOnly,
+      MPI_CodeErrRequest,
+      MPI_CodeErrRmaAttach,
+      MPI_CodeErrRmaConflict,
+      MPI_CodeErrRmaFlavor,
+      MPI_CodeErrRmaRange,
+      MPI_CodeErrRmaShared,
+      MPI_CodeErrRmaSync,
+      MPI_CodeErrRoot,
+      MPI_CodeErrService,
+      MPI_CodeErrSession,
+      MPI_CodeErrSize,
+      MPI_CodeErrSpawn,
+      MPI_CodeErrTag,
+      MPI_CodeErrTopology,
+      MPI_CodeErrTruncate,
+      MPI_CodeErrType,
+      MPI_CodeErrUnknown,
+      MPI_CodeErrUnsupportedDatarep,
+      MPI_CodeErrUnsupportedOperation,
+      MPI_CodeErrValueTooLarge,
+      MPI_CodeErrWin,
+      MPI_CodeErrLastcode
+    ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::mpi";
+}
+
+def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+#endif // MLIR_DIALECT_MPI_IR_MPI_TD
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
new file mode 100644
index 00000000000000..768f376e24da4c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -0,0 +1,189 @@
+//===- MPIops.td - Message Passing Interface Ops -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MPI_MLIR_IR_MPIOPS_TD
+#define MPI_MLIR_IR_MPIOPS_TD
+
+include "mlir/Dialect/MPI/IR/MPI.td"
+include "mlir/Dialect/MPI/IR/MPITypes.td"
+
+class MPI_Op<string mnemonic, list<Trait> traits = []>
+    : Op<MPI_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// InitOp
+//===----------------------------------------------------------------------===//
+
+def MPI_InitOp : MPI_Op<"init", []> {
+  let summary =
+      "Initialize the MPI library, equivalent to `MPI_Init(NULL, NULL)`";
+  let description = [{
+    This operation must preceed most MPI calls (except for very few exceptions,
+    please consult with the MPI specification on these).
+
+    Passing &argc, &argv is not supported currently.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "attr-dict (`:` type($retval)^)?";
+}
+
+//===----------------------------------------------------------------------===//
+// CommRankOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
+  let summary = "Get the current rank, equivalent to "
+                "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
+  let description = [{
+    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let results = (
+    outs Optional<MPI_Retval> : $retval,
+    I32 : $rank
+  );
+
+  let assemblyFormat = "attr-dict `:` type(results)";
+}
+
+//===----------------------------------------------------------------------===//
+// SendOp
+//===----------------------------------------------------------------------===//
+
+def MPI_SendOp : MPI_Op<"send", []> {
+  let summary =
+      "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+  let description = [{
+    MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
+    `dest`. The `tag` value and communicator enables the library to determine 
+    the matching of multiple sends and receives between the same ranks.
+
+    Communicators other than `MPI_COMM_WORLD` are not supprted for now.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
+                       "type($ref) `,` type($tag) `,` type($rank)"
+                       "(`->` type($retval)^)?";
+}
+
+//===----------------------------------------------------------------------===//
+// RecvOp
+//===----------------------------------------------------------------------===//
+
+def MPI_RecvOp : MPI_Op<"recv", []> {
+  let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
+                "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
+  let description = [{
+    MPI_Recv performs a blocking receive of `size` elements of type `dtype` 
+    from rank `dest`. The `tag` value and communicator enables the library to 
+    determine the matching of multiple sends and receives between the same 
+    ranks.
+
+    Communicators other than `MPI_COMM_WORLD` are not supprted for now.
+    The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object 
+    is not yet ported to MLIR.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
+                       "type($ref) `,` type($tag) `,` type($rank)"
+                       "(`->` type($retval)^)?";
+}
+
+
+//===----------------------------------------------------------------------===//
+// FinalizeOp
+//===----------------------------------------------------------------------===//
+
+def MPI_FinalizeOp : MPI_Op<"finalize", []> {
+  let summary = "Finalize the MPI library, equivalent to `MPI_Finalize()`";
+  let description = [{
+    This function cleans up the MPI state. Afterwards, no MPI methods may 
+    be invoked (excpet for MPI_Get_version, MPI_Initialized, and MPI_Finalized).
+    Notably, MPI_Init cannot be called again in the same program.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "attr-dict (`:` type($retval)^)?";
+}
+
+
+//===----------------------------------------------------------------------===//
+// RetvalCheckOp
+//===----------------------------------------------------------------------===//
+
+def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
+  let summary = "Check an MPI return value against an error class";
+  let description = [{
+    This operation compares MPI status codes to known error class
+    constants such as `MPI_SUCCESS`, or `MPI_ERR_COMM`.
+  }];
+
+  let arguments = (
+    ins MPI_Retval:$val,
+    MPI_ErrorClassAttr:$errclass
+  );
+
+  let results = (
+    outs I1:$res
+  );
+
+  let assemblyFormat = "$val `=` $errclass attr-dict `:` type($res)";
+}
+
+
+
+//===----------------------------------------------------------------------===//
+// RetvalCheckOp
+//===----------------------------------------------------------------------===//
+
+def MPI_ErrorClassOp : MPI_Op<"error_class", []> {
+  let summary = "Get the error class from an error code, equivalent to "
+                "the `MPI_Error_class` function";
+  let description = [{
+    `MPI_Error_class` maps return values from MPI calls to a set of well-known
+    MPI error classes.
+  }];
+
+  let arguments = (
+    ins MPI_Retval:$val
+  );
+
+  let results = (
+    outs MPI_Retval:$errclass
+  );
+
+  let assemblyFormat = "$val attr-dict `:` type($val)";
+}
+
+#endif // MPI_MLIR_IR_MPIOPS_TD
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
new file mode 100644
index 00000000000000..109f3ca61e26d5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -0,0 +1,43 @@
+//===- MPITypes.td - Message Passing Interface types -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Message Passing Interface dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MPI_IR_MPITYPES_TD
+#define MLIR_DIALECT_MPI_IR_MPITYPES_TD
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/Dialect/MPI/IR/MPI.td"
+
+//===----------------------------------------------------------------------===//
+// MPI Types
+//===----------------------------------------------------------------------===//
+
+class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<MPI_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+//===----------------------------------------------------------------------===//
+// mpi::RetvalType
+//===----------------------------------------------------------------------===//
+
+def MPI_Retval : MPI_Type<"Retval", "retval"> {
+  let summary = "MPI function call return value";
+  let description = [{
+    This type represents a return value from an MPI function vall.
+    This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
+
+    This return value can be compared agains the known MPI error classes
+    represented by `#mpi.errclass` using the `mpi.retval_check` operation.
+  }];
+}
+
+#endif // MLIR_DIALECT_MPI_IR_MPITYPES_TD
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a..42087994d0f0c8 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -48,6 +48,7 @@
 #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
@@ -120,6 +121,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   memref::MemRefDialect,
                   mesh::MeshDialect,
                   ml_program::MLProgramDialect,
+                  mpi::MPIDialect,
                   nvgpu::NVGPUDialect,
                   NVVM::NVVMDialect,
                   omp::OpenMPDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 68776a695cac4d..c72107939cf42b 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -21,6 +21,7 @@ add_subdirectory(Math)
 add_subdirectory(MemRef)
 add_subdirectory(Mesh)
 add_subdirectory(MLProgram)
+add_subdirectory(MPI)
 add_subdirectory(NVGPU)
 add_subdirectory(OpenACC)
 add_subdirectory(OpenACCMPCommon)
diff --git a/mlir/lib/Dialect/MPI/CMakeLists.txt b/mlir/lib/Dialect/MPI/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/lib/Dialect/MPI/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/MPI/IR/CMakeLists.txt b/mlir/lib/Dialect/MPI/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..29d0b2379da747
--- /dev/null
+++ b/mlir/lib/Dialect/MPI/IR/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_dialect_library(MLIRMPIDialect
+  MPIOps.cpp
+  MPI.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MPI
+
+  DEPENDS
+  MLIRMPIIncGen
+  MLIRMPIOpsIncGen
+  MLIRMPITypesIncGen
+  MLIRMPIAttrsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRDialect
+  MLIRIR
+  MLIRInferTypeOpInterface
+  MLIRSideEffectInterfaces
+  )
diff --git a/mlir/lib/Dialect/MPI/IR/MPI.cpp b/mlir/lib/Dialect/MPI/IR/MPI.cpp
new file mode 100644
index 00000000000000..6c5f69febcd63d
--- /dev/null
+++ b/mlir/lib/Dialect/MPI/IR/MPI.cpp
@@ -0,0 +1,56 @@
+//===- MPI.cpp - MPI dialect implementation -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::mpi;
+
+//===----------------------------------------------------------------------===//
+/// Tablegen Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MPI/IR/MPI.cpp.inc"
+
+#include "mlir/Dialect/MPI/IR/MPIDialect.cpp.inc"
+
+void MPIDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
+      >();
+
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/MPI/IR/MPITypesGen.cpp.inc"
+      >();
+
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/MPI/IR/MPIAttrDefs.cpp.inc"
+      >();
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd dialect, type, and op definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/MPI/IR/MPITypesGen.cpp.inc"
+
+#include "mlir/Dialect/MPI/IR/MPIEnums.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/MPI/IR/MPIAttrDefs.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
new file mode 100644
index 00000000000000..ddd77b8f586ee0
--- /dev/null
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -0,0 +1,21 @@
+//===- MPIOps.cpp - MPI dialect ops implementation ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+
+using namespace mlir;
+using namespace mlir::mpi;
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
diff --git a/mlir/test/Dialect/MPI/invalid.mlir b/mlir/test/Dialect/MPI/invalid.mlir
new file mode 100644
index 00000000000000..1da154c7a58126
--- /dev/null
+++ b/mlir/test/Dialect/MPI/invalid.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+// expected-error @+1 {{op result #0 must be 32-bit signless integer, but got 'i64'}}
+%rank = mpi.comm_rank : i64
+
+// -----
+
+func.func @mpi_test(%ref : !llvm.ptr, %rank: i32) -> () {
+    // expected-error @+1 {{invalid kind of type specified}}
+    mpi.send(%ref, %rank, %rank) : !llvm.ptr, i32, i32
+
+    return
+}
+
+// -----
+
+func.func @mpi_test(%ref : !llvm.ptr, %rank: i32) -> () {
+    // expected-error @+1 {{invalid kind of type specified}}
+    mpi.recv(%ref, %rank, %rank) : !llvm.ptr, i32, i32
+
+    return
+}
+
+// -----
+
+func.func @mpi_test(%ref : memref<100xf32>, %rank: i32) -> () {
+    // expected-error @+1 {{'mpi.recv' op result #0 must be MPI function call return value, but got 'i32'}}
+    %res = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> i32
+
+    return
+}
+
+// -----
+
+func.func @mpi_test(%ref : memref<100xf32>, %rank: i32) -> () {
+    // expected-error @+1 {{'mpi.send' op result #0 must be MPI function call return value, but got 'i32'}}
+    %res = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> i32
+
+    return
+}
+
+// -----
+
+func.func @mpi_test(%retval: !mpi.retval) -> () {
+    // expected-error @+2 {{custom op 'mpi.retval_check' expected ::mlir::mpi::MpiErrorClassEnum}}
+    // expected-error @+1 {{custom op 'mpi.retval_check' failed to parse MpiErrorClassAttr parameter 'value'}}
+    %res = mpi.retval_check %retval = <MPI_ERR_DOES_NOT_EXIST>
+
+    return
+}
diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir
new file mode 100644
index 00000000000000..8f2421a73396c2
--- /dev/null
+++ b/mlir/test/Dialect/MPI/ops.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+func.func @mpi_test(%ref : memref<100xf32>) -> () {
+    // Note: the !mpi.retval result is optional on all operations except mpi.error_class
+
+    // CHECK: %0 = mpi.init : !mpi.retval
+    %err = mpi.init : !mpi.retval
+
+    // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+    %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+
+    // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+    mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32
+
+    // CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+    %err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+    // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+    mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32
+
+    // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+    %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+    // CHECK-NEXT: %3 = mpi.finalize : !mpi.retval
+    %rval = mpi.finalize : !mpi.retval
+
+    // CHECK-NEXT: %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+    %res = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+
+    // CHECK-NEXT: %5 = mpi.error_class %0 : !mpi.retval
+    %errclass = mpi.error_class %err : !mpi.retval
+
+    // CHECK-NEXT: return
+    func.return
+}



More information about the Mlir-commits mailing list