[Mlir-commits] [mlir] 49f080a - [mlir][mpi] Mandatory Communicator (#133280)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 31 23:58:58 PDT 2025


Author: Frank Schlimbach
Date: 2025-04-01T08:58:55+02:00
New Revision: 49f080afc4466ddf415d7fc7e98989c0bd07d8ea

URL: https://github.com/llvm/llvm-project/commit/49f080afc4466ddf415d7fc7e98989c0bd07d8ea
DIFF: https://github.com/llvm/llvm-project/commit/49f080afc4466ddf415d7fc7e98989c0bd07d8ea.diff

LOG: [mlir][mpi] Mandatory Communicator (#133280)

This is replacing #125361
- communicator is mandatory
- new mpi.comm_world
- new mp.comm_split
- lowering and test

---------

Co-authored-by: Sergio Sánchez Ramírez <sergio.sanchez.ramirez+git at bsc.es>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
    mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
    mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
    mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
    mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
    mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
    mlir/test/Dialect/MPI/mpiops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index a8267b115b9e6..d78aa92d201e7 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -37,26 +37,41 @@ def MPI_InitOp : MPI_Op<"init", []> {
   let assemblyFormat = "attr-dict (`:` type($retval)^)?";
 }
 
+//===----------------------------------------------------------------------===//
+// CommWorldOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
+  let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
+  let description = [{
+    This operation returns the predefined MPI_COMM_WORLD communicator.
+  }];
+
+  let results = (outs MPI_Comm : $comm);
+
+  let assemblyFormat = "attr-dict `:` type(results)";
+}
+
 //===----------------------------------------------------------------------===//
 // CommRankOp
 //===----------------------------------------------------------------------===//
 
 def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
   let summary = "Get the current rank, equivalent to "
-                "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
+                "`MPI_Comm_rank(comm, &rank)`";
   let description = [{
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
-
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
+  let arguments = (ins MPI_Comm : $comm);
+
   let results = (
     outs Optional<MPI_Retval> : $retval,
     I32 : $rank
   );
 
-  let assemblyFormat = "attr-dict `:` type(results)";
+  let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
 }
 
 //===----------------------------------------------------------------------===//
@@ -65,20 +80,48 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
 
 def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
   let summary = "Get the size of the group associated to the communicator, "
-                "equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`";
+                "equivalent to `MPI_Comm_size(comm, &size)`";
   let description = [{
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
-
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
+  let arguments = (ins MPI_Comm : $comm);
+
   let results = (
     outs Optional<MPI_Retval> : $retval,
     I32 : $size
   );
 
-  let assemblyFormat = "attr-dict `:` type(results)";
+  let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+}
+
+//===----------------------------------------------------------------------===//
+// CommSplitOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
+  let summary = "Partition the group associated with the given communicator into "
+                "disjoint subgroups";
+  let description = [{
+    This operation splits the communicator into multiple sub-communicators.
+    The color value determines the group of processes that will be part of the
+    new communicator. The key value determines the rank of the calling process
+    in the new communicator.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);
+
+  let results = (
+    outs Optional<MPI_Retval> : $retval,
+    MPI_Comm : $newcomm
+  );
+
+  let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` "
+                       "type(results)";
 }
 
 //===----------------------------------------------------------------------===//
@@ -87,14 +130,12 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
 
 def MPI_SendOp : MPI_Op<"send", []> {
   let summary =
-      "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+      "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
   let description = [{
     MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
     `dest`. The `tag` value and communicator enables the library to determine 
     the matching of multiple sends and receives between the same ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
-
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
@@ -102,12 +143,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
   let arguments = (
     ins AnyMemRef : $ref,
     I32 : $tag,
-    I32 : $dest
+    I32 : $dest,
+    MPI_Comm : $comm
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
+  let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm `)` attr-dict `:` "
                        "type($ref) `,` type($tag) `,` type($dest)"
                        "(`->` type($retval)^)?";
   let hasCanonicalizer = 1;
@@ -119,15 +161,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
 
 def MPI_ISendOp : MPI_Op<"isend", []> {
   let summary =
-      "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+      "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
   let description = [{
     MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
     rank `dest`. The `tag` value and communicator enables the library to
     determine the matching of multiple sends and receives between the same
     ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
-
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
@@ -135,7 +175,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
   let arguments = (
     ins AnyMemRef : $ref,
     I32 : $tag,
-    I32 : $rank
+    I32 : $dest,
+    MPI_Comm : $comm
   );
 
   let results = (
@@ -143,8 +184,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
     MPI_Request : $req
   );
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict "
-                       "`:` type($ref) `,` type($tag) `,` type($rank) "
+  let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm`)` attr-dict "
+                       "`:` type($ref) `,` type($tag) `,` type($dest) "
                        "`->` type(results)";
   let hasCanonicalizer = 1;
 }
@@ -155,14 +196,13 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
 
 def MPI_RecvOp : MPI_Op<"recv", []> {
   let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
-                "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
+                "comm, MPI_STATUS_IGNORE)`";
   let description = [{
     MPI_Recv performs a blocking receive of `size` elements of type `dtype` 
     from rank `source`. The `tag` value and communicator enables the library to
     determine the matching of multiple sends and receives between the same 
     ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
     The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object 
     is not yet ported to MLIR.
 
@@ -172,13 +212,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
 
   let arguments = (
     ins AnyMemRef : $ref,
-    I32 : $tag, I32 : $source
+    I32 : $tag, I32 : $source,
+    MPI_Comm : $comm
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
-                       "type($ref) `,` type($tag) `,` type($source)"
+  let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm `)` attr-dict"
+                       " `:` type($ref) `,` type($tag) `,` type($source) "
                        "(`->` type($retval)^)?";
   let hasCanonicalizer = 1;
 }
@@ -188,16 +229,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
 //===----------------------------------------------------------------------===//
 
 def MPI_IRecvOp : MPI_Op<"irecv", []> {
-  let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
-                "MPI_COMM_WORLD, &req)`";
+  let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, source, tag, "
+                "comm, &req)`";
   let description = [{
     MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` 
-    from rank `dest`. The `tag` value and communicator enables the library to 
+    from rank `source`. The `tag` value and communicator enables the library to
     determine the matching of multiple sends and receives between the same 
     ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
-
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
@@ -205,7 +244,8 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
   let arguments = (
     ins AnyMemRef : $ref,
     I32 : $tag,
-    I32 : $rank
+    I32 : $source,
+    MPI_Comm : $comm
   );
 
   let results = (
@@ -213,9 +253,9 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
     MPI_Request : $req
   );
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
-                       "type($ref) `,` type($tag) `,` type($rank) `->`"
-                       "type(results)";
+  let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm`)` attr-dict "
+                       "`:` type($ref) `,` type($tag) `,` type($source)"
+                       "`->` type(results)";
   let hasCanonicalizer = 1;
 }
 
@@ -224,8 +264,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> {
 //===----------------------------------------------------------------------===//
 
 def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
-  let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, "
-                "MPI_COMM_WORLD)`";
+  let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`";
   let description = [{
     MPI_Allreduce performs a reduction operation on the values in the sendbuf
     array and stores the result in the recvbuf array. The operation is 
@@ -235,8 +274,6 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
     Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
     supported.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
-
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
@@ -244,13 +281,14 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
   let arguments = (
     ins AnyMemRef : $sendbuf,
     AnyMemRef : $recvbuf,
-    MPI_OpClassEnum : $op
+    MPI_OpClassEnum : $op,
+    MPI_Comm : $comm
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`"
-                       "type($sendbuf) `,` type($recvbuf)"
+  let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
+                       "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
                        "(`->` type($retval)^)?";
 }
 
@@ -259,20 +297,23 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
 //===----------------------------------------------------------------------===//
 
 def MPI_Barrier : MPI_Op<"barrier", []> {
-  let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`";
+  let summary = "Equivalent to `MPI_Barrier(comm)`";
   let description = [{
     MPI_Barrier blocks execution until all processes in the communicator have
     reached this routine.
 
-    Communicators other than `MPI_COMM_WORLD` are not supported for now.
-
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
+  let arguments = (ins MPI_Comm : $comm);
+
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "attr-dict (`:` type($retval) ^)?";
+  let assemblyFormat = [{
+    `(` $comm `)` attr-dict
+    (`->` type($retval)^)?
+  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -295,8 +336,7 @@ def MPI_Wait : MPI_Op<"wait", []> {
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) "
-                       "(`->` type($retval) ^)?";
+  let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?";
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index a55d30e778e22..adc35a70b5837 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -40,6 +40,17 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// mpi::CommType
+//===----------------------------------------------------------------------===//
+
+def MPI_Comm : MPI_Type<"Comm", "comm"> {
+  let summary = "MPI communicator handler";
+  let description = [{
+    This type represents a handler for the MPI communicator.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // mpi::RequestType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 4e0f59305a647..9df5e992e8ebd 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -83,9 +83,17 @@ class MPIImplTraits {
   ModuleOp &getModuleOp() { return moduleOp; }
 
   /// Gets or creates MPI_COMM_WORLD as a Value.
+  /// Different MPI implementations have 
diff erent communicator types.
+  /// Using i64 as a portable, intermediate type.
+  /// Appropriate cast needs to take place before calling MPI functions.
   virtual Value getCommWorld(const Location loc,
                              ConversionPatternRewriter &rewriter) = 0;
 
+  /// Type converter provides i64 type for communicator type.
+  /// Converts to native type, which might be ptr or int or whatever.
+  virtual Value castComm(const Location loc,
+                         ConversionPatternRewriter &rewriter, Value comm) = 0;
+
   /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
   virtual intptr_t getStatusIgnore() = 0;
 
@@ -139,10 +147,15 @@ class MPICHImplTraits : public MPIImplTraits {
   Value getCommWorld(const Location loc,
                      ConversionPatternRewriter &rewriter) override {
     static constexpr int MPI_COMM_WORLD = 0x44000000;
-    return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+    return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
                                              MPI_COMM_WORLD);
   }
 
+  Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
+                 Value comm) override {
+    return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm);
+  }
+
   intptr_t getStatusIgnore() override { return 1; }
 
   Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
@@ -256,9 +269,16 @@ class OMPIImplTraits : public MPIImplTraits {
     getOrDefineExternalStruct(loc, rewriter, name, commStructT);
 
     // get address of symbol
-    return rewriter.create<LLVM::AddressOfOp>(
+    auto comm = rewriter.create<LLVM::AddressOfOp>(
         loc, LLVM::LLVMPointerType::get(context),
         SymbolRefAttr::get(context, name));
+    return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm);
+  }
+
+  Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
+                 Value comm) override {
+    return rewriter.create<LLVM::IntToPtrOp>(
+        loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
   }
 
   intptr_t getStatusIgnore() override { return 0; }
@@ -440,6 +460,78 @@ struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// CommWorldOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    // get MPI_COMM_WORLD
+    rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
+
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// CommSplitOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    Type i32 = rewriter.getI32Type();
+    Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
+    Location loc = op.getLoc();
+
+    // get communicator
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+    auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+    auto outPtr =
+        rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one);
+
+    // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
+    auto funcType =
+        LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
+                                                    "MPI_Comm_split", funcType);
+
+    auto callOp = rewriter.create<LLVM::CallOp>(
+        loc, funcDecl,
+        ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
+                   outPtr.getRes()});
+
+    // load the communicator into a register
+    auto res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
+
+    // if retval is checked, replace uses of retval with the results from the
+    // call op
+    SmallVector<Value> replacements;
+    if (op.getRetval())
+      replacements.push_back(callOp.getResult());
+
+    // replace op
+    replacements.push_back(res.getRes());
+    rewriter.replaceOp(op, replacements);
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // CommRankOpLowering
 //===----------------------------------------------------------------------===//
@@ -462,21 +554,21 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
     auto moduleOp = op->getParentOfType<ModuleOp>();
 
     auto mpiTraits = MPIImplTraits::get(moduleOp);
-    // get MPI_COMM_WORLD
-    Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+    // get communicator
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
 
     // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
     auto rankFuncType =
-        LLVM::LLVMFunctionType::get(i32, {commWorld.getType(), ptrType});
+        LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
     // get or create function declaration:
     LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
         moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
 
-    // replace init with function call
+    // replace with function call
     auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
     auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
     auto callOp = rewriter.create<LLVM::CallOp>(
-        loc, initDecl, ValueRange{commWorld, rankptr.getRes()});
+        loc, initDecl, ValueRange{comm, rankptr.getRes()});
 
     // load the rank into a register
     auto loadedRank =
@@ -523,12 +615,12 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
         getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
-    Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
 
     // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
     // tag, comm)`
     auto funcType = LLVM::LLVMFunctionType::get(
-        i32, {ptrType, i32, dataType.getType(), i32, i32, commWorld.getType()});
+        i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
     // get or create function declaration:
     LLVM::LLVMFuncOp funcDecl =
         getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
@@ -537,7 +629,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
     auto funcCall = rewriter.create<LLVM::CallOp>(
         loc, funcDecl,
         ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
-                   commWorld});
+                   comm});
     if (op.getRetval())
       rewriter.replaceOp(op, funcCall.getResult());
     else
@@ -575,7 +667,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
         getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
-    Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
     Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
         loc, i64, mpiTraits->getStatusIgnore());
     statusIgnore =
@@ -585,7 +677,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     // tag, comm)`
     auto funcType =
         LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
-                                          i32, commWorld.getType(), ptrType});
+                                          i32, comm.getType(), ptrType});
     // get or create function declaration:
     LLVM::LLVMFuncOp funcDecl =
         getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
