[Mlir-commits] [mlir] c0f29e8 - [mlir][mesh] Make most collectives pure (#79643)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 30 06:57:51 PST 2024
Author: Boian Petkantchin
Date: 2024-01-30T06:57:47-08:00
New Revision: c0f29e83dbcc6789e74918ac6d8d46b8833f45aa
URL: https://github.com/llvm/llvm-project/commit/c0f29e83dbcc6789e74918ac6d8d46b8833f45aa
DIFF: https://github.com/llvm/llvm-project/commit/c0f29e83dbcc6789e74918ac6d8d46b8833f45aa.diff
LOG: [mlir][mesh] Make most collectives pure (#79643)
There are assumptions of matching/consistent paths of execution under SPMD that allow to have pure collective communication operations.
Added:
Modified:
mlir/docs/Dialects/Mesh.md
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
index 77da2f10d8902..6df3f53c9f017 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Mesh.md
@@ -47,6 +47,24 @@ For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
an in-group device with `(i, j)`. Then for each group with index `g` on the
second axis, the in-group device would be `(i, g, j)`.
+### Purity
+Collectives that involve the whole device group to perform a single operation
+are pure. The exceptions are `send` and `recv`.
+
+There is an assumption that the execution is SPMD.
+Not only that each process runs the same program, but that at the point of
+execution of a collective operation, all processes are in a coherent state.
+All compiler transformations must be consistent.
+Collective operations in the IR that may correspond to the same runtime
+collective operation must be transformed in a consistent manner.
+For example if a collective operation is optimized out, than it must also
+not appear in any path of execution on any process.
+
+Having the operations as `Pure` implies that if an interpreter is to execute
+the IR containing the `mesh` collectives, all processes would execute the same
+line when they reach a pure collective operation.
+This requirement stems from the need to be compatible with general optimization
+passes like dead code and common sub-expression elimination.
## Operations
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 7b301025e687a..a28b9b429a246 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -259,6 +259,7 @@ class Mesh_CollectiveCommunicationOpBase<
}
def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+ Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
@@ -312,6 +313,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
}
def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+ Pure,
SameOperandsAndResultShape]> {
let summary = "All-reduce over a device mesh.";
let description = [{
@@ -344,6 +346,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
}
def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+ Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "All-to-all over a device mesh.";
@@ -398,6 +401,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
}
def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
+ Pure,
AllShapesMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
@@ -453,6 +457,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
}
def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
+ Pure,
AllRanksMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
@@ -540,6 +545,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
}
def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
+ Pure,
AllShapesMatch<["input", "result"]>
]> {
let summary = "Reduce over a device mesh.";
@@ -581,6 +587,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
}
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+ Pure,
SameOperandsAndResultRank]> {
let summary = "Reduce-scatter over a device mesh.";
let description = [{
@@ -642,6 +649,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
}
def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
+ Pure,
AllRanksMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
@@ -734,6 +742,7 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
}
def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
+ Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultShape
]> {
More information about the Mlir-commits
mailing list