[Mlir-commits] [mlir] [MLIR] Add `mpi.comm` type to MPI dialect (PR #125361)

Sergio Sánchez Ramírez llvmlistbot at llvm.org
Sat Feb 1 13:49:39 PST 2025


https://github.com/mofeing created https://github.com/llvm/llvm-project/pull/125361

cc @wsmoses @tobiasgrosser @hhkit

There are currently 3 problems for which I need help:

1. i don't know how to represent a constant in MLIR. not something like a `arith.constant`; I'm referring to `MPI_COMM_WORLD`. what I did instead is create a `mpi.comm_world` op that returns a `mpi.comm` value, but not sure if it's the best. furthermore, it would be nice if we make optional `MPI_Comm` arguments to have a default value which is exactly `MPI_COMM_WORLD` or the result of `mpi.comm_world`.

2. The `assemblyFormat` of `MPI_BarrierOp`. Since it has one optional input argument and one optional result, I think it should support the following assembly formats:

```mlir
mpi.barrier
mpi.barrier : !mpi.retval
mpi.barrier(%comm) : !mpi.comm
mpi.barrier(%comm) : !mpi.comm -> !mpi.retval
```

but this imposes some problems on how to parse it. specifically, parsing `->` just in the case the two optional arguments are present. also, the 2nd and the 3rd case look confusing because in the 2nd the only type is the result type, while in the 3rd type is the input argument type.
what are your opinions on this?

3. Now there are ops that support a optional `MPI_Comm` argument that previously didn't, so the op creation functions are different and it's breaking the Mesh to MPI conversion. i guess that we just need to add some more `create` functions or update that conversion, but I would appreciate a lil help here since I'm not very famliar with the C++ side of MLIR. 

```mlir
cmake --build . --target check-mlir                                                                                                                                                                                                                              ✔   6s 
[24/348] Building CXX object tools/mlir/lib/Conversion/MeshToMPI/CMakeFiles/obj.MLIRMeshToMPI.dir/MeshToMPI.cpp.o
FAILED: tools/mlir/lib/Conversion/MeshToMPI/CMakeFiles/obj.MLIRMeshToMPI.dir/MeshToMPI.cpp.o 
/opt/homebrew/opt/llvm/bin/clang++ -DGTEST_HAS_RTTI=0 -D_DEBUG -D_GLIBCXX_ASSERTIONS -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -I/Users/mofeing/Developer/llvm-project/build/tools/mlir/lib/Conversion/MeshToMPI -I/Users/mofeing/Developer/llvm-project/mlir/lib/Conversion/MeshToMPI -I/Users/mofeing/Developer/llvm-project/build/tools/mlir/include -I/Users/mofeing/Developer/llvm-project/mlir/include -I/Users/mofeing/Developer/llvm-project/build/include -I/Users/mofeing/Developer/llvm-project/llvm/include -I/opt/homebrew/include -fPIC -fvisibility-inlines-hidden -Werror=date-time -Werror=unguarded-availability-new -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wmissing-field-initializers -pedantic -Wno-long-long -Wc++98-compat-extra-semi -Wimplicit-fallthrough -Wcovered-switch-default -Wno-noexcept-type -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wsuggest-override -Wstring-conversion -Wmisleading-indentation -Wctad-maybe-unsupported -fdiagnostics-color -Wundef -Werror=mismatched-tags -Werror=global-constructors -g -std=c++17 -arch arm64 -isysroot /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.5.sdk -mmacosx-version-min=14.4  -fno-exceptions -funwind-tables -fno-rtti -gsplit-dwarf -MD -MT tools/mlir/lib/Conversion/MeshToMPI/CMakeFiles/obj.MLIRMeshToMPI.dir/MeshToMPI.cpp.o -MF tools/mlir/lib/Conversion/MeshToMPI/CMakeFiles/obj.MLIRMeshToMPI.dir/MeshToMPI.cpp.o.d -o tools/mlir/lib/Conversion/MeshToMPI/CMakeFiles/obj.MLIRMeshToMPI.dir/MeshToMPI.cpp.o -c /Users/mofeing/Developer/llvm-project/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
In file included from /Users/mofeing/Developer/llvm-project/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp:15:
In file included from /Users/mofeing/Developer/llvm-project/mlir/include/mlir/Dialect/Arith/IR/Arith.h:19:
In file included from /Users/mofeing/Developer/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h:18:
/Users/mofeing/Developer/llvm-project/mlir/include/mlir/IR/Builders.h:507:5: error: no matching function for call to 'build'
  507 |     OpTy::build(*this, state, std::forward<Args>(args)...);
      |     ^~~~~~~~~~~
/Users/mofeing/Developer/llvm-project/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp:141:14: note: in instantiation of function template specialization 'mlir::OpBuilder::create<mlir::mpi::CommRankOp, mlir::TypeRange>' requested here
  141 |             .create<mpi::CommRankOp>(
      |              ^
/Users/mofeing/Developer/llvm-project/build/tools/mlir/include/mlir/Dialect/MPI/IR/MPIOps.h.inc:559:15: note: candidate function not viable: requires 4 arguments, but 3 were provided
  559 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, /*optional*/::mlir::Value comm);
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Users/mofeing/Developer/llvm-project/build/tools/mlir/include/mlir/Dialect/MPI/IR/MPIOps.h.inc:558:15: note: candidate function not viable: requires 5 arguments, but 3 were provided
  558 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type retval, ::mlir::Type rank, /*optional*/::mlir::Value comm);
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Users/mofeing/Developer/llvm-project/build/tools/mlir/include/mlir/Dialect/MPI/IR/MPIOps.h.inc:560:15: note: candidate function not viable: requires at least 4 arguments, but 3 were provided
  560 |   static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from /Users/mofeing/Developer/llvm-project/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp:15:
In file included from /Users/mofeing/Developer/llvm-project/mlir/include/mlir/Dialect/Arith/IR/Arith.h:19:
In file included from /Users/mofeing/Developer/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h:18:
/Users/mofeing/Developer/llvm-project/mlir/include/mlir/IR/Builders.h:507:5: error: no matching function for call to 'build'
  507 |     OpTy::build(*this, state, std::forward<Args>(args)...);
      |     ^~~~~~~~~~~
/Users/mofeing/Developer/llvm-project/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp:375:23: note: in instantiation of function template specialization 'mlir::OpBuilder::create<mlir::mpi::SendOp, mlir::TypeRange, mlir::memref::AllocOp &, mlir::arith::ConstantOp &, mlir::Value &>' requested here
  375 |               builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
      |                       ^
/Users/mofeing/Developer/llvm-project/build/tools/mlir/include/mlir/Dialect/MPI/IR/MPIOps.h.inc:2232:15: note: candidate function not viable: requires at most 5 arguments, but 6 were provided
 2232 |   static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Users/mofeing/Developer/llvm-project/build/tools/mlir/include/mlir/Dialect/MPI/IR/MPIOps.h.inc:2230:15: note: candidate function not viable: requires 7 arguments, but 6 were provided
 2230 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type retval, ::mlir::Value ref, ::mlir::Value tag, ::mlir::Value rank, /*optional*/::mlir::Value comm);
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Users/mofeing/Developer/llvm-project/build/tools/mlir/include/mlir/Dialect/MPI/IR/MPIOps.h.inc:2231:15: note: candidate function not viable: requires 7 arguments, but 6 were provided
 2231 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value ref, ::mlir::Value tag, ::mlir::Value rank, /*optional*/::mlir::Value comm);
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from /Users/mofeing/Developer/llvm-project/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp:15:
In file included from /Users/mofeing/Developer/llvm-project/mlir/include/mlir/Dialect/Arith/IR/Arith.h:19:
In file included from /Users/mofeing/Developer/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h:18:
/Users/mofeing/Developer/llvm-project/mlir/include/mlir/IR/Builders.h:507:5: error: no matching function for call to 'build'
  507 |     OpTy::build(*this, state, std::forward<Args>(args)...);
      |     ^~~~~~~~~~~
/Users/mofeing/Developer/llvm-project/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp:383:23: note: in instantiation of function template specialization 'mlir::OpBuilder::create<mlir::mpi::RecvOp, mlir::TypeRange, mlir::memref::AllocOp &, mlir::arith::ConstantOp &, mlir::Value &>' requested here
  383 |               builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
      |                       ^
/Users/mofeing/Developer/llvm-project/build/tools/mlir/include/mlir/Dialect/MPI/IR/MPIOps.h.inc:1855:15: note: candidate function not viable: requires at most 5 arguments, but 6 were provided
 1855 |   static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Users/mofeing/Developer/llvm-project/build/tools/mlir/include/mlir/Dialect/MPI/IR/MPIOps.h.inc:1853:15: note: candidate function not viable: requires 7 arguments, but 6 were provided
 1853 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type retval, ::mlir::Value ref, ::mlir::Value tag, ::mlir::Value rank, /*optional*/::mlir::Value comm);
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Users/mofeing/Developer/llvm-project/build/tools/mlir/include/mlir/Dialect/MPI/IR/MPIOps.h.inc:1854:15: note: candidate function not viable: requires 7 arguments, but 6 were provided
 1854 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value ref, ::mlir::Value tag, ::mlir::Value rank, /*optional*/::mlir::Value comm);
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3 errors generated.
[33/348] Building CXX object tools/mlir/test/lib/Dialect/Test/CMakeFiles/MLIRTestDialect.dir/TestOpsSyntax.cpp.o
ninja: build stopped: subcommand failed.
```

>From b5b500a2b6f5db41d912ce5a5688bbb66ccb7825 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 20:43:30 +0100
Subject: [PATCH 1/6] Revert "Remove MPI_Comm type"

This reverts commit 6abba5a37d5ea73c2b177581db9d476da4a26c91.
---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td   | 130 +++++++++++++------
 mlir/include/mlir/Dialect/MPI/IR/MPITypes.td |  11 ++
 mlir/test/Dialect/MPI/ops.mlir               |  24 ++++
 3 files changed, 126 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 284ba72af9768b..b36f15a4385bc1 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -37,26 +37,43 @@ def MPI_InitOp : MPI_Op<"init", []> {
   let assemblyFormat = "attr-dict (`:` type($retval)^)?";
 }
 
+//===----------------------------------------------------------------------===//
+// CommWorldOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
+  let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
+  let description = [{
+    This operation returns the predefined MPI_COMM_WORLD communicator.
+  }];
+
+  let results = (outs MPI_Comm : $comm);
+
+  let assemblyFormat = "attr-dict `:` type(results)";
+}
+
 //===----------------------------------------------------------------------===//
 // CommRankOp
 //===----------------------------------------------------------------------===//
 
 def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
   let summary = "Get the current rank, equivalent to "