@@ -594,7 +686,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     auto funcCall = rewriter.create<LLVM::CallOp>(
         loc, funcDecl,
         ValueRange{dataPtr, size, dataType, adaptor.getSource(),
-                   adaptor.getTag(), commWorld, statusIgnore});
+                   adaptor.getTag(), comm, statusIgnore});
     if (op.getRetval())
       rewriter.replaceOp(op, funcCall.getResult());
     else
@@ -629,7 +721,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
         getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
-    Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
+    Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
     // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
     //                    MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
     auto funcType = LLVM::LLVMFunctionType::get(
@@ -676,8 +769,15 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
 
 void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                               RewritePatternSet &patterns) {
-  patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
-               SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter);
+  // Using i64 as a portable, intermediate type for !mpi.comm.
+  // It would be nicer to somehow get the right type directly, but TLDI is not
+  // available here.
+  converter.addConversion([](mpi::CommType type) {
+    return IntegerType::get(type.getContext(), 64);
+  });
+  patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
+               FinalizeOpLowering, InitOpLowering, SendOpLowering,
+               RecvOpLowering, AllReduceOpLowering>(converter);
 }
 
 void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {

diff  --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 87c2938e4e52b..cafbf835de22f 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -310,11 +310,16 @@ class ConvertProcessLinearIndexOp
     }
 
     // Otherwise call create mpi::CommRankOp
