[Mlir-commits] [mlir] 48f8865 - [MLIR] Extend MPI dialect (#123255)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 1 05:33:26 PST 2025


Author: Sergio Sánchez Ramírez
Date: 2025-02-01T07:33:22-06:00
New Revision: 48f88651a01b050a28be99e5cdffe495754ea79a

URL: https://github.com/llvm/llvm-project/commit/48f88651a01b050a28be99e5cdffe495754ea79a
DIFF: https://github.com/llvm/llvm-project/commit/48f88651a01b050a28be99e5cdffe495754ea79a.diff

LOG: [MLIR] Extend MPI dialect (#123255)

cc @tobiasgrosser @wsmoses

this PR adds some new ops and types to the MLIR MPI dialect. the goal is
to get the minimum required ops here to get a project of us working, and
if everything works well, continue adding ops to the mpi dialect on
subsequent PRs until we achieve some level of compliance with the MPI
standard.

---

Things left to do in subsequent PRs:

- Add back the `mpi.comm` type and add as optional argument of current
implemented ops that should support it (i.e. `send`, `recv`, `isend`,
`irecv`, `allreduce`, `barrier`).
- Support defining custom `MPI_Op`s (the MPI operations, not the
tablegen `MPI_Op`) as regions.
- Add more ops.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MPI/IR/MPI.td
    mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
    mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
    mlir/lib/Dialect/MPI/IR/MPIOps.cpp
    mlir/test/Dialect/MPI/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 643612e1e2ee895..7c84443e5520d92 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -215,4 +215,43 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
+def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">;
+def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">;
+def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">;
+def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">;
+def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">;
+def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">;
+def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">;
+def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">;
+def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">;
+def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">;
+def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">;
+def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
+def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
+def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
+
+def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
+      MPI_OpNull,
+      MPI_OpMax,
+      MPI_OpMin,
+      MPI_OpSum,
+      MPI_OpProd,
+      MPI_OpLand,
+      MPI_OpBand,
+      MPI_OpLor,
+      MPI_OpBor,
+      MPI_OpLxor,
+      MPI_OpBxor,
+      MPI_OpMinloc,
+      MPI_OpMaxloc,
+      MPI_OpReplace
+    ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::mpi";
+}
+
+def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 #endif // MLIR_DIALECT_MPI_IR_MPI_TD

diff  --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 240fac5104c34f5..284ba72af9768b7 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -59,6 +59,28 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
   let assemblyFormat = "attr-dict `:` type(results)";
 }
 
+//===----------------------------------------------------------------------===//
+// CommSizeOp
+//===----------------------------------------------------------------------===//
+
+def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
+  let summary = "Get the size of the group associated to the communicator, "
+                "equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`";
+  let description = [{
+    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let results = (
+    outs Optional<MPI_Retval> : $retval,
+    I32 : $size
+  );
+
+  let assemblyFormat = "attr-dict `:` type(results)";
+}
+
 //===----------------------------------------------------------------------===//
 // SendOp
 //===----------------------------------------------------------------------===//
@@ -71,13 +93,17 @@ def MPI_SendOp : MPI_Op<"send", []> {
     `dest`. The `tag` value and communicator enables the library to determine 
     the matching of multiple sends and receives between the same ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supprted for now.
+    Communicators other than `MPI_COMM_WORLD` are not supported for now.
 
     This operation can optionally return an `!mpi.retval` value that can be used
     to check for errors.
   }];
 
