[Mlir-commits] [mlir] [mlir][mesh] Better Op result names (PR #82408)

Boian Petkantchin llvmlistbot at llvm.org
Tue Feb 20 11:51:16 PST 2024


https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/82408

Implement OpAsmOpInterface for most ops to increase IR readability. For example `mesh.process_linear_index` would produce a value with name `proc_linear_idx`.

>From 87baff81275558f85d0d05daa754ee9f73b42778 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Tue, 20 Feb 2024 11:40:31 -0800
Subject: [PATCH] [mlir][mesh] Better Op result names

Implement OpAsmOpInterface for most ops to increase IR readability.
For example `mesh.process_linear_index` would produce a value with name
`proc_linear_idx`.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 25 +++--
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 98 +++++++++++++++++++
 .../Mesh/process-multi-index-op-lowering.mlir |  4 +-
 3 files changed, 119 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ba7c111aea6bb..b9cd15e2062669 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -16,6 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
 include "mlir/IR/CommonAttrConstraints.td"
 include "mlir/IR/CommonTypeConstraints.td"
+include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -78,7 +79,10 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
 }
 
 def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
-  Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+    Pure,
+    DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
   let summary = "Get the shape of the mesh.";
   let arguments = (ins
     FlatSymbolRefAttr:$mesh,
@@ -101,7 +105,11 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
   ];
 }
 
-def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
+def Mesh_ShardOp : Mesh_Op<"shard", [
+    Pure,
+    SameOperandsAndResultType,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
   let summary = "Annotate on how a tensor is sharded across a mesh.";
   let description = [{
     The mesh.shard operation is designed to specify and guide the sharding
@@ -194,7 +202,8 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
 
 def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
   Pure,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
   let summary = "Get the multi index of current device along specified mesh axes.";
   let description = [{
@@ -221,7 +230,8 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
 
 def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
   Pure,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
   let summary = "Get the linear index of the current device.";
   let description = [{
@@ -248,7 +258,10 @@ class Mesh_CollectiveCommunicationOpBase<
     string mnemonic, list<Trait> traits = []> :
     Mesh_Op<mnemonic,
       !listconcat(traits,
-      [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
+        [
+          DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+          DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+        ])> {
   dag commonArgs = (ins
     FlatSymbolRefAttr:$mesh,
     DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
@@ -258,7 +271,7 @@ class Mesh_CollectiveCommunicationOpBase<
 def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
     Pure,
     SameOperandsAndResultElementType,
-    SameOperandsAndResultRank
+    SameOperandsAndResultRank,
   ]> {
   let summary = "All-gather over a device mesh.";
   let description = [{
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 838255cf5a5ba3..07a320752b2595 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -27,7 +27,9 @@
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <algorithm>
 #include <functional>
@@ -180,6 +182,20 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
   return type;
 }
 
+// static void getAsmMultiResultNames(Operation* op, StringRef namePrefix,
+// function_ref<void(Value, StringRef)> setNameFn) {
+//   if (op->getNumResults() == 1) {
+//     setNameFn(op->getResult(0), namePrefix);
+//     return;
+//   }
+//   SmallString<64> str;
+//   for (auto [i, result]: llvm::enumerate(op->getResults())) {
+//     (Twine(namePrefix) + "_" + Twine(i) + "_").toStringRef(str);
+//     setNameFn(result, str);
+//     str.clear();
+//   }
+// }
+
 //===----------------------------------------------------------------------===//
 // mesh.mesh op
 //===----------------------------------------------------------------------===//
@@ -244,6 +260,11 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
+void MeshShapeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResults()[0], "mesh_shape");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shard attr
 //===----------------------------------------------------------------------===//
@@ -307,6 +328,15 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
                       std::mem_fn(&MeshAxesAttr::empty));
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+void ShardOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "sharding_annotated");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.process_multi_index op
 //===----------------------------------------------------------------------===//
@@ -345,6 +375,11 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
+void ProcessMultiIndexOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResults()[0], "proc_linear_idx");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.process_linear_index op
 //===----------------------------------------------------------------------===//
@@ -363,6 +398,11 @@ void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
   build(odsBuilder, odsState, mesh.getSymName());
 }
 
+void ProcessLinearIndexOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "proc_linear_idx");
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
@@ -606,6 +646,11 @@ void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
 }
 
+void AllGatherOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_gather");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_reduce op
 //===----------------------------------------------------------------------===//
@@ -620,6 +665,11 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
 }
 
+void AllReduceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_reduce");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_slice op
 //===----------------------------------------------------------------------===//
@@ -654,6 +704,11 @@ void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
         APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
 }
 
+void AllSliceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_slice");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_to_all op
 //===----------------------------------------------------------------------===//
@@ -674,6 +729,11 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
 }
 
+void AllToAllOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "all_to_all");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.broadcast op
 //===----------------------------------------------------------------------===//
@@ -698,6 +758,11 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
 }
 
+void BroadcastOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "broadcast");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.gather op
 //===----------------------------------------------------------------------===//
@@ -724,6 +789,11 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
 }
 
+void GatherOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "gather");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.recv op
 //===----------------------------------------------------------------------===//
@@ -747,6 +817,10 @@ void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
 }
 
+void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "recv");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.reduce op
 //===----------------------------------------------------------------------===//
@@ -770,6 +844,11 @@ void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
 }
 
+void ReduceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "reduce");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.reduce_scatter op
 //===----------------------------------------------------------------------===//
@@ -791,6 +870,11 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
 }
 
+void ReduceScatterOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "reduce_scatter");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.scatter op
 //===----------------------------------------------------------------------===//
@@ -817,6 +901,11 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
 }
 
+void ScatterOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "scatter");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.send op
 //===----------------------------------------------------------------------===//
@@ -839,6 +928,10 @@ void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
 }
 
+void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "send");
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shift op
 //===----------------------------------------------------------------------===//
@@ -865,6 +958,11 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   // offset % shift_axis_mesh_dim_size == 0.
 }
 
+void ShiftOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "shift");
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
index 677a5982ea2540..e23cfd79a42745 100644
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -6,7 +6,7 @@ mesh.mesh @mesh2d(shape = ?x?)
 func.func @multi_index_2d_mesh() -> (index, index) {
   // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
   // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
-  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
   %0:2 = mesh.process_multi_index on @mesh2d : index, index
   // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
   return %0#0, %0#1 : index, index
@@ -16,7 +16,7 @@ func.func @multi_index_2d_mesh() -> (index, index) {
 func.func @multi_index_2d_mesh_single_inner_axis() -> index {
   // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
   // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
-  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
   %0 = mesh.process_multi_index on @mesh2d axes = [0] : index
   // CHECK: return %[[MULTI_IDX]]#0 : index
   return %0 : index



More information about the Mlir-commits mailing list