-    auto rank = rewriter
-                    .create<mpi::CommRankOp>(
-                        loc, TypeRange{mpi::RetvalType::get(op->getContext()),
-                                       rewriter.getI32Type()})
-                    .getRank();
+    auto ctx = op.getContext();
+    Value commWorld =
+        rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
+    auto rank =
+        rewriter
+            .create<mpi::CommRankOp>(
+                loc,
+                TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
+                commWorld)
+            .getRank();
     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
                                                     rank);
     return success();
@@ -652,6 +657,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
       auto upperSendOffset = rewriter.create<arith::SubIOp>(
           loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
 
+      Value commWorld = rewriter.create<mpi::CommWorldOp>(
+          loc, mpi::CommType::get(op->getContext()));
+
       // Make sure we send/recv in a way that does not lead to a dead-lock.
       // The current approach is by far not optimal, this should be at least
       // be a red-black pattern or using MPI_sendrecv.
@@ -680,7 +688,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
               auto subview = builder.create<memref::SubViewOp>(
                   loc, array, offsets, dimSizes, strides);
               builder.create<memref::CopyOp>(loc, subview, buffer);
-              builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
+              builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to,
+                                          commWorld);
               builder.create<scf::YieldOp>(loc);
             });
         // if has neighbor: receive halo data into buffer and copy to array