-                "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
+                "`MPI_Comm_rank(comm, &rank)`";
   let description = [{
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
+  let arguments = (ins Optional<MPI_Comm> : $comm);
+
   let results = (
     outs Optional<MPI_Retval> : $retval,
     I32 : $rank
   );
 
-  let assemblyFormat = "attr-dict `:` type(results)";
+  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)";
 }
 
 //===----------------------------------------------------------------------===//
@@ -65,20 +82,50 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
 
 def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
   let summary = "Get the size of the group associated to the communicator, "
-                "equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`";
+                "equivalent to `MPI_Comm_size(comm, &size)`";
   let description = [{
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
+  let arguments = (ins Optional<MPI_Comm> : $comm);
+
   let results = (
     outs Optional<MPI_Retval> : $retval,
     I32 : $size
   );
 
-  let assemblyFormat = "attr-dict `:` type(results)";
+  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)";
+}
+
+//===----------------------------------------------------------------------===//
+// CommSplitOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommSplit : MPI_Op<"comm_split", []> {
+  let summary = "Partition the group associated to the given communicator into "
+                "disjoint subgroups";
+  let description = [{
+    This operation splits the communicator into multiple sub-communicators.
+    The color value determines the group of processes that will be part of the
+    new communicator. The key value determines the rank of the calling process
+    in the new communicator.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);
+
+  let results = (
+    outs MPI_Comm : $newcomm,
+    Optional<MPI_Retval> : $retval
+  );
+
+  let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
+                       "type(results)";
 }
 
 //===----------------------------------------------------------------------===//
