[Mlir-commits] [mlir] f78027d - [mlir][mesh] Better op result names (#82408)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 20 16:53:31 PST 2024


Author: Boian Petkantchin
Date: 2024-02-20T16:53:26-08:00
New Revision: f78027dfeca9925efe7e025beb05b4cef8a1581a

URL: https://github.com/llvm/llvm-project/commit/f78027dfeca9925efe7e025beb05b4cef8a1581a
DIFF: https://github.com/llvm/llvm-project/commit/f78027dfeca9925efe7e025beb05b4cef8a1581a.diff

LOG: [mlir][mesh] Better op result names (#82408)

Implement OpAsmOpInterface for most ops to increase IR readability. For
example `mesh.process_linear_index` would produce a value with name
`proc_linear_idx`.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
    mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
    mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ba7c111aea6bb..b9cd15e2062669 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -16,6 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
 include "mlir/IR/CommonAttrConstraints.td"
 include "mlir/IR/CommonTypeConstraints.td"
+include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -78,7 +79,10 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
 }
 
 def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
-  Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+    Pure,
+    DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
   let summary = "Get the shape of the mesh.";
   let arguments = (ins
     FlatSymbolRefAttr:$mesh,
@@ -101,7 +105,11 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
   ];
 }
 
-def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
+def Mesh_ShardOp : Mesh_Op<"shard", [
+    Pure,
+    SameOperandsAndResultType,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
   let summary = "Annotate on how a tensor is sharded across a mesh.";
   let description = [{
     The mesh.shard operation is designed to specify and guide the sharding
@@ -194,7 +202,8 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
 
 def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
   Pure,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
   let summary = "Get the multi index of current device along specified mesh axes.";
   let description = [{
@@ -221,7 +230,8 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
 
 def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
   Pure,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
   let summary = "Get the linear index of the current device.";
   let description = [{
@@ -248,7 +258,10 @@ class Mesh_CollectiveCommunicationOpBase<
     string mnemonic, list<Trait> traits = []> :
     Mesh_Op<mnemonic,
       !listconcat(traits,
-      [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
+        [
+          DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+          DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+        ])> {
   dag commonArgs = (ins
     FlatSymbolRefAttr:$mesh,
     DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
@@ -258,7 +271,7 @@ class Mesh_CollectiveCommunicationOpBase<
 def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
     Pure,
     SameOperandsAndResultElementType,
-    SameOperandsAndResultRank
+    SameOperandsAndResultRank,
   ]> {
   let summary = "All-gather over a device mesh.";
   let description = [{

diff  --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 838255cf5a5ba3..50163880e85f96 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -24,7 +24,6 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
@@ -34,7 +33,6 @@
 #include <iterator>
 #include <numeric>
 #include <optional>
-#include <string>
 #include <utility>
 
 #define DEBUG_TYPE "mesh-ops"
@@ -244,6 +242,11 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
+void MeshShapeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResults()[0], "mesh_shape");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shard attr
 //===----------------------------------------------------------------------===//
@@ -307,6 +310,15 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
                       std::mem_fn(&MeshAxesAttr::empty));
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+void ShardOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "sharding_annotated");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.process_multi_index op
 //===----------------------------------------------------------------------===//
@@ -345,6 +357,11 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
+void ProcessMultiIndexOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResults()[0], "proc_linear_idx");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.process_linear_index op
 //===----------------------------------------------------------------------===//
@@ -363,6 +380,11 @@ void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
   build(odsBuilder, odsState, mesh.getSymName());
 }
 
+void ProcessLinearIndexOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "proc_linear_idx");
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
@@ -606,6 +628,11 @@ void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
 }
 
+void AllGatherOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_gather");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_reduce op
 //===----------------------------------------------------------------------===//
@@ -620,6 +647,11 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
 }
 
+void AllReduceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_reduce");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_slice op
 //===----------------------------------------------------------------------===//
@@ -654,6 +686,11 @@ void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
         APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
 }
 
+void AllSliceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_slice");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_to_all op
 //===----------------------------------------------------------------------===//
@@ -674,6 +711,11 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
 }
 
+void AllToAllOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_to_all");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.broadcast op
 //===----------------------------------------------------------------------===//
@@ -698,6 +740,11 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
 }
 
+void BroadcastOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "broadcast");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.gather op
 //===----------------------------------------------------------------------===//
@@ -724,6 +771,11 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
 }
 
+void GatherOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "gather");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.recv op
 //===----------------------------------------------------------------------===//
@@ -747,6 +799,10 @@ void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
 }
 
+void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "recv");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.reduce op
 //===----------------------------------------------------------------------===//
@@ -770,6 +826,11 @@ void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
 }
 
+void ReduceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "reduce");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.reduce_scatter op
 //===----------------------------------------------------------------------===//
@@ -791,6 +852,11 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
 }
 
+void ReduceScatterOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "reduce_scatter");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.scatter op
 //===----------------------------------------------------------------------===//
@@ -817,6 +883,11 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
 }
 
+void ScatterOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "scatter");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.send op
 //===----------------------------------------------------------------------===//
@@ -839,6 +910,10 @@ void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
 }
 
+void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "send");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shift op
 //===----------------------------------------------------------------------===//
@@ -865,6 +940,11 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   // offset % shift_axis_mesh_dim_size == 0.
 }
 
+void ShiftOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "shift");
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
index 677a5982ea2540..e23cfd79a42745 100644
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -6,7 +6,7 @@ mesh.mesh @mesh2d(shape = ?x?)
 func.func @multi_index_2d_mesh() -> (index, index) {
   // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
   // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
-  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
   %0:2 = mesh.process_multi_index on @mesh2d : index, index
   // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
   return %0#0, %0#1 : index, index
@@ -16,7 +16,7 @@ func.func @multi_index_2d_mesh() -> (index, index) {
 func.func @multi_index_2d_mesh_single_inner_axis() -> index {
   // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
   // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
-  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
   %0 = mesh.process_multi_index on @mesh2d axes = [0] : index
   // CHECK: return %[[MULTI_IDX]]#0 : index
   return %0 : index


        


More information about the Mlir-commits mailing list