@@ -688,7 +697,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
             loc, hasFrom, [&](OpBuilder &builder, Location loc) {
               offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
                                        : OpFoldResult(lowerRecvOffset);
-              builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
+              builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from,
+                                          commWorld);
               auto subview = builder.create<memref::SubViewOp>(
                   loc, array, offsets, dimSizes, strides);
               builder.create<memref::CopyOp>(loc, buffer, subview);

diff  --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index b630ce3a23f30..174f7c79b9d50 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -3,6 +3,7 @@
 // COM: Test MPICH ABI
 // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
 // CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
@@ -22,11 +23,14 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
     %0 = mpi.init : !mpi.retval
 
-    // CHECK: [[v8:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
+    %comm = mpi.comm_world : !mpi.comm
+
+    // CHECK: [[v8:%.*]] = llvm.trunc [[comm]] : i64 to i32
     // CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
     // CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
     // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32
-    %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+    %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
 
     // CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
     // CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -35,9 +39,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
     // CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
-    // CHECK: [[v19:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
-    // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
-    mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+    // CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
+    // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+    mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
 
     // CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -45,9 +49,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
     // CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
-    // CHECK: [[v27:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
-    // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
-    %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+    // CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
+    // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+    %1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
 
     // CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -55,11 +59,11 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
     // CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
