[Mlir-commits] [mlir] c0f29e8 - [mlir][mesh] Make most collectives pure (#79643)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 30 06:57:51 PST 2024


Author: Boian Petkantchin
Date: 2024-01-30T06:57:47-08:00
New Revision: c0f29e83dbcc6789e74918ac6d8d46b8833f45aa

URL: https://github.com/llvm/llvm-project/commit/c0f29e83dbcc6789e74918ac6d8d46b8833f45aa
DIFF: https://github.com/llvm/llvm-project/commit/c0f29e83dbcc6789e74918ac6d8d46b8833f45aa.diff

LOG: [mlir][mesh] Make most collectives pure (#79643)

There are assumptions of matching/consistent paths of execution under SPMD that allow to have pure collective communication operations.

Added: 
    

Modified: 
    mlir/docs/Dialects/Mesh.md
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
index 77da2f10d8902..6df3f53c9f017 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Mesh.md
@@ -47,6 +47,24 @@ For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
 an in-group device with `(i, j)`. Then for each group with index `g` on the
 second axis, the in-group device would be `(i, g, j)`.
 
+### Purity
+Collectives that involve the whole device group to perform a single operation
+are pure. The exceptions are `send` and `recv`.
+
+There is an assumption that the execution is SPMD.
+Not only that each process runs the same program, but that at the point of
+execution of a collective operation, all processes are in a coherent state.
+All compiler transformations must be consistent.
+Collective operations in the IR that may correspond to the same runtime
+collective operation must be transformed in a consistent manner.
+For example if a collective operation is optimized out, than it must also
+not appear in any path of execution on any process.
+
+Having the operations as `Pure` implies that if an interpreter is to execute
+the IR containing the `mesh` collectives, all processes would execute the same
+line when they reach a pure collective operation.
+This requirement stems from the need to be compatible with general optimization
+passes like dead code and common sub-expression elimination.
 
 ## Operations
 

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 7b301025e687a..a28b9b429a246 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -259,6 +259,7 @@ class Mesh_CollectiveCommunicationOpBase<
 }
 
 def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+    Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultRank
   ]> {
@@ -312,6 +313,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
 }
 
 def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+    Pure,
     SameOperandsAndResultShape]> {
   let summary = "All-reduce over a device mesh.";
   let description = [{
@@ -344,6 +346,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
 }
 
 def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+    Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultRank]> {
   let summary = "All-to-all over a device mesh.";
@@ -398,6 +401,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
 }
 
 def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
+    Pure,
     AllShapesMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
@@ -453,6 +457,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
 }
 
 def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
+    Pure,
     AllRanksMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
@@ -540,6 +545,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
 }
 
 def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
+    Pure,
     AllShapesMatch<["input", "result"]>
   ]> {
   let summary = "Reduce over a device mesh.";
@@ -581,6 +587,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
 }
 
 def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+    Pure,
     SameOperandsAndResultRank]> {
   let summary = "Reduce-scatter over a device mesh.";
   let description = [{
@@ -642,6 +649,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
 }
 
 def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
+    Pure,
     AllRanksMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
@@ -734,6 +742,7 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
 }
 
 def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
+    Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultShape
   ]> {


        


More information about the Mlir-commits mailing list