-  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
+  let arguments = (
+    ins AnyMemRef : $ref,
+    I32 : $tag,
+    I32 : $rank
+  );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
@@ -87,6 +113,42 @@ def MPI_SendOp : MPI_Op<"send", []> {
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ISendOp
+//===----------------------------------------------------------------------===//
+
+def MPI_ISendOp : MPI_Op<"isend", []> {
+  let summary =
+      "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
+  let description = [{
+    MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
+    rank `dest`. The `tag` value and communicator enables the library to
+    determine the matching of multiple sends and receives between the same
+    ranks.
+
+    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (
+    ins AnyMemRef : $ref,
+    I32 : $tag,
+    I32 : $rank
+  );
+
+  let results = (
+    outs Optional<MPI_Retval>:$retval,
+    MPI_Request : $req
+  );
+
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict "
+                       "`:` type($ref) `,` type($tag) `,` type($rank) "
+                       "`->` type(results)";
+  let hasCanonicalizer = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // RecvOp
 //===----------------------------------------------------------------------===//
@@ -100,7 +162,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
     determine the matching of multiple sends and receives between the same 
     ranks.
 
-    Communicators other than `MPI_COMM_WORLD` are not supprted for now.
+    Communicators other than `MPI_COMM_WORLD` are not supported for now.
     The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object 
     is not yet ported to MLIR.
 
@@ -108,16 +170,134 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
     to check for errors.
   }];
 
-  let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
+  let arguments = (
+    ins AnyMemRef : $ref,
+    I32 : $tag, I32 : $rank
+  );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
                        "type($ref) `,` type($tag) `,` type($rank)"
                        "(`->` type($retval)^)?";
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// IRecvOp
+//===----------------------------------------------------------------------===//
+
+def MPI_IRecvOp : MPI_Op<"irecv", []> {
+  let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
+                "MPI_COMM_WORLD, &req)`";
+  let description = [{
+    MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` 
+    from rank `dest`. The `tag` value and communicator enables the library to 
+    determine the matching of multiple sends and receives between the same 
+    ranks.
+
+    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (
+    ins AnyMemRef : $ref,
+    I32 : $tag,
+    I32 : $rank
+  );
+
+  let results = (
+    outs Optional<MPI_Retval>:$retval,
+    MPI_Request : $req
+  );
+
+  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
+                       "type($ref) `,` type($tag) `,` type($rank) `->`"
+                       "type(results)";
+  let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// AllReduceOp
+//===----------------------------------------------------------------------===//
+
+def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
+  let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, "
+                "MPI_COMM_WORLD)`";
+  let description = [{
+    MPI_Allreduce performs a reduction operation on the values in the sendbuf
+    array and stores the result in the recvbuf array. The operation is 
+    performed across all processes in the communicator.
+
+    The `op` attribute specifies the reduction operation to be performed.
+    Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
+    supported.
+
+    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (
+    ins AnyMemRef : $sendbuf,
+    AnyMemRef : $recvbuf,
+    MPI_OpClassAttr : $op
+  );
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`"
+                       "type($sendbuf) `,` type($recvbuf)"
+                       "(`->` type($retval)^)?";
+}
+
+//===----------------------------------------------------------------------===//
+// BarrierOp
+//===----------------------------------------------------------------------===//
+
+def MPI_Barrier : MPI_Op<"barrier", []> {
+  let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`";
+  let description = [{
+    MPI_Barrier blocks execution until all processes in the communicator have
+    reached this routine.
+
+    Communicators other than `MPI_COMM_WORLD` are not supported for now.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "attr-dict (`:` type($retval) ^)?";
+}
+
+//===----------------------------------------------------------------------===//
+// WaitOp
+//===----------------------------------------------------------------------===//
+
+def MPI_Wait : MPI_Op<"wait", []> {
+  let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`";
+  let description = [{
+    MPI_Wait blocks execution until the request has completed.
+
+    The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object 
+    is not yet ported to MLIR.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (ins MPI_Request : $req);
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) "
+                       "(`->` type($retval) ^)?";
+}
 
 //===----------------------------------------------------------------------===//
 // FinalizeOp
@@ -139,7 +319,6 @@ def MPI_FinalizeOp : MPI_Op<"finalize", []> {
   let assemblyFormat = "attr-dict (`:` type($retval)^)?";
 }
 
-
 //===----------------------------------------------------------------------===//
 // RetvalCheckOp
 //===----------------------------------------------------------------------===//
@@ -163,10 +342,8 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
   let assemblyFormat = "$val `=` $errclass attr-dict `:` type($res)";
 }
 
-
-
 //===----------------------------------------------------------------------===//
-// RetvalCheckOp
+// ErrorClassOp
 //===----------------------------------------------------------------------===//
 
 def MPI_ErrorClassOp : MPI_Op<"error_class", []> {

diff  --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index 87eefa719d45c07..fafea0eac8bb74c 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -40,4 +40,26 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// mpi::RequestType
+//===----------------------------------------------------------------------===//
+
+def MPI_Request : MPI_Type<"Request", "request"> {
+  let summary = "MPI asynchronous request handler";
+  let description = [{
+    This type represents a handler to an asynchronous request.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// mpi::StatusType
+//===----------------------------------------------------------------------===//
+
+def MPI_Status : MPI_Type<"Status", "status"> {
+  let summary = "MPI reception operation status type";
+  let description = [{
+    This type represents the status of a reception operation.
+  }];
+}
+
 #endif // MLIR_DIALECT_MPI_IR_MPITYPES_TD

diff  --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index dcb55d8921364f9..56d8edfbcc0255b 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -53,6 +53,16 @@ void mlir::mpi::RecvOp::getCanonicalizationPatterns(
   results.add<FoldCast<mlir::mpi::RecvOp>>(context);
 }
 
+void mlir::mpi::ISendOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+  results.add<FoldCast<mlir::mpi::ISendOp>>(context);
+}
+
+void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+  results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir
index 8f2421a73396c21..f23a7e18a2ee977 100644
--- a/mlir/test/Dialect/MPI/ops.mlir
+++ b/mlir/test/Dialect/MPI/ops.mlir
@@ -9,6 +9,9 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32
     %retval, %rank = mpi.comm_rank : !mpi.retval, i32
 
+    // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32
+    %retval_0, %size = mpi.comm_size : !mpi.retval, i32
+
     // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
     mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32
 
@@ -21,13 +24,43 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
     %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
-    // CHECK-NEXT: %3 = mpi.finalize : !mpi.retval
+    // CHECK-NEXT: %req = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
+    %req = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
+
+    // CHECK-NEXT: %retval_1, %req_2 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+    %err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+
+    // CHECK-NEXT: %req_3 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
+    %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request
+
+    // CHECK-NEXT: %retval_4, %req_5 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+    %err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request
+
+    // CHECK-NEXT: mpi.wait(%req) : !mpi.request
+    mpi.wait(%req) : !mpi.request
+
+    // CHECK-NEXT: %3 = mpi.wait(%req_2) : !mpi.request -> !mpi.retval
+    %err6 = mpi.wait(%req2) : !mpi.request -> !mpi.retval
+
+    // CHECK-NEXT: mpi.barrier : !mpi.retval
+    mpi.barrier : !mpi.retval
+
+    // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval
+    %err7 = mpi.barrier : !mpi.retval
+
+    // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
+    mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32>
+
+    // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+    %err8 = mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
+    // CHECK-NEXT: %7 = mpi.finalize : !mpi.retval
     %rval = mpi.finalize : !mpi.retval
 
-    // CHECK-NEXT: %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+    // CHECK-NEXT: %8 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
     %res = mpi.retval_check %retval = <MPI_SUCCESS> : i1
 
-    // CHECK-NEXT: %5 = mpi.error_class %0 : !mpi.retval
+    // CHECK-NEXT: %9 = mpi.error_class %0 : !mpi.retval
     %errclass = mpi.error_class %err : !mpi.retval
 
     // CHECK-NEXT: return


        


More information about the Mlir-commits mailing list