-    // CHECK: [[v35:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
     // CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
     // CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
-    // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
-    mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+    // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[comm_3]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+    mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
 
     // CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -67,27 +71,38 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
     // CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
-    // CHECK: [[v45:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32
     // CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
     // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
-    // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
-    %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
-
-    // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    // CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
-    // CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
-    // CHECK: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
-    // CHECK: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
-    // CHECK: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
-    // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
-    mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+    // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+    %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
+    
+    // CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32
+    %color = arith.constant 10 : i32
+    // CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32
+    %key = arith.constant 22 : i32
+    // CHECK: [[v53:%.*]] = llvm.trunc [[comm]] : i64 to i32
+    // CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x i32 : (i32) -> !llvm.ptr
+    // CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (i32, i32, i32, !llvm.ptr) -> i32
+    // CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
+    %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
+
+    // CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK: [[v62:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v63:%.*]] = llvm.trunc [[v62]] : i64 to i32
+    // CHECK: [[v64:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v65:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32
+    // CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32
+    // CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32
+    // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[v61]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+    mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
     // CHECK: llvm.call @MPI_Finalize() : () -> i32
     %3 = mpi.finalize : !mpi.retval
@@ -101,6 +116,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // COM: Test OpenMPI ABI
 // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
@@ -122,11 +138,14 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
     %0 = mpi.init : !mpi.retval
 
+    %comm = mpi.comm_world : !mpi.comm
     // CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[comm:%.*]] = llvm.ptrtoint [[v8]] : !llvm.ptr to i64
+    // CHECK: [[comm_1:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
     // CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
     // CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
-    // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
-    %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+    // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[comm_1]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
+    %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32
 
     // CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
     // CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -135,9 +154,9 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
     // CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
-    // CHECK: [[v19:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
     // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
-    mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+    mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
 
     // CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -145,9 +164,9 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
     // CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
-    // CHECK: [[v27:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
     // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
-    %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+    %1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
 
     // CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -155,11 +174,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
     // CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
-    // CHECK: [[v35:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
     // CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
     // CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
     // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
-    mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+    mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32
 
     // CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -167,11 +186,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
     // CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
-    // CHECK: [[v45:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
     // CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
     // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
     // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
-    %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+    %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
 
     // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -185,11 +204,22 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
     // CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
     // CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr
-    // CHECK: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
     // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
-    mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
+    mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+
+    // CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
+    %color = arith.constant 10 : i32
+    // CHECK: [[v72:%.*]] = llvm.mlir.constant(22 : i32) : i32
+    %key = arith.constant 22 : i32
+    // CHECK: [[v73:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
+    // CHECK: [[v74:%.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: [[v75:%.*]] = llvm.alloca [[v74]] x !llvm.ptr : (i32) -> !llvm.ptr
+    // CHECK: [[v76:%.*]] = llvm.call @MPI_Comm_split([[v73]], [[v71]], [[v72]], [[v75]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+    // CHECK: [[v77:%.*]] = llvm.load [[v75]] : !llvm.ptr -> i32
+    %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
 
-    // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
+    // CHECK: llvm.call @MPI_Finalize() : () -> i32
     %3 = mpi.finalize : !mpi.retval
 
     return

