[Mlir-commits] [mlir] [mlir][mesh] Make most collectives pure (PR #79643)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 26 11:51:26 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/79643.diff
2 Files Affected:
- (modified) mlir/docs/Dialects/Mesh.md (+18)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+9)
``````````diff
diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
index 77da2f10d8902c9..6df3f53c9f0179f 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Mesh.md
@@ -47,6 +47,24 @@ For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
an in-group device with `(i, j)`. Then for each group with index `g` on the
second axis, the in-group device would be `(i, g, j)`.
+### Purity
+Collectives that involve the whole device group to perform a single operation
+are pure. The exceptions are `send` and `recv`.
+
+There is an assumption that the execution is SPMD.
+Not only that each process runs the same program, but that at the point of
+execution of a collective operation, all processes are in a coherent state.
+All compiler transformations must be consistent.
+Collective operations in the IR that may correspond to the same runtime
+collective operation must be transformed in a consistent manner.
+For example if a collective operation is optimized out, than it must also
+not appear in any path of execution on any process.
+
+Having the operations as `Pure` implies that if an interpreter is to execute
+the IR containing the `mesh` collectives, all processes would execute the same
+line when they reach a pure collective operation.
+This requirement stems from the need to be compatible with general optimization
+passes like dead code and common sub-expression elimination.
## Operations
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 7b301025e687ae3..a28b9b429a2460a 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -259,6 +259,7 @@ class Mesh_CollectiveCommunicationOpBase<
}
def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+ Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
@@ -312,6 +313,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
}
def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+ Pure,
SameOperandsAndResultShape]> {
let summary = "All-reduce over a device mesh.";
let description = [{
@@ -344,6 +346,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
}
def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+ Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "All-to-all over a device mesh.";
@@ -398,6 +401,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
}
def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
+ Pure,
AllShapesMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
@@ -453,6 +457,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
}
def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
+ Pure,
AllRanksMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
@@ -540,6 +545,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
}
def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
+ Pure,
AllShapesMatch<["input", "result"]>
]> {
let summary = "Reduce over a device mesh.";
@@ -581,6 +587,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
}
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+ Pure,
SameOperandsAndResultRank]> {
let summary = "Reduce-scatter over a device mesh.";
let description = [{
@@ -642,6 +649,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
}
def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
+ Pure,
AllRanksMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
@@ -734,6 +742,7 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
}
def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
+ Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultShape
]> {
``````````
</details>
https://github.com/llvm/llvm-project/pull/79643
More information about the Mlir-commits
mailing list