[Mlir-commits] [mlir] [MLIR] Extend MPI dialect (PR #123255)

Sergio Sánchez Ramírez llvmlistbot at llvm.org
Sun Jan 26 11:19:41 PST 2025


https://github.com/mofeing updated https://github.com/llvm/llvm-project/pull/123255

>From 1fbae54306aff0c55ec544675bedb961574c2837 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez+git at bsc.es>
Date: Thu, 16 Jan 2025 22:31:59 +0100
Subject: [PATCH 1/6] Add `MPI_Comm`, `MPI_Request`, `MPI_Status`, `MPI_Op`
 type definitions

---
 mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 30 ++++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index 87eefa719d45c0..1d96b49d16585b 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -40,4 +40,34 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
   }];
 }
 
+// TODO
+def MPI_Comm : MPI_Type<"Comm", "comm"> {
+  let summary = "..."
+  let description = [{
+    This type represents a handler to the MPI communicator.
+  }]
+}
+
+// TODO
+def MPI_Request : MPI_Type<"Request", "request"> {
+  let summary = "..."
+  let description = [{
+    This type represents a handler to an asynchronous requests.
+  }]
+}
+
+// TODO
+def MPI_Status : MPI_Type<"Status", "status"> {
+  let summary = "";
+  let description = [{
+  }];
+}
+
+// TODO
+def MPI_Op : MPI_Type<"Op", "op"> {
+  let summary = "";
+  let description = [{
+  }];
+}
+
 #endif // MLIR_DIALECT_MPI_IR_MPITYPES_TD

>From dc84ca4b87412dcfa9c83c48ae9916663c7ca38e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez+git at bsc.es>
Date: Thu, 16 Jan 2025 23:02:23 +0100
Subject: [PATCH 2/6] Add `MPI_CommSize`, `MPI_ISend`, `MPI_IRecv` ops

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 85 ++++++++++++++++++++++
 1 file changed, 85 insertions(+)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 240fac5104c34f..8719b67cd7f5f0 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -59,6 +59,28 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
   let assemblyFormat = "attr-dict `:` type(results)";
 }
 