diff  --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 4e60c6f0d4e44..23756bb66928d 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -4,7 +4,7 @@
 // CHECK: mesh.mesh @mesh0
 mesh.mesh @mesh0(shape = 3x4x5)
 func.func @process_multi_index() -> (index, index, index) {
-  // CHECK: mpi.comm_rank : !mpi.retval, i32
+  // CHECK: mpi.comm_rank
   // CHECK-DAG: %[[v4:.*]] = arith.remsi
   // CHECK-DAG: %[[v0:.*]] = arith.remsi
   // CHECK-DAG: %[[v1:.*]] = arith.remsi
@@ -15,7 +15,7 @@ func.func @process_multi_index() -> (index, index, index) {
 
 // CHECK-LABEL: func @process_linear_index
 func.func @process_linear_index() -> index {
-  // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank : !mpi.retval, i32
+  // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank
   // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index
   %0 = mesh.process_linear_index on @mesh0 : index
   // CHECK: return %[[cast]] : index
@@ -113,17 +113,17 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
     // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
     // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
     // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
-    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
-    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
-    // CHECK-SAME: to memref<2x120x120xi8>
-    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
-    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
-    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
-    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
+    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> to memref<2x120x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
     // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
     %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
-    // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
+    // CHECK: return [[varg0]] : memref<120x120x120xi8>
     return %res : memref<120x120x120xi8>
   }
 }
@@ -140,41 +140,44 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
     // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
     // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
     // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+    // CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
     // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
     // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
-    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
     // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
     // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
     // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
     // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
     // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
     // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-    // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
-    // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
     // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
     // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
     // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+    // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
     // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
     // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
-    // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x3x120xi8>, i32, i32
     // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
     // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
-    // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x4x120xi8>, i32, i32
     // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
     // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
     // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+    // CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
-    // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<1x120x120xi8>, i32, i32
     // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
     // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
     // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
     // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
     // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
     // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
-    // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32
     // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
     %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
     // CHECK: return [[varg0]] : memref<120x120x120xi8>
@@ -191,45 +194,48 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
     // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
     // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
     // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
+    // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
     // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
     // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
-    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
     // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
     // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
     // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
     // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
     // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
     // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-    // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
-    // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
     // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
     // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
     // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+    // CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
     // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
     // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
-    // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x3x120xi8>, i32, i32
     // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
     // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
-    // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x4x120xi8>, i32, i32
     // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
     // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
     // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+    // CHECK-NEXT: [[v3:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
-    // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<1x120x120xi8>, i32, i32
     // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
     // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
     // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
     // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
     // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
     // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
-    // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32
     // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
-    // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
+    // CHECK-NEXT: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
     %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
-    // CHECK: return [[v1]] : tensor<120x120x120xi8>
+    // CHECK-NEXT: return [[v4]] : tensor<120x120x120xi8>
     return %res : tensor<120x120x120xi8>
   }
 }

diff  --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index fb4333611a246..ef457628fe2c4 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -1,66 +1,83 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
+// CHECK-LABEL: func.func @mpi_test(
+// CHECK-SAME: [[varg0:%.*]]: memref<100xf32>) {
 func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // Note: the !mpi.retval result is optional on all operations except mpi.error_class
 
-    // CHECK: %0 = mpi.init : !mpi.retval
+    // CHECK-NEXT: [[v0:%.*]] = mpi.init : !mpi.retval
     %err = mpi.init : !mpi.retval
 
-    // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32
-    %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+    // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+    %comm = mpi.comm_world : !mpi.comm
 
-    // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
-    %retval_0, %size = mpi.comm_size : !mpi.retval, i32
+    // CHECK-NEXT: [[vrank:%.*]] = mpi.comm_rank([[v1]]) : i32
+    %rank = mpi.comm_rank(%comm) : i32
 
-    // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
-    mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32
+    // CHECK-NEXT: [[vretval:%.*]], [[vrank_0:%.*]] = mpi.comm_rank([[v1]]) : !mpi.retval, i32
+    %retval, %rank_1 = mpi.comm_rank(%comm) : !mpi.retval, i32
 
