[Mlir-commits] [mlir] [mlir][mesh] Mandatory Communicator (PR #133280)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 27 10:19:51 PDT 2025
Sergio =?utf-8?q?Sánchez_Ramírez?=,
Sergio =?utf-8?q?Sánchez_Ramírez?=,
Sergio =?utf-8?q?Sánchez_Ramírez?=,
Sergio =?utf-8?q?Sánchez_Ramírez?=,
Sergio =?utf-8?q?Sánchez_Ramírez?=,"Schlimbach, Frank"
<frank.schlimbach at intel.com>,"Schlimbach, Frank" <frank.schlimbach at intel.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/133280 at github.com>
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
@<!-- -->mofeing
#<!-- -->125361
---
Patch is 54.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133280.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+100-37)
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPITypes.td (+11)
- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+36-12)
- (modified) mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp (+17-7)
- (modified) mlir/test/Conversion/MPIToLLVM/ops.mlir (+16-18)
- (modified) mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir (+34-28)
- (modified) mlir/test/Dialect/MPI/ops.mlir (+52-35)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index db28bd09678f8..6bc25054bf48a 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -37,26 +37,43 @@ def MPI_InitOp : MPI_Op<"init", []> {
let assemblyFormat = "attr-dict (`:` type($retval)^)?";
}
+//===----------------------------------------------------------------------===//
+// CommWorldOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
+ let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
+ let description = [{
+ This operation returns the predefined MPI_COMM_WORLD communicator.
+ }];
+
+ let results = (outs MPI_Comm : $comm);
+
+ let assemblyFormat = "attr-dict `:` type(results)";
+}
+
//===----------------------------------------------------------------------===//
// CommRankOp
//===----------------------------------------------------------------------===//
def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
let summary = "Get the current rank, equivalent to "
- "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
+ "`MPI_Comm_rank(comm, &rank)`";
let description = [{
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
+ let arguments = (ins Optional<MPI_Comm> : $comm);
+
let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $rank
);
- let assemblyFormat = "attr-dict `:` type(results)";
+ let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
}
//===----------------------------------------------------------------------===//
@@ -65,20 +82,51 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
let summary = "Get the size of the group associated to the communicator, "
- "equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`";
+ "equivalent to `MPI_Comm_size(comm, &size)`";
let description = [{
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
+ let arguments = (ins MPI_Comm : $comm);
+
let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $size
);
- let assemblyFormat = "attr-dict `:` type(results)";
+ let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+}
+
+//===----------------------------------------------------------------------===//
+// CommSplitOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommSplit : MPI_Op<"comm_split", []> {
+ let summary = "Partition the group associated to the given communicator into "
+ "disjoint subgroups";
+ let description = [{
+ This operation splits the communicator into multiple sub-communicators.
+ The color value determines the group of processes that will be part of the
+ new communicator. The key value determines the rank of the calling process
+ in the new communicator.
+
+ This operation can optionally return an `!mpi.retval` value that can be used
+ to check for errors.
+ }];
+
+ let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);
+
+ let results = (
+ outs Optional<MPI_Retval> : $retval,
+ MPI_Comm : $newcomm
+ );
+
+ let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
+ "type($color) `,` type($key) `->` "
+ "type(results)";
}
//===----------------------------------------------------------------------===//
@@ -87,13 +135,13 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
def MPI_SendOp : MPI_Op<"send", []> {
let summary =
- "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+ "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
`dest`. The `tag` value and communicator enables the library to determine
the matching of multiple sends and receives between the same ranks.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
@@ -102,12 +150,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
- I32 : $dest
+ I32 : $dest,
+ MPI_Comm : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
+ let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($dest)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
@@ -119,14 +168,14 @@ def MPI_SendOp : MPI_Op<"send", []> {
def MPI_ISendOp : MPI_Op<"isend", []> {
let summary =
- "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+ "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
@@ -135,7 +184,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
- I32 : $rank
+ I32 : $rank,
+ MPI_Comm : $comm
);
let results = (
@@ -143,7 +193,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
MPI_Request : $req
);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict "
+ let assemblyFormat = "`(` $ref `,` $tag `,` $rank `,` $comm`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank) "
"`->` type(results)";
let hasCanonicalizer = 1;
@@ -154,15 +204,15 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
//===----------------------------------------------------------------------===//
def MPI_RecvOp : MPI_Op<"recv", []> {
- let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
- "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
+ let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
+ "comm, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
from rank `source`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.
@@ -172,13 +222,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
let arguments = (
ins AnyMemRef : $ref,
- I32 : $tag, I32 : $source
+ I32 : $tag, I32 : $source,
+ MPI_Comm : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
- "type($ref) `,` type($tag) `,` type($source)"
+ let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm `)` attr-dict"
+ " `:` type($ref) `,` type($tag) `,` type($source) "
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}
@@ -189,14 +240,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
def MPI_IRecvOp : MPI_Op<"irecv", []> {
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
- "MPI_COMM_WORLD, &req)`";
+ "comm, &req)`";
let description = [{
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
from rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
@@ -205,7 +256,8 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
- I32 : $rank
+ I32 : $rank,
+ MPI_Comm : $comm
);
let results = (
@@ -213,9 +265,9 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
MPI_Request : $req
);
- let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
- "type($ref) `,` type($tag) `,` type($rank) `->`"
- "type(results)";
+ let assemblyFormat = "`(` $ref `,` $tag `,` $rank `,` $comm`)` attr-dict "
+ "`:` type($ref) `,` type($tag) `,` type($rank)"
+ "`->` type(results)";
let hasCanonicalizer = 1;
}
@@ -224,8 +276,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
//===----------------------------------------------------------------------===//
def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
- let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, "
- "MPI_COMM_WORLD)`";
+ let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`";
let description = [{
MPI_Allreduce performs a reduction operation on the values in the sendbuf
array and stores the result in the recvbuf array. The operation is
@@ -235,7 +286,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
supported.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
@@ -244,13 +295,14 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
- MPI_OpClassAttr : $op
+ MPI_OpClassAttr : $op,
+ MPI_Comm : $comm
);
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`"
- "type($sendbuf) `,` type($recvbuf)"
+ let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
+ "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
"(`->` type($retval)^)?";
}
@@ -259,20 +311,32 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
//===----------------------------------------------------------------------===//
def MPI_Barrier : MPI_Op<"barrier", []> {
- let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`";
+ let summary = "Equivalent to `MPI_Barrier(comm)`";
let description = [{
MPI_Barrier blocks execution until all processes in the communicator have
reached this routine.
- Communicators other than `MPI_COMM_WORLD` are not supported for now.
+ If communicator is not specified, `MPI_COMM_WORLD` is used by default.
This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];
+ let arguments = (ins MPI_Comm : $comm);
+
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "attr-dict (`:` type($retval) ^)?";
+ // TODO fix assembly format
+ // let assemblyFormat = "("
+ // "(attr-dict) ^"
+ // "(attr-dict `:` type($retval)) ^"
+ // "(`(` $comm `)` attr-dict `:` type($comm)) ^"
+ // "(`(` $comm `)` attr-dict `:` type($comm) `->` type($retval))"
+ // ")?";
+ let assemblyFormat = [{
+ `(` $comm `)` attr-dict
+ (`->` type($retval)^)?
+ }];
}
//===----------------------------------------------------------------------===//
@@ -295,8 +359,7 @@ def MPI_Wait : MPI_Op<"wait", []> {
let results = (outs Optional<MPI_Retval>:$retval);
- let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) "
- "(`->` type($retval) ^)?";
+ let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index a55d30e778e22..b56a224d84774 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -40,6 +40,17 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
}];
}
+//===----------------------------------------------------------------------===//
+// mpi::CommType
+//===----------------------------------------------------------------------===//
+
+def MPI_Comm : MPI_Type<"Comm", "comm"> {
+ let summary = "MPI communicator handler";
+ let description = [{
+ This type represents a handler to the MPI communicator.
+ }];
+}
+
//===----------------------------------------------------------------------===//
// mpi::RequestType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d91f9512ccb8f..be8f5989740e3 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -295,6 +295,26 @@ struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
}
};
+//===----------------------------------------------------------------------===//
+// CommWorldOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto mpiTraits = MPIImplTraits::get(moduleOp);
+ // get MPI_COMM_WORLD
+ rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// CommRankOpLowering
//===----------------------------------------------------------------------===//
@@ -317,12 +337,12 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
- // get MPI_COMM_WORLD
- Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+ // get communicator
+ Value comm = adaptor.getComm();
// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
auto rankFuncType =
- LLVM::LLVMFunctionType::get(i32, {commWorld.getType(), ptrType});
+ LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
@@ -331,7 +351,7 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
auto callOp = rewriter.create<LLVM::CallOp>(
- loc, initDecl, ValueRange{commWorld, rankptr.getRes()});
+ loc, initDecl, ValueRange{comm, rankptr.getRes()});
// load the rank into a register
auto loadedRank =
@@ -386,12 +406,12 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
- Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+ Value comm = adaptor.getComm();
// LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
// tag, comm)`
auto funcType = LLVM::LLVMFunctionType::get(
- i32, {ptrType, i32, dataType.getType(), i32, i32, commWorld.getType()});
+ i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
@@ -400,7 +420,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
auto funcCall = rewriter.create<LLVM::CallOp>(
loc, funcDecl,
ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
- commWorld});
+ comm});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
@@ -445,7 +465,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
- Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+ Value comm = adaptor.getComm();
Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
loc, i64, mpiTraits->getStatusIgnore());
statusIgnore =
@@ -455,7 +475,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
// tag, comm)`
auto funcType =
LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
- i32, commWorld.getType(), ptrType});
+ i32, comm.getType(), ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp funcDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
@@ -464,7 +484,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
auto funcCall = rewriter.create<LLVM::CallOp>(
loc, funcDecl,
ValueRange{dataPtr, size, dataType, adaptor.getSource(),
- adaptor.getTag(), commWorld, statusIgnore});
+ adaptor.getTag(), comm, statusIgnore});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
@@ -497,8 +517,12 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
- patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
- SendOpLowering, RecvOpLowering>(converter);
+ // FIXME: Need tldi info to get mpi implementation to know the Communicator
+ // type
+ Type commType = IntegerType::get(&converter.getContext(), 32);
+ converter.addConversion([&](mpi::CommType type) { return commType; });
+ patterns.add<CommRankOpLowering, CommWorldOpLowering, FinalizeOpLowering,
+ InitOpLowering, SendOpLowering, RecvOpLowering>(converter);
}
void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) {
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 87c2938e4e52b..cafbf835de22f 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/133280
More information about the Mlir-commits
mailing list