+//===----------------------------------------------------------------------===//
+// CommSizeOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
+  let summary = "Get the size of the group associated to the communicator, equivalent to "
+                "`MPI_Comm_size(MPI_COMM_WORLD, &size)`";
+  let description = [{
+    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let results = (
+    outs Optional<MPI_Retval> : $retval,
+    I32 : $size
+  );
+
+  let assemblyFormat = "attr-dict `:` type(results)";
+}
+
 //===----------------------------------------------------------------------===//
 // SendOp
 //===----------------------------------------------------------------------===//
@@ -87,6 +109,37 @@ def MPI_SendOp : MPI_Op<"send", []> {
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ISendOp
+//===----------------------------------------------------------------------===//
+
+// TODO what about request handler?
+// NOTE datatype & count args are implicit by the type of the first argument (i.e. memref of eltype)
+// NOTE other communicators not yet supported by the `mpi` dialect
+def MPI_ISendOp : MPI_Op<"isend", []> {
+  let summary =
+      "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+  let description = [{
+    MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to rank
+    `dest`. The `tag` value and communicator enables the library to determine 
+    the matching of multiple sends and receives between the same ranks.
+
+    Communicators other than `MPI_COMM_WORLD` are not supprted for now.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
+                       "type($ref) `,` type($tag) `,` type($rank)"
+                       "(`->` type($retval)^)?";
+  let hasCanonicalizer = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // RecvOp
 //===----------------------------------------------------------------------===//
@@ -118,6 +171,38 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// IRecvOp
+//===----------------------------------------------------------------------===//
+
+// TODO same as MPI_ISendOp
+def MPI_IRecvOp : MPI_Op<"irecv", []> {
+  let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
+                "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
+  let description = [{
+    MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` 
+    from rank `dest`. The `tag` value and communicator enables the library to 
+    determine the matching of multiple sends and receives between the same 
+    ranks.
+
+    Communicators other than `MPI_COMM_WORLD` are not supprted for now.
+    The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object 
+    is not yet ported to MLIR.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
+                       "type($ref) `,` type($tag) `,` type($rank)"
+                       "(`->` type($retval)^)?";
+  let hasCanonicalizer = 1;
+}
+
 
 //===----------------------------------------------------------------------===//
 // FinalizeOp

>From 2ee10ab60bb793b9164b0795e6657ceccc4704ae Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez+git at bsc.es>
Date: Thu, 16 Jan 2025 23:02:34 +0100
Subject: [PATCH 3/6] Fix typo

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 8719b67cd7f5f0..4be5a6dfea7777 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -251,7 +251,7 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
 
 
 //===----------------------------------------------------------------------===//
-// RetvalCheckOp
+// ErrorClassOp
 //===----------------------------------------------------------------------===//
 
 def MPI_ErrorClassOp : MPI_Op<"error_class", []> {

>From 539bf43b5cf705e64183d7d84f08e7132dac3872 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez at bsc.es>
Date: Sat, 25 Jan 2025 20:59:08 +0100
Subject: [PATCH 4/6] Finish types

---
 mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 24 ++++++++++++++++----
 1 file changed, 19 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index 1d96b49d16585b..20cde07d9a4b98 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -40,7 +40,10 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
   }];
 }
 
-// TODO
+//===----------------------------------------------------------------------===//
+// mpi::CommType
+//===----------------------------------------------------------------------===//
+
 def MPI_Comm : MPI_Type<"Comm", "comm"> {
   let summary = "..."
   let description = [{
@@ -48,25 +51,36 @@ def MPI_Comm : MPI_Type<"Comm", "comm"> {
   }]
 }
 
-// TODO
+//===----------------------------------------------------------------------===//
+// mpi::RequestType
+//===----------------------------------------------------------------------===//
+
 def MPI_Request : MPI_Type<"Request", "request"> {
   let summary = "..."
   let description = [{
-    This type represents a handler to an asynchronous requests.
+    This type represents a handler to an asynchronous request.
   }]
 }
 
-// TODO
+//===----------------------------------------------------------------------===//
+// mpi::StatusType
+//===----------------------------------------------------------------------===//
+
 def MPI_Status : MPI_Type<"Status", "status"> {
   let summary = "";
   let description = [{
+    This type represents the status of a reception operation.
   }];
 }
 
-// TODO
+//===----------------------------------------------------------------------===//
+// mpi::OpType
+//===----------------------------------------------------------------------===//
+
 def MPI_Op : MPI_Type<"Op", "op"> {
   let summary = "";
   let description = [{
+    This type represents a handle to a operation that can be used in MPI reduce and scan routines.
   }];
 }
 

>From 662998d610c56eb83dd9897a94b8bd4131e95191 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez at bsc.es>
Date: Sun, 26 Jan 2025 11:35:36 +0100
Subject: [PATCH 5/6] Define `MPI_Op`  enum & attr

---
 mlir/include/mlir/Dialect/MPI/IR/MPI.td | 40 +++++++++++++++++++++++++
 1 file changed, 40 insertions(+)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 643612e1e2ee89..182de03a5a8057 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -215,4 +215,44 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
+// TODO is it ok to have them as I32?
+def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">;
+def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">;
+def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">;
+def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">;
+def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">;
+def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">;
+def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">;
+def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">;
+def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">;
+def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">;
+def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">;
+def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
+def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
+def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
+
+def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
+      MPI_OpNull,
+      MPI_OpMax,
+      MPI_OpMin,
+      MPI_OpSum,
+      MPI_OpProd,
+      MPI_OpLand,
+      MPI_OpBand,
+      MPI_OpLor,
+      MPI_OpBor,
+      MPI_OpLxor,
+      MPI_OpBxor,
+      MPI_OpMinloc,
+      MPI_OpMaxloc,
+      MPI_OpReplace
+    ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::mpi";
+}
+
+def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 #endif // MLIR_DIALECT_MPI_IR_MPI_TD

>From c1ec63c24ff9b1af1a4dde393b1e90767605044b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez at bsc.es>
Date: Sun, 26 Jan 2025 20:18:36 +0100
Subject: [PATCH 6/6] Add communicator argument to mpi ops as optional input
 argument

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 38 ++++++++++++----------
 1 file changed, 20 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 4be5a6dfea7777..1330313c41a8c3 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -43,14 +43,16 @@ def MPI_InitOp : MPI_Op<"init", []> {
 
 def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
   let summary = "Get the current rank, equivalent to "
-                "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
+                "`MPI_Comm_rank(comm, &rank)`";
   let description = [{
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
+  let arguments = (ins Optional<MPI_Comm> : $comm);
+
   let results = (
     outs Optional<MPI_Retval> : $retval,
     I32 : $rank
@@ -65,14 +67,16 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
 
 def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
   let summary = "Get the size of the group associated to the communicator, equivalent to "
-                "`MPI_Comm_size(MPI_COMM_WORLD, &size)`";
+                "`MPI_Comm_size(comm, &size)`";
   let description = [{
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
+  let arguments = (ins Optional<MPI_Comm> : $comm);
+
   let results = (
     outs Optional<MPI_Retval> : $retval,
     I32 : $size
@@ -87,19 +91,19 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
 
 def MPI_SendOp : MPI_Op<"send", []> {
   let summary =
-      "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+      "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
   let description = [{
     MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
     `dest`. The `tag` value and communicator enables the library to determine 
     the matching of multiple sends and receives between the same ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supprted for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
-  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
+  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional<MPI_Comm> : $comm);
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
@@ -115,22 +119,21 @@ def MPI_SendOp : MPI_Op<"send", []> {
 
 // TODO what about request handler?
 // NOTE datatype & count args are implicit by the type of the first argument (i.e. memref of eltype)
-// NOTE other communicators not yet supported by the `mpi` dialect
 def MPI_ISendOp : MPI_Op<"isend", []> {
   let summary =
-      "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+      "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
   let description = [{
     MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to rank
     `dest`. The `tag` value and communicator enables the library to determine 
     the matching of multiple sends and receives between the same ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supprted for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
-  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
+  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional<MPI_Comm> : $comm);
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
@@ -146,14 +149,14 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
 
 def MPI_RecvOp : MPI_Op<"recv", []> {
   let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
-                "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
+                "comm, MPI_STATUS_IGNORE)`";
   let description = [{
     MPI_Recv performs a blocking receive of `size` elements of type `dtype` 
     from rank `dest`. The `tag` value and communicator enables the library to 
     determine the matching of multiple sends and receives between the same 
     ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supprted for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
     The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object 
     is not yet ported to MLIR.
 
@@ -161,7 +164,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
     to check for errors.
   }];
 
-  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
+  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional<MPI_Comm> : $comm);
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
@@ -175,17 +178,16 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
 // IRecvOp
 //===----------------------------------------------------------------------===//
 
-// TODO same as MPI_ISendOp
 def MPI_IRecvOp : MPI_Op<"irecv", []> {
   let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
-                "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
+                "comm, MPI_STATUS_IGNORE)`";
   let description = [{
     MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` 
     from rank `dest`. The `tag` value and communicator enables the library to 
     determine the matching of multiple sends and receives between the same 
     ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supprted for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
     The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object 
     is not yet ported to MLIR.
 
@@ -193,7 +195,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
     to check for errors.
   }];
 
-  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
+  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional<MPI_Comm> : $comm);
 
   let results = (outs Optional<MPI_Retval>:$retval);
 



More information about the Mlir-commits mailing list