[Mlir-commits] [mlir] [mlir][mesh] Make most collectives pure (PR #79643)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 26 11:51:26 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/79643.diff


2 Files Affected:

- (modified) mlir/docs/Dialects/Mesh.md (+18) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+9) 


``````````diff
diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
index 77da2f10d8902c9..6df3f53c9f0179f 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Mesh.md
@@ -47,6 +47,24 @@ For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
 an in-group device with `(i, j)`. Then for each group with index `g` on the
 second axis, the in-group device would be `(i, g, j)`.
 
+### Purity
+Collectives that involve the whole device group to perform a single operation
+are pure. The exceptions are `send` and `recv`.
+
+There is an assumption that the execution is SPMD.
+Not only that each process runs the same program, but that at the point of
+execution of a collective operation, all processes are in a coherent state.
+All compiler transformations must be consistent.
+Collective operations in the IR that may correspond to the same runtime
+collective operation must be transformed in a consistent manner.
+For example if a collective operation is optimized out, than it must also
+not appear in any path of execution on any process.
+
+Having the operations as `Pure` implies that if an interpreter is to execute
+the IR containing the `mesh` collectives, all processes would execute the same
+line when they reach a pure collective operation.
+This requirement stems from the need to be compatible with general optimization
+passes like dead code and common sub-expression elimination.
 
 ## Operations
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 7b301025e687ae3..a28b9b429a2460a 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -259,6 +259,7 @@ class Mesh_CollectiveCommunicationOpBase<
 }
 
 def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+    Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultRank
   ]> {
@@ -312,6 +313,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
 }
 
 def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+    Pure,
     SameOperandsAndResultShape]> {
   let summary = "All-reduce over a device mesh.";
   let description = [{
@@ -344,6 +346,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
 }
 
 def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+    Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultRank]> {
   let summary = "All-to-all over a device mesh.";
@@ -398,6 +401,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
 }
 
 def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
+    Pure,
     AllShapesMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
@@ -453,6 +457,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
 }
 
 def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
+    Pure,
     AllRanksMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
@@ -540,6 +545,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
 }
 
 def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
+    Pure,
     AllShapesMatch<["input", "result"]>
   ]> {
   let summary = "Reduce over a device mesh.";
@@ -581,6 +587,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
 }
 
 def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+    Pure,
     SameOperandsAndResultRank]> {
   let summary = "Reduce-scatter over a device mesh.";
   let description = [{
@@ -642,6 +649,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
 }
 
 def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
+    Pure,
     AllRanksMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
@@ -734,6 +742,7 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
 }
 
 def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
+    Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultShape
   ]> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/79643


More information about the Mlir-commits mailing list