[Mlir-commits] [mlir] [mlir][mpi] Mandatory Communicator (PR #133280)
Frank Schlimbach
llvmlistbot at llvm.org
Mon Mar 31 09:34:23 PDT 2025
Sergio =?utf-8?q?Sánchez_Ramírez?=,
Sergio =?utf-8?q?Sánchez_Ramírez?=,
Sergio =?utf-8?q?Sánchez_Ramírez?=,
Sergio =?utf-8?q?Sánchez_Ramírez?=,
Sergio =?utf-8?q?Sánchez_Ramírez?=,"Schlimbach, Frank"
<frank.schlimbach at intel.com>,"Schlimbach, Frank" <frank.schlimbach at intel.com>
=?utf-8?q?,?=Frank Schlimbach <frank.schlimbach at intel.com>,Frank
Schlimbach <frank.schlimbach at intel.com>,"Schlimbach, Frank"
<frank.schlimbach at intel.com>,"Schlimbach, Frank" <frank.schlimbach at intel.com>
=?utf-8?q?,?="Schlimbach, Frank" <frank.schlimbach at intel.com>,Schlimbach,
Frank <frank.schlimbach at intel.com>=?utf-8?q?,?="Schlimbach, Frank"
<frank.schlimbach at intel.com>,Frank Schlimbach <frank.schlimbach at intel.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/133280 at github.com>
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/133280
>From 82c1e60878b31ff3232a487ff7d54e8fdf6deb85 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
<sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 20:43:30 +0100
Subject: [PATCH 01/16] Revert "Remove MPI_Comm type"
This reverts commit 6abba5a37d5ea73c2b177581db9d476da4a26c91.
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 132 +++++++++++++------
mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 11 ++
mlir/test/Dialect/MPI/mpiops.mlir | 26 +++-
3 files changed, 128 insertions(+), 41 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index a8267b115b9e6..3c6f5a8ac0ea8 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -37,26 +37,43 @@ def MPI_InitOp : MPI_Op<"init", []> {
let assemblyFormat = "attr-dict (`:` type($retval)^)?";
}
+//===----------------------------------------------------------------------===//
+// CommWorldOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
+ let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
+ let description = [{
+ This operation returns the predefined MPI_COMM_WORLD communicator.
+ }];
+
+ let results = (outs MPI_Comm : $comm);
+
+ let assemblyFormat = "attr-dict `:` type(results)";
+}
+
//===----------------------------------------------------------------------===//
// CommRankOp
//===----------------------------------------------------------------------===//
def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
let summary = "Get the current rank, equivalent to "
- "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
+ "`MPI_Comm_rank(comm, &rank)`";
let description = [{
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
+ let arguments = (ins Optional<MPI_Comm> : $comm);
+
let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $rank
);
- let assemblyFormat = "attr-dict `:` type(results)";
+ let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)";
}
//===----------------------------------------------------------------------===//
@@ -65,20 +82,50 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
let summary = "Get the size of the group associated to the communicator, "
- "equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`";
+ "equivalent to `MPI_Comm_size(comm, &size)`";
let description = [{
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
+ let arguments = (ins Optional<MPI_Comm> : $comm);
+
let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $size
);
- let assemblyFormat = "attr-dict `:` type(results)";
+ let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)";
+}
+
+//===----------------------------------------------------------------------===//
+// CommSplitOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommSplit : MPI_Op<"comm_split", []> {
+ let summary = "Partition the group associated to the given communicator into "
+ "disjoint subgroups";
+ let description = [{
+ This operation splits the communicator into multiple sub-communicators.
+ The color value determines the group of processes that will be part of the
+ new communicator. The key value determines the rank of the calling process
+ in the new communicator.
+
+ This operation can optionally return an `!mpi.retval` value that can be used
+ to check for errors.
+ }];
+
+ let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);
+
+ let results = (
+ outs MPI_Comm : $newcomm,
+ Optional<MPI_Retval> : $retval
+ );
+
+ let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
+ "type(results)";
}
//===----------------------------------------------------------------------===//
@@ -87,13 +134,13 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
def MPI_SendOp : MPI_Op<"send", []> {
let summary =
- "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+ "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
`dest`. The `tag` value and communicator enables the library to determine
the matching of multiple sends and receives between the same ranks.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
@@ -102,12 +149,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
- I32 : $dest
+ I32 : $dest,
+ Optional<MPI_Comm> : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
+ let assemblyFormat = "`(` $ref `,` $tag `,` $dest (`,` $comm ^)? `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($dest)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
@@ -119,14 +167,14 @@ def MPI_SendOp : MPI_Op<"send", []> {
def MPI_ISendOp : MPI_Op<"isend", []> {
let summary =
- "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+ "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
@@ -135,7 +183,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
- I32 : $rank
+ I32 : $rank,
+ Optional<MPI_Comm> : $comm
);
let results = (
@@ -143,9 +192,9 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
MPI_Request : $req
);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict "
+ let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank) "
- "`->` type(results)";
+ "(`,` type($comm) ^)? `->` type(results)";
let hasCanonicalizer = 1;
}
@@ -154,15 +203,15 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
//===----------------------------------------------------------------------===//
def MPI_RecvOp : MPI_Op<"recv", []> {
- let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
- "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
+ let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
+ "comm, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
from rank `source`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.
@@ -172,14 +221,15 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
let arguments = (
ins AnyMemRef : $ref,
- I32 : $tag, I32 : $source
+ I32 : $tag, I32 : $source,
+ Optional<MPI_Comm> : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
- "type($ref) `,` type($tag) `,` type($source)"
- "(`->` type($retval)^)?";
+ let assemblyFormat = "`(` $ref `,` $tag `,` $source (`,` $comm ^)?`)` attr-dict"
+ " `:` type($ref) `,` type($tag) `,` type($source) "
+ "(`,` type($comm) ^)? (`->` type($retval)^)?";
let hasCanonicalizer = 1;
}
@@ -189,14 +239,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
def MPI_IRecvOp : MPI_Op<"irecv", []> {
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
- "MPI_COMM_WORLD, &req)`";
+ "comm, &req)`";
let description = [{
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
from rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
@@ -205,7 +255,8 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
- I32 : $rank
+ I32 : $rank,
+ Optional<MPI_Comm> : $comm
);
let results = (
@@ -213,9 +264,9 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
MPI_Request : $req
);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
- "type($ref) `,` type($tag) `,` type($rank) `->`"
- "type(results)";
+ let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict "
+ "`:` type($ref) `,` type($tag) `,` type($rank)"
+ "(`,` type($comm) ^)? `->` type(results)";
let hasCanonicalizer = 1;
}
@@ -224,8 +275,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
//===----------------------------------------------------------------------===//
def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
- let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, "
- "MPI_COMM_WORLD)`";
+ let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`";
let description = [{
MPI_Allreduce performs a reduction operation on the values in the sendbuf
array and stores the result in the recvbuf array. The operation is
@@ -235,7 +285,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
supported.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
@@ -244,14 +294,15 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
- MPI_OpClassEnum : $op
+ MPI_OpClassAttr : $op,
+ Optional<MPI_Comm> : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`"
- "type($sendbuf) `,` type($recvbuf)"
- "(`->` type($retval)^)?";
+ let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm ^)?`)` "
+ "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
+ "(`,` type($comm) ^)? (`->` type($retval)^)?";
}
//===----------------------------------------------------------------------===//
@@ -259,20 +310,22 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
//===----------------------------------------------------------------------===//
def MPI_Barrier : MPI_Op<"barrier", []> {
- let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`";
+ let summary = "Equivalent to `MPI_Barrier(comm)`";
let description = [{
MPI_Barrier blocks execution until all processes in the communicator have
reached this routine.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
+ let arguments = (ins Optional<MPI_Comm> : $comm);
+
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "attr-dict (`:` type($retval) ^)?";
+ let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($retval) ^)?";
}
//===----------------------------------------------------------------------===//
@@ -295,8 +348,7 @@ def MPI_Wait : MPI_Op<"wait", []> {
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) "
- "(`->` type($retval) ^)?";
+ let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index a55d30e778e22..b56a224d84774 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -40,6 +40,17 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
}];
}
+//===----------------------------------------------------------------------===//
+// mpi::CommType
+//===----------------------------------------------------------------------===//
+
+def MPI_Comm : MPI_Type<"Comm", "comm"> {
+ let summary = "MPI communicator handler";
+ let description = [{
+ This type represents a handler to the MPI communicator.
+ }];
+}
+
//===----------------------------------------------------------------------===//
// mpi::RequestType
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index fb4333611a246..7fb7323870719 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -12,30 +12,48 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
%retval_0, %size = mpi.comm_size : !mpi.retval, i32
+ // CHECK-NEXT: %comm = mpi.comm_world : !mpi.comm
+ %comm = mpi.comm_world : !mpi.comm
+
+ // CHECK-NEXT: %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : i32, !mpi.retval
+ %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : mpi.comm, i32, i32
+
// CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32
// CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
%err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ // CHECK-NEXT: mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
+ mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
+
// CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32
// CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
%err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
+ mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
+
// CHECK-NEXT: %req = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
%req = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
// CHECK-NEXT: %retval_1, %req_2 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
%err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+ // CHECK-NEXT: %3 = mpi.isend(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> !mpi.request
+ %req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm -> !mpi.request
+
// CHECK-NEXT: %req_3 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
%req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
// CHECK-NEXT: %retval_4, %req_5 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
%err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+ // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request
+ %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm -> !mpi.request
+
// CHECK-NEXT: mpi.wait(%req) : !mpi.request
mpi.wait(%req) : !mpi.request
@@ -48,12 +66,18 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
%err7 = mpi.barrier : !mpi.retval
- // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+ // CHECK-NEXT: mpi.barrier(%comm) : !mpi.retval
+ mpi.barrier(%comm) : !mpi.retval
+
+ // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32>
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
%err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, !mpi.comm
+ mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, !mpi.comm
+
// CHECK-NEXT: %7 = mpi.finalize : !mpi.retval
%rval = mpi.finalize : !mpi.retval
>From 8d596ae3dd89c484fed21d0d3d9cc9637010fa86 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
<sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 20:52:39 +0100
Subject: [PATCH 02/16] Fix assembly format for `comm_size`, `comm_rank`
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 3c6f5a8ac0ea8..cae411599c381 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -73,7 +73,8 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
I32 : $rank
);
- let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)";
+ let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` (type($comm) ^ `->`)?"
+ "type(results)";
}
//===----------------------------------------------------------------------===//
@@ -97,7 +98,8 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
I32 : $size
);
- let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)";
+ let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` (type($comm) ^ `->`)?"
+ "type(results)";
}
//===----------------------------------------------------------------------===//
>From 25635931b982c6b1b5dde6e83f2be77221c4ff3e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
<sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 20:53:03 +0100
Subject: [PATCH 03/16] add more tests for `comm_size`, `comm_rank`
---
mlir/test/Dialect/MPI/mpiops.mlir | 22 ++++++++++++++++++++--
1 file changed, 20 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index 7fb7323870719..d5c10a9faa774 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -6,14 +6,32 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK: %0 = mpi.init : !mpi.retval
%err = mpi.init : !mpi.retval
+ // CHECK-NEXT: %comm = mpi.comm_world : !mpi.comm
+ %comm = mpi.comm_world : !mpi.comm
+
+ // CHECK-NEXT: %rank = mpi.comm_rank : i32
+ %rank = mpi.comm_rank : i32
+
// CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32
%retval, %rank = mpi.comm_rank : !mpi.retval, i32
+ // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.comm -> i32
+ %rank = mpi.comm_rank(%comm) : !mpi.comm -> i32
+
+ // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.comm -> !mpi.retval, i32
+ %retval, %rank = mpi.comm_rank(%comm) : !mpi.comm -> !mpi.retval, i32
+
+ // CHECK-NEXT: %size = mpi.comm_size : i32
+ %size = mpi.comm_size : i32
+
// CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
%retval_0, %size = mpi.comm_size : !mpi.retval, i32
- // CHECK-NEXT: %comm = mpi.comm_world : !mpi.comm
- %comm = mpi.comm_world : !mpi.comm
+ // CHECK-NEXT: %size = mpi.comm_size : !mpi.comm -> i32
+ %size = mpi.comm_size(%comm) : !mpi.comm -> i32
+
+ // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
+ %retval_0, %size = mpi.comm_size(%comm) : !mpi.comm -> !mpi.retval, i32
// CHECK-NEXT: %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : i32, !mpi.retval
%new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : mpi.comm, i32, i32
>From 8fd64632e015d20475a793f5d6c9e1cf5b94c1b0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
<sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 22:19:45 +0100
Subject: [PATCH 04/16] fix some assembly formats
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index cae411599c381..70868cda9952d 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -73,8 +73,8 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
I32 : $rank
);
- let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` (type($comm) ^ `->`)?"
- "type(results)";
+ let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($comm) ^ `->`):"
+ "(`:`)? type(results)";
}
//===----------------------------------------------------------------------===//
@@ -98,8 +98,8 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
I32 : $size
);
- let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` (type($comm) ^ `->`)?"
- "type(results)";
+ let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($comm) ^ `->`):"
+ "(`:`)? type(results)";
}
//===----------------------------------------------------------------------===//
@@ -122,11 +122,12 @@ def MPI_CommSplit : MPI_Op<"comm_split", []> {
let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);
let results = (
- outs MPI_Comm : $newcomm,
- Optional<MPI_Retval> : $retval
+ outs Optional<MPI_Retval> : $retval,
+ MPI_Comm : $newcomm
);
let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
+ "type($comm) `,` type($color) `,` type($key) `->` "
"type(results)";
}
>From b7080d386b44c693aca402c644db2fb99e7040c3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
<sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 22:20:04 +0100
Subject: [PATCH 05/16] fix some tests
---
mlir/test/Dialect/MPI/mpiops.mlir | 18 ++++++++++++------
1 file changed, 12 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index d5c10a9faa774..1cf2f2de4cf45 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -33,8 +33,11 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
%retval_0, %size = mpi.comm_size(%comm) : !mpi.comm -> !mpi.retval, i32
- // CHECK-NEXT: %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : i32, !mpi.retval
- %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : mpi.comm, i32, i32
+ // CHECK-NEXT: %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.comm
+ %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32
+
+ // CHECK-NEXT: %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.retval, !mpi.comm
+ %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.retval, !mpi.comm
// CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32
@@ -78,14 +81,17 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK-NEXT: %3 = mpi.wait(%req_2) : !mpi.request -> !mpi.retval
%err6 = mpi.wait(%req2) : !mpi.request -> !mpi.retval
- // CHECK-NEXT: mpi.barrier : !mpi.retval
- mpi.barrier : !mpi.retval
+ // CHECK-NEXT: mpi.barrier
+ mpi.barrier
// CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
%err7 = mpi.barrier : !mpi.retval
- // CHECK-NEXT: mpi.barrier(%comm) : !mpi.retval
- mpi.barrier(%comm) : !mpi.retval
+ // CHECK-NEXT: mpi.barrier(%comm)
+ mpi.barrier(%comm)
+
+ // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
+ %err7 = mpi.barrier : !mpi.retval
// CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32>
>From fd9ad01452dd7e1dc60189d17dd98cbb90df7b9f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
<sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 22:20:43 +0100
Subject: [PATCH 06/16] try fixing the assembly forma of `barrier`
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 13 ++++++++++++-
1 file changed, 12 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 70868cda9952d..c76a5ee081552 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -328,7 +328,18 @@ def MPI_Barrier : MPI_Op<"barrier", []> {
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($retval) ^)?";
+ // TODO fix assembly format
+ // let assemblyFormat = "("
+ // "(attr-dict) ^"
+ // "(attr-dict `:` type($retval)) ^"
+ // "(`(` $comm `)` attr-dict `:` type($comm)) ^"
+ // "(`(` $comm `)` attr-dict `:` type($comm) `->` type($retval))"
+ // ")?";
+ let assemblyFormat = [{
+ (`(` $comm ^ `)`)? attr-dict
+ (`:` type($comm) ^ `->`):(`:`)?
+ type(results)
+ }];
}
//===----------------------------------------------------------------------===//
>From 7d1a33db75839fd6a41865fd1708767878c0cdb6 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Mar 2025 17:14:15 +0100
Subject: [PATCH 07/16] making communicator mandatory, fixing dependent code
and tests
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 45 +++----
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 24 +++-
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 22 ++--
.../MeshToMPI/convert-mesh-to-mpi.mlir | 62 +++++----
mlir/test/Dialect/MPI/mpiops.mlir | 121 +++++++-----------
5 files changed, 129 insertions(+), 145 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index c76a5ee081552..6bc25054bf48a 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -73,8 +73,7 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
I32 : $rank
);
- let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($comm) ^ `->`):"
- "(`:`)? type(results)";
+ let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
}
//===----------------------------------------------------------------------===//
@@ -91,15 +90,14 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
to check for errors.
}];
- let arguments = (ins Optional<MPI_Comm> : $comm);
+ let arguments = (ins MPI_Comm : $comm);
let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $size
);
- let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($comm) ^ `->`):"
- "(`:`)? type(results)";
+ let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
}
//===----------------------------------------------------------------------===//
@@ -127,7 +125,7 @@ def MPI_CommSplit : MPI_Op<"comm_split", []> {
);
let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
- "type($comm) `,` type($color) `,` type($key) `->` "
+ "type($color) `,` type($key) `->` "
"type(results)";
}
@@ -153,12 +151,12 @@ def MPI_SendOp : MPI_Op<"send", []> {
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $dest,
- Optional<MPI_Comm> : $comm
+ MPI_Comm : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $ref `,` $tag `,` $dest (`,` $comm ^)? `)` attr-dict `:` "
+ let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($dest)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
@@ -187,7 +185,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank,
- Optional<MPI_Comm> : $comm
+ MPI_Comm : $comm
);
let results = (
@@ -195,9 +193,9 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
MPI_Request : $req
);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict "
+ let assemblyFormat = "`(` $ref `,` $tag `,` $rank `,` $comm`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank) "
- "(`,` type($comm) ^)? `->` type(results)";
+ "`->` type(results)";
let hasCanonicalizer = 1;
}
@@ -225,14 +223,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag, I32 : $source,
- Optional<MPI_Comm> : $comm
+ MPI_Comm : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $ref `,` $tag `,` $source (`,` $comm ^)?`)` attr-dict"
+ let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm `)` attr-dict"
" `:` type($ref) `,` type($tag) `,` type($source) "
- "(`,` type($comm) ^)? (`->` type($retval)^)?";
+ "(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}
@@ -259,7 +257,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank,
- Optional<MPI_Comm> : $comm
+ MPI_Comm : $comm
);
let results = (
@@ -267,9 +265,9 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
MPI_Request : $req
);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict "
+ let assemblyFormat = "`(` $ref `,` $tag `,` $rank `,` $comm`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank)"
- "(`,` type($comm) ^)? `->` type(results)";
+ "`->` type(results)";
let hasCanonicalizer = 1;
}
@@ -298,14 +296,14 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
MPI_OpClassAttr : $op,
- Optional<MPI_Comm> : $comm
+ MPI_Comm : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm ^)?`)` "
+ let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
"attr-dict `:` type($sendbuf) `,` type($recvbuf) "
- "(`,` type($comm) ^)? (`->` type($retval)^)?";
+ "(`->` type($retval)^)?";
}
//===----------------------------------------------------------------------===//
@@ -324,7 +322,7 @@ def MPI_Barrier : MPI_Op<"barrier", []> {
to check for errors.
}];
- let arguments = (ins Optional<MPI_Comm> : $comm);
+ let arguments = (ins MPI_Comm : $comm);
let results = (outs Optional<MPI_Retval>:$retval);
@@ -336,9 +334,8 @@ def MPI_Barrier : MPI_Op<"barrier", []> {
// "(`(` $comm `)` attr-dict `:` type($comm) `->` type($retval))"
// ")?";
let assemblyFormat = [{
- (`(` $comm ^ `)`)? attr-dict
- (`:` type($comm) ^ `->`):(`:`)?
- type(results)
+ `(` $comm `)` attr-dict
+ (`->` type($retval)^)?
}];
}
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 87c2938e4e52b..cafbf835de22f 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -310,11 +310,16 @@ class ConvertProcessLinearIndexOp
}
// Otherwise call create mpi::CommRankOp
- auto rank = rewriter
- .create<mpi::CommRankOp>(
- loc, TypeRange{mpi::RetvalType::get(op->getContext()),
- rewriter.getI32Type()})
- .getRank();
+ auto ctx = op.getContext();
+ Value commWorld =
+ rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
+ auto rank =
+ rewriter
+ .create<mpi::CommRankOp>(
+ loc,
+ TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
+ commWorld)
+ .getRank();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
rank);
return success();
@@ -652,6 +657,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto upperSendOffset = rewriter.create<arith::SubIOp>(
loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
+ Value commWorld = rewriter.create<mpi::CommWorldOp>(
+ loc, mpi::CommType::get(op->getContext()));
+
// Make sure we send/recv in a way that does not lead to a dead-lock.
// The current approach is by far not optimal, this should be at least
// be a red-black pattern or using MPI_sendrecv.
@@ -680,7 +688,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto subview = builder.create<memref::SubViewOp>(
loc, array, offsets, dimSizes, strides);
builder.create<memref::CopyOp>(loc, subview, buffer);
- builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
+ builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to,
+ commWorld);
builder.create<scf::YieldOp>(loc);
});
// if has neighbor: receive halo data into buffer and copy to array
@@ -688,7 +697,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
loc, hasFrom, [&](OpBuilder &builder, Location loc) {
offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
: OpFoldResult(lowerRecvOffset);
- builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
+ builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from,
+ commWorld);
auto subview = builder.create<memref::SubViewOp>(
loc, array, offsets, dimSizes, strides);
builder.create<memref::CopyOp>(loc, buffer, subview);
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index b630ce3a23f30..c70077d507ad1 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -22,11 +22,12 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
%0 = mpi.init : !mpi.retval
+ %comm = mpi.comm_world : !mpi.comm
// CHECK: [[v8:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32
- %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+ %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
// CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -37,7 +38,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v19:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
- mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+ mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -47,7 +48,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v27:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
- %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ %1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -59,7 +60,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
- mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+ mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -71,7 +72,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
- %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -122,11 +123,12 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
%0 = mpi.init : !mpi.retval
+ %comm = mpi.comm_world : !mpi.comm
// CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
- %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+ %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
// CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -137,7 +139,7 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v19:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
- mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+ mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -147,7 +149,7 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v27:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
- %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ %1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -159,7 +161,7 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
- mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+ mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -171,7 +173,7 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
- %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 4e60c6f0d4e44..23756bb66928d 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -4,7 +4,7 @@
// CHECK: mesh.mesh @mesh0
mesh.mesh @mesh0(shape = 3x4x5)
func.func @process_multi_index() -> (index, index, index) {
- // CHECK: mpi.comm_rank : !mpi.retval, i32
+ // CHECK: mpi.comm_rank
// CHECK-DAG: %[[v4:.*]] = arith.remsi
// CHECK-DAG: %[[v0:.*]] = arith.remsi
// CHECK-DAG: %[[v1:.*]] = arith.remsi
@@ -15,7 +15,7 @@ func.func @process_multi_index() -> (index, index, index) {
// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
- // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank : !mpi.retval, i32
+ // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank
// CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index
%0 = mesh.process_linear_index on @mesh0 : index
// CHECK: return %[[cast]] : index
@@ -113,17 +113,17 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
// CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
// CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
// CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
- // CHECK-SAME: to memref<2x120x120xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
- // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> to memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
- // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
+ // CHECK: return [[varg0]] : memref<120x120x120xi8>
return %res : memref<120x120x120xi8>
}
}
@@ -140,41 +140,44 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
// CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
// CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+ // CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
// CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
// CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
- // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
// CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
// CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+ // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
// CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
// CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x3x120xi8>, i32, i32
// CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
// CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x4x120xi8>, i32, i32
// CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
// CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
// CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+ // CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<1x120x120xi8>, i32, i32
// CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
// CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
// CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
// CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32
// CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
// CHECK: return [[varg0]] : memref<120x120x120xi8>
@@ -191,45 +194,48 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
// CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
+ // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
// CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
// CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
- // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
// CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
// CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+ // CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
// CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
// CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x3x120xi8>, i32, i32
// CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
// CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x4x120xi8>, i32, i32
// CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
// CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
// CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+ // CHECK-NEXT: [[v3:%.*]] = mpi.comm_world : !mpi.comm
// CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<1x120x120xi8>, i32, i32
// CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
// CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
// CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
// CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
// CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32
// CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
- // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
+ // CHECK-NEXT: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
- // CHECK: return [[v1]] : tensor<120x120x120xi8>
+ // CHECK-NEXT: return [[v4]] : tensor<120x120x120xi8>
return %res : tensor<120x120x120xi8>
}
}
diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index 1cf2f2de4cf45..265687270e671 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -1,114 +1,83 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// CHECK-LABEL: func.func @mpi_test(
+// CHECK-SAME: [[varg0:%.*]]: memref<100xf32>) {
func.func @mpi_test(%ref : memref<100xf32>) -> () {
// Note: the !mpi.retval result is optional on all operations except mpi.error_class
- // CHECK: %0 = mpi.init : !mpi.retval
+ // CHECK-NEXT: [[v0:%.*]] = mpi.init : !mpi.retval
%err = mpi.init : !mpi.retval
- // CHECK-NEXT: %comm = mpi.comm_world : !mpi.comm
+ // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
%comm = mpi.comm_world : !mpi.comm
- // CHECK-NEXT: %rank = mpi.comm_rank : i32
- %rank = mpi.comm_rank : i32
+ // CHECK-NEXT: [[vrank:%.*]] = mpi.comm_rank([[v1]]) : i32
+ %rank = mpi.comm_rank(%comm) : i32
- // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32
- %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+ // CHECK-NEXT: [[vretval:%.*]], [[vrank_0:%.*]] = mpi.comm_rank([[v1]]) : !mpi.retval, i32
+ %retval, %rank_1 = mpi.comm_rank(%comm) : !mpi.retval, i32
- // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.comm -> i32
- %rank = mpi.comm_rank(%comm) : !mpi.comm -> i32
+ // CHECK-NEXT: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32
+ %size = mpi.comm_size(%comm) : i32
- // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.comm -> !mpi.retval, i32
- %retval, %rank = mpi.comm_rank(%comm) : !mpi.comm -> !mpi.retval, i32
+ // CHECK-NEXT: [[vretval_1:%.*]], [[vsize_2:%.*]] = mpi.comm_size([[v1]]) : !mpi.retval, i32
+ %retval_0, %size_1 = mpi.comm_size(%comm) : !mpi.retval, i32
- // CHECK-NEXT: %size = mpi.comm_size : i32
- %size = mpi.comm_size : i32
+ // CHECK-NEXT: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : i32, i32 -> !mpi.comm
+ %new_comm = mpi.comm_split(%comm, %rank, %rank) : i32, i32 -> !mpi.comm
- // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
- %retval_0, %size = mpi.comm_size : !mpi.retval, i32
+ // CHECK-NEXT: [[vretval_3:%.*]], [[vnewcomm_4:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : i32, i32 -> !mpi.retval, !mpi.comm
+ %retval_1, %new_comm_1 = mpi.comm_split(%comm, %rank, %rank) : i32, i32 -> !mpi.retval, !mpi.comm
- // CHECK-NEXT: %size = mpi.comm_size : !mpi.comm -> i32
- %size = mpi.comm_size(%comm) : !mpi.comm -> i32
+ // CHECK-NEXT: mpi.send([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32
+ mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32
- // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
- %retval_0, %size = mpi.comm_size(%comm) : !mpi.comm -> !mpi.retval, i32
+ // CHECK-NEXT: [[v2:%.*]] = mpi.send([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval
+ %retval_2 = mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
- // CHECK-NEXT: %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.comm
- %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32
+ // CHECK-NEXT: mpi.recv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32
+ mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32
- // CHECK-NEXT: %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.retval, !mpi.comm
- %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.retval, !mpi.comm
+ // CHECK-NEXT: [[v3:%.*]] = mpi.recv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval
+ %retval_3 = mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
- // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
- mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32
+ // CHECK-NEXT: [[vretval_5:%.*]], [[vreq:%.*]] = mpi.isend([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+ %err4, %req2 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
- // CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
- %err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ // CHECK-NEXT: [[vreq_6:%.*]] = mpi.isend([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.request
+ %req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.request
- // CHECK-NEXT: mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
- mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
+ // CHECK-NEXT: [[vreq_7:%.*]] = mpi.irecv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.request
+ %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.request
- // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
- mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32
+ // CHECK-NEXT: [[vretval_8:%.*]], [[vreq_9:%.*]] = mpi.irecv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+ %err5, %req4 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
- // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
- %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+ // CHECK-NEXT: mpi.wait([[vreq_9]]) : !mpi.request
+ mpi.wait(%req4) : !mpi.request
- // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
- mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
-
- // CHECK-NEXT: %req = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
- %req = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
-
- // CHECK-NEXT: %retval_1, %req_2 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
- %err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
-
- // CHECK-NEXT: %3 = mpi.isend(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> !mpi.request
- %req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm -> !mpi.request
-
- // CHECK-NEXT: %req_3 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
- %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
-
- // CHECK-NEXT: %retval_4, %req_5 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
- %err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
-
- // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request
- %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm -> !mpi.request
-
- // CHECK-NEXT: mpi.wait(%req) : !mpi.request
- mpi.wait(%req) : !mpi.request
-
- // CHECK-NEXT: %3 = mpi.wait(%req_2) : !mpi.request -> !mpi.retval
+ // CHECK-NEXT: [[v4:%.*]] = mpi.wait([[vreq]]) : !mpi.request -> !mpi.retval
%err6 = mpi.wait(%req2) : !mpi.request -> !mpi.retval
- // CHECK-NEXT: mpi.barrier
- mpi.barrier
-
- // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
- %err7 = mpi.barrier : !mpi.retval
-
- // CHECK-NEXT: mpi.barrier(%comm)
+ // CHECK-NEXT: mpi.barrier([[v1]])
mpi.barrier(%comm)
- // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
- %err7 = mpi.barrier : !mpi.retval
-
- // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
- mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32>
+ // CHECK-NEXT: [[v5:%.*]] = mpi.barrier([[v1]]) -> !mpi.retval
+ %err7 = mpi.barrier(%comm) -> !mpi.retval
- // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
- %err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ // CHECK-NEXT: [[v6:%.*]] = mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ %err8 = mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
- // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, !mpi.comm
- mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, !mpi.comm
+ // CHECK-NEXT: mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
+ mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
- // CHECK-NEXT: %7 = mpi.finalize : !mpi.retval
+ // CHECK-NEXT: [[v7:%.*]] = mpi.finalize : !mpi.retval
%rval = mpi.finalize : !mpi.retval
- // CHECK-NEXT: %8 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+ // CHECK-NEXT: [[v8:%.*]] = mpi.retval_check [[vretval:%.*]] = <MPI_SUCCESS> : i1
%res = mpi.retval_check %retval = <MPI_SUCCESS> : i1
- // CHECK-NEXT: %9 = mpi.error_class %0 : !mpi.retval
+ // CHECK-NEXT: [[v9:%.*]] = mpi.error_class [[v0]] : !mpi.retval
%errclass = mpi.error_class %err : !mpi.retval
// CHECK-NEXT: return
>From c6760058826b7853e7a792fc14a576636aca8657 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Mar 2025 18:07:10 +0100
Subject: [PATCH 08/16] lowering mpi.commworld and tests, works for MPICH, not
openmpi
---
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 49 ++++++++++++++-----
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 12 ++---
2 files changed, 41 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 4e0f59305a647..7767a953142eb 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -440,6 +440,26 @@ struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
}
};
+//===----------------------------------------------------------------------===//
+// CommWorldOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ // get MPI_COMM_WORLD
+ rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// CommRankOpLowering
//===----------------------------------------------------------------------===//
@@ -462,12 +482,12 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
- // get MPI_COMM_WORLD
- Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+ // get communicator
+ Value comm = adaptor.getComm();
// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
auto rankFuncType =
- LLVM::LLVMFunctionType::get(i32, {commWorld.getType(), ptrType});
+ LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
@@ -476,7 +496,7 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
auto callOp = rewriter.create<LLVM::CallOp>(
- loc, initDecl, ValueRange{commWorld, rankptr.getRes()});
+ loc, initDecl, ValueRange{comm, rankptr.getRes()});
// load the rank into a register
auto loadedRank =
@@ -523,12 +543,12 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
- Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+ Value comm = adaptor.getComm();
// LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
// tag, comm)`
auto funcType = LLVM::LLVMFunctionType::get(
- i32, {ptrType, i32, dataType.getType(), i32, i32, commWorld.getType()});
+ i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
@@ -537,7 +557,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
auto funcCall = rewriter.create<LLVM::CallOp>(
loc, funcDecl,
ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
- commWorld});
+ comm});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
@@ -575,7 +595,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
- Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+ Value comm = adaptor.getComm();
Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
loc, i64, mpiTraits->getStatusIgnore());
statusIgnore =
@@ -585,7 +605,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
// tag, comm)`
auto funcType =
LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
- i32, commWorld.getType(), ptrType});
+ i32, comm.getType(), ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
@@ -594,7 +614,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
auto funcCall = rewriter.create<LLVM::CallOp>(
loc, funcDecl,
ValueRange{dataPtr, size, dataType, adaptor.getSource(),
- adaptor.getTag(), commWorld, statusIgnore});
+ adaptor.getTag(), comm, statusIgnore});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
@@ -676,8 +696,13 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
- patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
- SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
+ // FIXME: Need tldi info to get mpi implementation to know the Communicator
+ // type
+ Type commType = IntegerType::get(&converter.getContext(), 32);
+ converter.addConversion([&](mpi::CommType type) { return commType; });
+ patterns.add<CommRankOpLowering, CommWorldOpLowering, FinalizeOpLowering,
+ InitOpLowering, SendOpLowering, RecvOpLowering,
+ AllReduceOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index c70077d507ad1..c176c80143bc5 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -36,8 +36,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
// CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[v19:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
- // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+ // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v8]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -46,8 +45,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
// CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[v27:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
- // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+ // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v8]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -56,10 +54,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
// CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[v35:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
- // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+ // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v8]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -68,10 +65,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
// CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[v45:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
// CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
- // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+ // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v8]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
>From 32daa5c5919ecc6bdce3b54284147751848f941a Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Fri, 28 Mar 2025 09:09:52 +0100
Subject: [PATCH 09/16] Fixing oversights and comments (review)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing at users.noreply.github.com>
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 6bc25054bf48a..3c38f82bca291 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -60,13 +60,11 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
let summary = "Get the current rank, equivalent to "
"`MPI_Comm_rank(comm, &rank)`";
let description = [{
- If communicator is not specified, `MPI_COMM_WORLD` is used by default.
-
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
- let arguments = (ins Optional<MPI_Comm> : $comm);
+ let arguments = (ins MPI_Comm : $comm);
let results = (
outs Optional<MPI_Retval> : $retval,
>From 313df8923a6b6a8eb2ad29e1b95e1d81b254ea53 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Fri, 28 Mar 2025 09:12:54 +0100
Subject: [PATCH 10/16] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing at users.noreply.github.com>
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 15 +--------------
1 file changed, 1 insertion(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 3c38f82bca291..67e51bfa197ad 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -82,8 +82,6 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
let summary = "Get the size of the group associated to the communicator, "
"equivalent to `MPI_Comm_size(comm, &size)`";
let description = [{
- If communicator is not specified, `MPI_COMM_WORLD` is used by default.
-
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
@@ -139,8 +137,6 @@ def MPI_SendOp : MPI_Op<"send", []> {
`dest`. The `tag` value and communicator enables the library to determine
the matching of multiple sends and receives between the same ranks.
- If communicator is not specified, `MPI_COMM_WORLD` is used by default.
-
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
@@ -173,8 +169,6 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
determine the matching of multiple sends and receives between the same
ranks.
- If communicator is not specified, `MPI_COMM_WORLD` is used by default.
-
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
@@ -202,7 +196,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
//===----------------------------------------------------------------------===//
def MPI_RecvOp : MPI_Op<"recv", []> {
- let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
+ let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
"comm, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
@@ -210,7 +204,6 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
determine the matching of multiple sends and receives between the same
ranks.
- If communicator is not specified, `MPI_COMM_WORLD` is used by default.
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.
@@ -245,8 +238,6 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
determine the matching of multiple sends and receives between the same
ranks.
- If communicator is not specified, `MPI_COMM_WORLD` is used by default.
-
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
@@ -284,8 +275,6 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
supported.
- If communicator is not specified, `MPI_COMM_WORLD` is used by default.
-
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
@@ -314,8 +303,6 @@ def MPI_Barrier : MPI_Op<"barrier", []> {
MPI_Barrier blocks execution until all processes in the communicator have
reached this routine.
- If communicator is not specified, `MPI_COMM_WORLD` is used by default.
-
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
>From 42c99a95ecad4b89217c8f7c40152adbe0994447 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 28 Mar 2025 10:59:59 +0100
Subject: [PATCH 11/16] using i64 as intermediate type for \!mpi.comm and
appropriate casting
---
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 40 ++++++++++++++-----
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 28 ++++++++-----
2 files changed, 49 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 7767a953142eb..c54dcdbcf27a4 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -83,9 +83,17 @@ class MPIImplTraits {
ModuleOp &getModuleOp() { return moduleOp; }
/// Gets or creates MPI_COMM_WORLD as a Value.
+ /// Different MPI implementations have different types for communicator.
+ /// Using i64 as a portable, intermediate type.
+ /// Appropriate cast needs to take place before calling MPI functions.
virtual Value getCommWorld(const Location loc,
ConversionPatternRewriter &rewriter) = 0;
+ /// Type converter provides i64 type for communicator type.
+ /// Converts to native type, which might be ptr or int or whatever.
+ virtual Value castComm(const Location loc,
+ ConversionPatternRewriter &rewriter, Value comm) = 0;
+
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
virtual intptr_t getStatusIgnore() = 0;
@@ -139,10 +147,15 @@ class MPICHImplTraits : public MPIImplTraits {
Value getCommWorld(const Location loc,
ConversionPatternRewriter &rewriter) override {
static constexpr int MPI_COMM_WORLD = 0x44000000;
- return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+ return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
MPI_COMM_WORLD);
}
+ Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
+ Value comm) override {
+ return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm);
+ }
+
intptr_t getStatusIgnore() override { return 1; }
Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
@@ -256,9 +269,16 @@ class OMPIImplTraits : public MPIImplTraits {
getOrDefineExternalStruct(loc, rewriter, name, commStructT);
// get address of symbol
- return rewriter.create<LLVM::AddressOfOp>(
+ auto comm = rewriter.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context),
SymbolRefAttr::get(context, name));
+ return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm);
+ }
+
+ Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
+ Value comm) override {
+ return rewriter.create<LLVM::IntToPtrOp>(
+ loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
}
intptr_t getStatusIgnore() override { return 0; }
@@ -483,7 +503,7 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
auto mpiTraits = MPIImplTraits::get(moduleOp);
// get communicator
- Value comm = adaptor.getComm();
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
auto rankFuncType =
@@ -543,7 +563,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
- Value comm = adaptor.getComm();
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
// LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
// tag, comm)`
@@ -595,7 +615,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
- Value comm = adaptor.getComm();
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
loc, i64, mpiTraits->getStatusIgnore());
statusIgnore =
@@ -696,10 +716,12 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
- // FIXME: Need tldi info to get mpi implementation to know the Communicator
- // type
- Type commType = IntegerType::get(&converter.getContext(), 32);
- converter.addConversion([&](mpi::CommType type) { return commType; });
+ // Using i64 as a portable, intermediate type for !mpi.comm.
+ // It would be nicer to somehow get the right type directly, but TLDI is not
+ // available here.
+ converter.addConversion([](mpi::CommType type) {
+ return IntegerType::get(type.getContext(), 64);
+ });
patterns.add<CommRankOpLowering, CommWorldOpLowering, FinalizeOpLowering,
InitOpLowering, SendOpLowering, RecvOpLowering,
AllReduceOpLowering>(converter);
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index c176c80143bc5..ca02e1b5c9f72 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -22,8 +22,10 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
%0 = mpi.init : !mpi.retval
+ // CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
%comm = mpi.comm_world : !mpi.comm
- // CHECK: [[v8:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+
+ // CHECK: [[v8:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
// CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32
@@ -36,7 +38,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
// CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v8]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+ // CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
+ // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -45,7 +48,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
// CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v8]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+ // CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
+ // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -54,9 +58,10 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
// CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
- // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v8]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+ // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[comm_3]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
// CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -65,9 +70,10 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
// CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
- // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v8]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+ // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -121,9 +127,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
%comm = mpi.comm_world : !mpi.comm
// CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[comm:%.*]] = llvm.ptrtoint [[v8]] : !llvm.ptr to i64
+ // CHECK: [[comm_1:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
- // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[comm_1]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
%retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
// CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
@@ -133,7 +141,7 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
// CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
- // CHECK: [[v19:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
@@ -143,7 +151,7 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
// CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
- // CHECK: [[v27:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
%1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
@@ -153,7 +161,7 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
// CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
- // CHECK: [[v35:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
// CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
@@ -165,7 +173,7 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
// CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
- // CHECK: [[v45:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
>From d023e655384ee82634ec46c8450fef3e75e6017c Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 28 Mar 2025 11:44:30 +0100
Subject: [PATCH 12/16] lowering mpi.comm_split
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 3 +-
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 61 +++++++++++++++++--
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 24 ++++++++
mlir/test/Dialect/MPI/mpiops.mlir | 8 +--
4 files changed, 86 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 67e51bfa197ad..58b957357bc65 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -100,7 +100,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
// CommSplitOp
//===----------------------------------------------------------------------===//
-def MPI_CommSplit : MPI_Op<"comm_split", []> {
+def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
let summary = "Partition the group associated to the given communicator into "
"disjoint subgroups";
let description = [{
@@ -121,7 +121,6 @@ def MPI_CommSplit : MPI_Op<"comm_split", []> {
);
let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
- "type($color) `,` type($key) `->` "
"type(results)";
}
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index c54dcdbcf27a4..a0db634eb4c80 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -480,6 +480,58 @@ struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
}
};
+//===----------------------------------------------------------------------===//
+// CommSplitOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ Type i32 = rewriter.getI32Type();
+ Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
+ Location loc = op.getLoc();
+
+ // get communicator
+ Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+ auto outPtr =
+ rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one);
+
+ // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
+ auto funcType =
+ LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
+ "MPI_Comm_split", funcType);
+
+ auto callOp = rewriter.create<LLVM::CallOp>(
+ loc, funcDecl,
+ ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
+ outPtr.getRes()});
+
+ // load the communicator into a register
+ auto res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
+
+ // if retval is checked, replace uses of retval with the results from the
+ // call op
+ SmallVector<Value> replacements;
+ if (op.getRetval())
+ replacements.push_back(callOp.getResult());
+
+ // replace op
+ replacements.push_back(res.getRes());
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// CommRankOpLowering
//===----------------------------------------------------------------------===//
@@ -512,7 +564,7 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
- // replace init with function call
+ // replace with function call
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
auto callOp = rewriter.create<LLVM::CallOp>(
@@ -722,9 +774,10 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
converter.addConversion([](mpi::CommType type) {
return IntegerType::get(type.getContext(), 64);
});
- patterns.add<CommRankOpLowering, CommWorldOpLowering, FinalizeOpLowering,
- InitOpLowering, SendOpLowering, RecvOpLowering,
- AllReduceOpLowering>(converter);
+ patterns
+ .add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
+ FinalizeOpLowering, InitOpLowering, SendOpLowering, RecvOpLowering,
+ AllReduceOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index ca02e1b5c9f72..18d18dc53458d 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -3,6 +3,7 @@
// COM: Test MPICH ABI
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
@@ -75,6 +76,17 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+ // CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
+ %color = arith.constant 10 : i32
+ // CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32
+ %key = arith.constant 22 : i32
+ // CHECK: [[v53:%.*]] = llvm.trunc [[comm]] : i64 to i32
+ // CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x i32 : (i32) -> !llvm.ptr
+ // CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (i32, i32, i32, !llvm.ptr) -> i32
+ // CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
+ %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
// CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -104,6 +116,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// COM: Test OpenMPI ABI
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
@@ -195,6 +208,17 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+ // CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
+ %color = arith.constant 10 : i32
+ // CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32
+ %key = arith.constant 22 : i32
+ // CHECK: [[v53:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
+ // CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x !llvm.ptr : (i32) -> !llvm.ptr
+ // CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+ // CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
+ %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
+
// CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index 265687270e671..ef457628fe2c4 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -23,11 +23,11 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
// CHECK-NEXT: [[vretval_1:%.*]], [[vsize_2:%.*]] = mpi.comm_size([[v1]]) : !mpi.retval, i32
%retval_0, %size_1 = mpi.comm_size(%comm) : !mpi.retval, i32
- // CHECK-NEXT: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : i32, i32 -> !mpi.comm
- %new_comm = mpi.comm_split(%comm, %rank, %rank) : i32, i32 -> !mpi.comm
+ // CHECK-NEXT: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : !mpi.comm
+ %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm
- // CHECK-NEXT: [[vretval_3:%.*]], [[vnewcomm_4:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : i32, i32 -> !mpi.retval, !mpi.comm
- %retval_1, %new_comm_1 = mpi.comm_split(%comm, %rank, %rank) : i32, i32 -> !mpi.retval, !mpi.comm
+ // CHECK-NEXT: [[vretval_3:%.*]], [[vnewcomm_4:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : !mpi.retval, !mpi.comm
+ %retval_1, %new_comm_1 = mpi.comm_split(%comm, %rank, %rank) : !mpi.retval, !mpi.comm
// CHECK-NEXT: mpi.send([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32
mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32
>From 9f23e95c390881fe1813860eb214efd0fa20a7d4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 28 Mar 2025 12:00:17 +0100
Subject: [PATCH 13/16] cleanup
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 7 -------
1 file changed, 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 58b957357bc65..7283988d6d9f3 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -310,13 +310,6 @@ def MPI_Barrier : MPI_Op<"barrier", []> {
let results = (outs Optional<MPI_Retval>:$retval);
- // TODO fix assembly format
- // let assemblyFormat = "("
- // "(attr-dict) ^"
- // "(attr-dict `:` type($retval)) ^"
- // "(`(` $comm `)` attr-dict `:` type($comm)) ^"
- // "(`(` $comm `)` attr-dict `:` type($comm) `->` type($retval))"
- // ")?";
let assemblyFormat = [{
`(` $comm `)` attr-dict
(`->` type($retval)^)?
>From 1172687410fc8d6b1323815ab1e53eace3ca7c4a Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 31 Mar 2025 13:32:44 +0200
Subject: [PATCH 14/16] merge conflicts, fixing mpi.all_reduce
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 18 +++++++++---------
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 10 +++++-----
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 8 ++++----
3 files changed, 18 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 7283988d6d9f3..d549f5e575c5a 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -175,7 +175,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
- I32 : $rank,
+ I32 : $dest,
MPI_Comm : $comm
);
@@ -184,8 +184,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
MPI_Request : $req
);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank `,` $comm`)` attr-dict "
- "`:` type($ref) `,` type($tag) `,` type($rank) "
+ let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm`)` attr-dict "
+ "`:` type($ref) `,` type($tag) `,` type($dest) "
"`->` type(results)";
let hasCanonicalizer = 1;
}
@@ -229,11 +229,11 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
//===----------------------------------------------------------------------===//
def MPI_IRecvOp : MPI_Op<"irecv", []> {
- let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
+ let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, source, tag, "
"comm, &req)`";
let description = [{
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
- from rank `dest`. The `tag` value and communicator enables the library to
+ from rank `source`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
@@ -244,7 +244,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
- I32 : $rank,
+ I32 : $source,
MPI_Comm : $comm
);
@@ -253,8 +253,8 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
MPI_Request : $req
);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank `,` $comm`)` attr-dict "
- "`:` type($ref) `,` type($tag) `,` type($rank)"
+ let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm`)` attr-dict "
+ "`:` type($ref) `,` type($tag) `,` type($source)"
"`->` type(results)";
let hasCanonicalizer = 1;
}
@@ -281,7 +281,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
- MPI_OpClassAttr : $op,
+ MPI_OpClassEnum : $op,
MPI_Comm : $comm
);
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index a0db634eb4c80..5faf480816dd3 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -721,7 +721,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
- Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+ Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
// 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
// MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
auto funcType = LLVM::LLVMFunctionType::get(
@@ -774,10 +775,9 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
converter.addConversion([](mpi::CommType type) {
return IntegerType::get(type.getContext(), 64);
});
- patterns
- .add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
- FinalizeOpLowering, InitOpLowering, SendOpLowering, RecvOpLowering,
- AllReduceOpLowering>(converter);
+ patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
+ FinalizeOpLowering, InitOpLowering, SendOpLowering,
+ RecvOpLowering, AllReduceOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 18d18dc53458d..7c2e26dac0ca0 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -100,9 +100,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
// CHECK: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
- // CHECK: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+ // CHECK: [[v61:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
- mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+ mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
@@ -204,9 +204,9 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
// CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
- // CHECK: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+ // CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
- mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+ mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
%color = arith.constant 10 : i32
>From 682424dffbb712b9949c6fe8d5f52c7bfc04eaa3 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 31 Mar 2025 18:11:57 +0200
Subject: [PATCH 15/16] fixing match results
---
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 44 +++++++++----------
1 file changed, 22 insertions(+), 22 deletions(-)
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 7c2e26dac0ca0..174f7c79b9d50 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -88,20 +88,20 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
- // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
- // CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
- // CHECK: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
- // CHECK: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
- // CHECK: [[v61:%.*]] = llvm.trunc [[comm]] : i64 to i32
- // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+ // CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v62:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v63:%.*]] = llvm.trunc [[v62]] : i64 to i32
+ // CHECK: [[v64:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v65:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32
+ // CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+ // CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
+ // CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32
+ // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[v61]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
// CHECK: llvm.call @MPI_Finalize() : () -> i32
@@ -208,18 +208,18 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
- // CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
+ // CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
%color = arith.constant 10 : i32
- // CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32
+ // CHECK: [[v72:%.*]] = llvm.mlir.constant(22 : i32) : i32
%key = arith.constant 22 : i32
- // CHECK: [[v53:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
- // CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x !llvm.ptr : (i32) -> !llvm.ptr
- // CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
- // CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
+ // CHECK: [[v73:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
+ // CHECK: [[v74:%.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: [[v75:%.*]] = llvm.alloca [[v74]] x !llvm.ptr : (i32) -> !llvm.ptr
+ // CHECK: [[v76:%.*]] = llvm.call @MPI_Comm_split([[v73]], [[v71]], [[v72]], [[v75]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+ // CHECK: [[v77:%.*]] = llvm.load [[v75]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
- // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
+ // CHECK: llvm.call @MPI_Finalize() : () -> i32
%3 = mpi.finalize : !mpi.retval
return
>From b3e196f85bdb43be5e4305e82898a9bc9a19500e Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 31 Mar 2025 18:34:07 +0200
Subject: [PATCH 16/16] Apply suggestions from code review
Co-authored-by: Christian Ulmann <christianulmann at gmail.com>
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 8 ++++----
mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 2 +-
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 4 ++--
3 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index d549f5e575c5a..d78aa92d201e7 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -71,7 +71,7 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
I32 : $rank
);
- let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+ let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
}
//===----------------------------------------------------------------------===//
@@ -93,7 +93,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
I32 : $size
);
- let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+ let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
}
//===----------------------------------------------------------------------===//
@@ -101,7 +101,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
//===----------------------------------------------------------------------===//
def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
- let summary = "Partition the group associated to the given communicator into "
+ let summary = "Partition the group associated with the given communicator into "
"disjoint subgroups";
let description = [{
This operation splits the communicator into multiple sub-communicators.
@@ -311,7 +311,7 @@ def MPI_Barrier : MPI_Op<"barrier", []> {
let results = (outs Optional<MPI_Retval>:$retval);
let assemblyFormat = [{
- `(` $comm `)` attr-dict
+ `(` $comm `)` attr-dict
(`->` type($retval)^)?
}];
}
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index b56a224d84774..adc35a70b5837 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -47,7 +47,7 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
def MPI_Comm : MPI_Type<"Comm", "comm"> {
let summary = "MPI communicator handler";
let description = [{
- This type represents a handler to the MPI communicator.
+ This type represents a handler for the MPI communicator.
}];
}
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 5faf480816dd3..9df5e992e8ebd 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -83,14 +83,14 @@ class MPIImplTraits {
ModuleOp &getModuleOp() { return moduleOp; }
/// Gets or creates MPI_COMM_WORLD as a Value.
- /// Different MPI implementations have different types for communicator.
+ /// Different MPI implementations have different communicator types.
/// Using i64 as a portable, intermediate type.
/// Appropriate cast needs to take place before calling MPI functions.
virtual Value getCommWorld(const Location loc,
ConversionPatternRewriter &rewriter) = 0;
/// Type converter provides i64 type for communicator type.
- /// Converts to native type, which might be ptr or int or whatever.
+ /// Converts to native type, which might be ptr or int or whatever.
virtual Value castComm(const Location loc,
ConversionPatternRewriter &rewriter, Value comm) = 0;
More information about the Mlir-commits
mailing list