-    // CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
-    %err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+    // CHECK-NEXT: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32
+    %size = mpi.comm_size(%comm) : i32
 
-    // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
-    mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32
+    // CHECK-NEXT: [[vretval_1:%.*]], [[vsize_2:%.*]] = mpi.comm_size([[v1]]) : !mpi.retval, i32
+    %retval_0, %size_1 = mpi.comm_size(%comm) : !mpi.retval, i32
 
-    // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
-    %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+    // CHECK-NEXT: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : !mpi.comm
+    %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm
 
-    // CHECK-NEXT: %req = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
-    %req = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
+    // CHECK-NEXT: [[vretval_3:%.*]], [[vnewcomm_4:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : !mpi.retval, !mpi.comm
+    %retval_1, %new_comm_1 = mpi.comm_split(%comm, %rank, %rank) : !mpi.retval, !mpi.comm
 
-    // CHECK-NEXT: %retval_1, %req_2 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
-    %err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+    // CHECK-NEXT: mpi.send([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32
+    mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32
 
-    // CHECK-NEXT: %req_3 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
-    %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
+    // CHECK-NEXT: [[v2:%.*]] = mpi.send([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval
+    %retval_2 = mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
 
-    // CHECK-NEXT: %retval_4, %req_5 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
-    %err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+    // CHECK-NEXT: mpi.recv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32
+    mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32
 
-    // CHECK-NEXT: mpi.wait(%req) : !mpi.request
-    mpi.wait(%req) : !mpi.request
+    // CHECK-NEXT: [[v3:%.*]] = mpi.recv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval
+    %retval_3 = mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
 
-    // CHECK-NEXT: %3 = mpi.wait(%req_2) : !mpi.request -> !mpi.retval
+    // CHECK-NEXT: [[vretval_5:%.*]], [[vreq:%.*]] = mpi.isend([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+    %err4, %req2 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+
+    // CHECK-NEXT: [[vreq_6:%.*]] = mpi.isend([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.request
+    %req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.request
+
+    // CHECK-NEXT: [[vreq_7:%.*]] = mpi.irecv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.request
+    %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.request
+
+    // CHECK-NEXT: [[vretval_8:%.*]], [[vreq_9:%.*]] = mpi.irecv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+    %err5, %req4 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+
+    // CHECK-NEXT: mpi.wait([[vreq_9]]) : !mpi.request
+    mpi.wait(%req4) : !mpi.request
+
+    // CHECK-NEXT: [[v4:%.*]] = mpi.wait([[vreq]]) : !mpi.request -> !mpi.retval
     %err6 = mpi.wait(%req2) : !mpi.request -> !mpi.retval
 
-    // CHECK-NEXT: mpi.barrier : !mpi.retval
-    mpi.barrier : !mpi.retval
+    // CHECK-NEXT: mpi.barrier([[v1]])
+    mpi.barrier(%comm)
 
-    // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
-    %err7 = mpi.barrier : !mpi.retval
+    // CHECK-NEXT: [[v5:%.*]] = mpi.barrier([[v1]]) -> !mpi.retval
+    %err7 = mpi.barrier(%comm) -> !mpi.retval
 
-    // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32>
-    mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32>
+    // CHECK-NEXT: [[v6:%.*]] = mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+    %err8 = mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
 
-    // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
-    %err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+    // CHECK-NEXT: mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
+    mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
-    // CHECK-NEXT: %7 = mpi.finalize : !mpi.retval
+    // CHECK-NEXT: [[v7:%.*]] = mpi.finalize : !mpi.retval
     %rval = mpi.finalize : !mpi.retval
 
-    // CHECK-NEXT: %8 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+    // CHECK-NEXT: [[v8:%.*]] = mpi.retval_check [[vretval:%.*]] = <MPI_SUCCESS> : i1
     %res = mpi.retval_check %retval = <MPI_SUCCESS> : i1
 
-    // CHECK-NEXT: %9 = mpi.error_class %0 : !mpi.retval
+    // CHECK-NEXT: [[v9:%.*]] = mpi.error_class [[v0]] : !mpi.retval
     %errclass = mpi.error_class %err : !mpi.retval
 
     // CHECK-NEXT: return


        


More information about the Mlir-commits mailing list