@@ -87,13 +134,13 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
 
 def MPI_SendOp : MPI_Op<"send", []> {
   let summary =
-      "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+      "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
   let description = [{
     MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
     `dest`. The `tag` value and communicator enables the library to determine 
     the matching of multiple sends and receives between the same ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
@@ -102,12 +149,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
   let arguments = (
     ins AnyMemRef : $ref,
     I32 : $tag,
-    I32 : $rank
+    I32 : $rank,
+    Optional<MPI_Comm> : $comm
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)? `)` attr-dict `:` "
                        "type($ref) `,` type($tag) `,` type($rank)"
                        "(`->` type($retval)^)?";
   let hasCanonicalizer = 1;
@@ -119,14 +167,14 @@ def MPI_SendOp : MPI_Op<"send", []> {
 
 def MPI_ISendOp : MPI_Op<"isend", []> {
   let summary =
-      "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+      "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
   let description = [{
     MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
     rank `dest`. The `tag` value and communicator enables the library to
     determine the matching of multiple sends and receives between the same
     ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
@@ -135,7 +183,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
   let arguments = (
     ins AnyMemRef : $ref,
     I32 : $tag,
-    I32 : $rank
+    I32 : $rank,
+    Optional<MPI_Comm> : $comm
   );
 
   let results = (
@@ -143,9 +192,9 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
     MPI_Request : $req
   );
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict "
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict "
                        "`:` type($ref) `,` type($tag) `,` type($rank) "
-                       "`->` type(results)";
+                       "(`,` type($comm) ^)? `->` type(results)";
   let hasCanonicalizer = 1;
 }
 
@@ -155,14 +204,14 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
 
 def MPI_RecvOp : MPI_Op<"recv", []> {
   let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
-                "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
+                "comm, MPI_STATUS_IGNORE)`";
   let description = [{
     MPI_Recv performs a blocking receive of `size` elements of type `dtype` 
     from rank `dest`. The `tag` value and communicator enables the library to 
     determine the matching of multiple sends and receives between the same 
     ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
     The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object 
     is not yet ported to MLIR.
 
@@ -172,14 +221,15 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
 
   let arguments = (
     ins AnyMemRef : $ref,
-    I32 : $tag, I32 : $rank
+    I32 : $tag, I32 : $rank,
+    Optional<MPI_Comm> : $comm
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
-                       "type($ref) `,` type($tag) `,` type($rank)"
-                       "(`->` type($retval)^)?";
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict"
+                       " `:` type($ref) `,` type($tag) `,` type($rank) "
+                       "(`,` type($comm) ^)? (`->` type($retval)^)?";
   let hasCanonicalizer = 1;
 }
 
@@ -189,14 +239,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
 
 def MPI_IRecvOp : MPI_Op<"irecv", []> {
   let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
-                "MPI_COMM_WORLD, &req)`";
+                "comm, &req)`";
   let description = [{
     MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` 
     from rank `dest`. The `tag` value and communicator enables the library to 
     determine the matching of multiple sends and receives between the same 
     ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
@@ -205,7 +255,8 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
   let arguments = (
     ins AnyMemRef : $ref,
     I32 : $tag,
-    I32 : $rank
+    I32 : $rank,
+    Optional<MPI_Comm> : $comm
   );
 
   let results = (
@@ -213,9 +264,9 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
     MPI_Request : $req
   );
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
-                       "type($ref) `,` type($tag) `,` type($rank) `->`"
-                       "type(results)";
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict "
+                       "`:` type($ref) `,` type($tag) `,` type($rank)"
+                       "(`,` type($comm) ^)? `->` type(results)";
   let hasCanonicalizer = 1;
 }
 
@@ -224,8 +275,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
 //===----------------------------------------------------------------------===//
 
 def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
-  let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, "
-                "MPI_COMM_WORLD)`";
+  let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`";
   let description = [{
     MPI_Allreduce performs a reduction operation on the values in the sendbuf
     array and stores the result in the recvbuf array. The operation is 
@@ -235,7 +285,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
     Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
     supported.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
@@ -244,14 +294,15 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
   let arguments = (
     ins AnyMemRef : $sendbuf,
     AnyMemRef : $recvbuf,
-    MPI_OpClassAttr : $op
+    MPI_OpClassAttr : $op,
+    Optional<MPI_Comm> : $comm
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`"
-                       "type($sendbuf) `,` type($recvbuf)"
-                       "(`->` type($retval)^)?";
+  let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm ^)?`)` "
+                       "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
+                       "(`,` type($comm) ^)? (`->` type($retval)^)?";
 }
 
 //===----------------------------------------------------------------------===//
@@ -259,20 +310,22 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
 //===----------------------------------------------------------------------===//
 
 def MPI_Barrier : MPI_Op<"barrier", []> {
-  let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`";
+  let summary = "Equivalent to `MPI_Barrier(comm)`";
   let description = [{
     MPI_Barrier blocks execution until all processes in the communicator have
     reached this routine.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+    If communicator is not specified, `MPI_COMM_WORLD` is used by default.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
+  let arguments = (ins Optional<MPI_Comm> : $comm);
+
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "attr-dict (`:` type($retval) ^)?";
+  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($retval) ^)?";
 }
 
 //===----------------------------------------------------------------------===//
@@ -295,8 +348,7 @@ def MPI_Wait : MPI_Op<"wait", []> {
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) "
-                       "(`->` type($retval) ^)?";
+  let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index fafea0eac8bb74..868132a62abc4b 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -40,6 +40,17 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// mpi::CommType
+//===----------------------------------------------------------------------===//
+
+def MPI_Comm : MPI_Type<"Comm", "comm"> {
+  let summary = "MPI communicator handler";
+  let description = [{
+    This type represents a handler to the MPI communicator.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // mpi::RequestType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir
index f23a7e18a2ee97..f5bdb86be94c47 100644
--- a/mlir/test/Dialect/MPI/ops.mlir
+++ b/mlir/test/Dialect/MPI/ops.mlir
@@ -12,30 +12,48 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
     %retval_0, %size = mpi.comm_size : !mpi.retval, i32
 
+    // CHECK-NEXT: %comm = mpi.comm_world : !mpi.comm
+    %comm = mpi.comm_world : !mpi.comm
+
+    // CHECK-NEXT: %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : i32, !mpi.retval
+    %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : mpi.comm, i32, i32
+
     // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
     mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32
 
     // CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
     %err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
+    // CHECK-NEXT: mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
+    mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
+
     // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
     mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32
 
     // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
     %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
+    // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
+    mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm
+
     // CHECK-NEXT: %req = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
     %req = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
 
     // CHECK-NEXT: %retval_1, %req_2 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
     %err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
 
+    // CHECK-NEXT: %3 = mpi.isend(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> !mpi.request
+    %req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm -> !mpi.request
+
     // CHECK-NEXT: %req_3 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
     %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
 
     // CHECK-NEXT: %retval_4, %req_5 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
     %err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
 
+    // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request
+    %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, !mpi.comm -> !mpi.request
+
     // CHECK-NEXT: mpi.wait(%req) : !mpi.request
     mpi.wait(%req) : !mpi.request
 
@@ -48,12 +66,18 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
     %err7 = mpi.barrier : !mpi.retval
 
+    // CHECK-NEXT: mpi.barrier(%comm) : !mpi.retval
+    mpi.barrier(%comm) : !mpi.retval
+
     // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
     mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
 
     // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval
     %err8 = mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval
 
+    // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, !mpi.comm
+    mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, !mpi.comm
+
     // CHECK-NEXT: %7 = mpi.finalize : !mpi.retval
     %rval = mpi.finalize : !mpi.retval
 

>From 0d677f38bb3f75492085af47bee29b9ab1a6f349 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 20:52:39 +0100
Subject: [PATCH 2/6] Fix assembly format for `comm_size`, `comm_rank`

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index b36f15a4385bc1..dcf46192a43f83 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -73,7 +73,8 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
     I32 : $rank
   );
 
-  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)";
+  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` (type($comm) ^ `->`)?"
+                       "type(results)";
 }
 
 //===----------------------------------------------------------------------===//
@@ -97,7 +98,8 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
     I32 : $size
   );
 
-  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)";
+  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` (type($comm) ^ `->`)?"
+                       "type(results)";
 }
 
 //===----------------------------------------------------------------------===//

>From f80e7b1cb101a35f23ae5491515431cb00ce2e99 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 20:53:03 +0100
Subject: [PATCH 3/6] add more tests for `comm_size`, `comm_rank`

---
 mlir/test/Dialect/MPI/ops.mlir | 22 ++++++++++++++++++++--
 1 file changed, 20 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir
index f5bdb86be94c47..d7521353b34a14 100644
--- a/mlir/test/Dialect/MPI/ops.mlir
+++ b/mlir/test/Dialect/MPI/ops.mlir
@@ -6,14 +6,32 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // CHECK: %0 = mpi.init : !mpi.retval
     %err = mpi.init : !mpi.retval
 
+    // CHECK-NEXT: %comm = mpi.comm_world : !mpi.comm
+    %comm = mpi.comm_world : !mpi.comm
+
+    // CHECK-NEXT: %rank = mpi.comm_rank : i32
+    %rank = mpi.comm_rank : i32
+
     // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32
     %retval, %rank = mpi.comm_rank : !mpi.retval, i32
 
+    // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.comm -> i32
+    %rank = mpi.comm_rank(%comm) : !mpi.comm -> i32
+
+    // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.comm -> !mpi.retval, i32
+    %retval, %rank = mpi.comm_rank(%comm) : !mpi.comm -> !mpi.retval, i32
+
+    // CHECK-NEXT: %size = mpi.comm_size : i32
+    %size = mpi.comm_size : i32
+
     // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
     %retval_0, %size = mpi.comm_size : !mpi.retval, i32
 
-    // CHECK-NEXT: %comm = mpi.comm_world : !mpi.comm
-    %comm = mpi.comm_world : !mpi.comm
+    // CHECK-NEXT: %size = mpi.comm_size : !mpi.comm -> i32
+    %size = mpi.comm_size(%comm) : !mpi.comm -> i32
+
+    // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
+    %retval_0, %size = mpi.comm_size(%comm) : !mpi.comm -> !mpi.retval, i32
 
     // CHECK-NEXT: %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : i32, !mpi.retval
     %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : mpi.comm, i32, i32

>From 0ebf945c8566d8bd48a59ed0382193e28467bd72 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 22:19:45 +0100
Subject: [PATCH 4/6] fix some assembly formats

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index dcf46192a43f83..16672bf12a7d21 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -73,8 +73,8 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
     I32 : $rank
   );
 
-  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` (type($comm) ^ `->`)?"
-                       "type(results)";
+  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($comm) ^ `->`):"
+                       "(`:`)? type(results)";
 }
 
 //===----------------------------------------------------------------------===//
@@ -98,8 +98,8 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
     I32 : $size
   );
 
-  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` (type($comm) ^ `->`)?"
-                       "type(results)";
+  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($comm) ^ `->`):"
+                       "(`:`)? type(results)";
 }
 
 //===----------------------------------------------------------------------===//
@@ -122,11 +122,12 @@ def MPI_CommSplit : MPI_Op<"comm_split", []> {
   let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);
 
   let results = (
-    outs MPI_Comm : $newcomm,
-    Optional<MPI_Retval> : $retval
+    outs Optional<MPI_Retval> : $retval,
+    MPI_Comm : $newcomm
   );
 
   let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
+                       "type($comm) `,` type($color) `,` type($key) `->` "
                        "type(results)";
 }
 

>From 845188d1781e73e0596fa24655ebf33ca2b5920f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 22:20:04 +0100
Subject: [PATCH 5/6] fix some tests

---
 mlir/test/Dialect/MPI/ops.mlir | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir
index d7521353b34a14..fad203ded1d06a 100644
--- a/mlir/test/Dialect/MPI/ops.mlir
+++ b/mlir/test/Dialect/MPI/ops.mlir
@@ -33,8 +33,11 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
     %retval_0, %size = mpi.comm_size(%comm) : !mpi.comm -> !mpi.retval, i32
 
-    // CHECK-NEXT: %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : i32, !mpi.retval
-    %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : mpi.comm, i32, i32
+    // CHECK-NEXT: %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.comm
+    %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32
+
+    // CHECK-NEXT: %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.retval, !mpi.comm
+    %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm, i32, i32 -> !mpi.retval, !mpi.comm
 
     // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
     mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32
@@ -78,14 +81,17 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // CHECK-NEXT: %3 = mpi.wait(%req_2) : !mpi.request -> !mpi.retval
     %err6 = mpi.wait(%req2) : !mpi.request -> !mpi.retval
 
-    // CHECK-NEXT: mpi.barrier : !mpi.retval
-    mpi.barrier : !mpi.retval
+    // CHECK-NEXT: mpi.barrier
+    mpi.barrier
 
     // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
     %err7 = mpi.barrier : !mpi.retval
 
-    // CHECK-NEXT: mpi.barrier(%comm) : !mpi.retval
-    mpi.barrier(%comm) : !mpi.retval
+    // CHECK-NEXT: mpi.barrier(%comm)
+    mpi.barrier(%comm)
+
+    // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
+    %err7 = mpi.barrier : !mpi.retval
 
     // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
     mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32>

>From 45dbcc8173b3910c400cf68dbb4ef4158adb6f18 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?=
 <sergio.sanchez.ramirez+git at bsc.es>
Date: Sat, 1 Feb 2025 22:20:43 +0100
Subject: [PATCH 6/6] try fixing the assembly forma of `barrier`

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 16672bf12a7d21..baa279c62a16c4 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -328,7 +328,18 @@ def MPI_Barrier : MPI_Op<"barrier", []> {
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($retval) ^)?";
+  // TODO fix assembly format
+  // let assemblyFormat = "("
+  //                      "(attr-dict) ^"
+  //                      "(attr-dict `:` type($retval)) ^"
+  //                      "(`(` $comm `)` attr-dict `:` type($comm)) ^"
+  //                      "(`(` $comm `)` attr-dict `:` type($comm) `->` type($retval))"
+  //                      ")?";
+  let assemblyFormat = [{
+    (`(` $comm ^ `)`)? attr-dict
+    (`:` type($comm) ^ `->`):(`:`)?
+    type(results)
+  }];
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list