[Mlir-commits] [mlir] [mlir][mesh] Better Op result names (PR #82408)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 20 11:51:46 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
Implement OpAsmOpInterface for most ops to increase IR readability. For example `mesh.process_linear_index` would produce a value with name `proc_linear_idx`.
---
Full diff: https://github.com/llvm/llvm-project/pull/82408.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+19-6)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+98)
- (modified) mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ba7c111aea6bb..b9cd15e2062669 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -16,6 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/CommonTypeConstraints.td"
+include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -78,7 +79,10 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
}
def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
- Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
let summary = "Get the shape of the mesh.";
let arguments = (ins
FlatSymbolRefAttr:$mesh,
@@ -101,7 +105,11 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
];
}
-def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
+def Mesh_ShardOp : Mesh_Op<"shard", [
+ Pure,
+ SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
let summary = "Annotate on how a tensor is sharded across a mesh.";
let description = [{
The mesh.shard operation is designed to specify and guide the sharding
@@ -194,7 +202,8 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
Pure,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the multi index of current device along specified mesh axes.";
let description = [{
@@ -221,7 +230,8 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
Pure,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the linear index of the current device.";
let description = [{
@@ -248,7 +258,10 @@ class Mesh_CollectiveCommunicationOpBase<
string mnemonic, list<Trait> traits = []> :
Mesh_Op<mnemonic,
!listconcat(traits,
- [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
+ [
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ])> {
dag commonArgs = (ins
FlatSymbolRefAttr:$mesh,
DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
@@ -258,7 +271,7 @@ class Mesh_CollectiveCommunicationOpBase<
def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
Pure,
SameOperandsAndResultElementType,
- SameOperandsAndResultRank
+ SameOperandsAndResultRank,
]> {
let summary = "All-gather over a device mesh.";
let description = [{
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 838255cf5a5ba3..07a320752b2595 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -27,7 +27,9 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>
#include <functional>
@@ -180,6 +182,20 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
return type;
}
+// static void getAsmMultiResultNames(Operation* op, StringRef namePrefix,
+// function_ref<void(Value, StringRef)> setNameFn) {
+// if (op->getNumResults() == 1) {
+// setNameFn(op->getResult(0), namePrefix);
+// return;
+// }
+// SmallString<64> str;
+// for (auto [i, result]: llvm::enumerate(op->getResults())) {
+// (Twine(namePrefix) + "_" + Twine(i) + "_").toStringRef(str);
+// setNameFn(result, str);
+// str.clear();
+// }
+// }
+
//===----------------------------------------------------------------------===//
// mesh.mesh op
//===----------------------------------------------------------------------===//
@@ -244,6 +260,11 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
+void MeshShapeOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResults()[0], "mesh_shape");
+}
+
//===----------------------------------------------------------------------===//
// mesh.shard attr
//===----------------------------------------------------------------------===//
@@ -307,6 +328,15 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
std::mem_fn(&MeshAxesAttr::empty));
}
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+void ShardOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "sharding_annotated");
+}
+
//===----------------------------------------------------------------------===//
// mesh.process_multi_index op
//===----------------------------------------------------------------------===//
@@ -345,6 +375,11 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
+void ProcessMultiIndexOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResults()[0], "proc_linear_idx");
+}
+
//===----------------------------------------------------------------------===//
// mesh.process_linear_index op
//===----------------------------------------------------------------------===//
@@ -363,6 +398,11 @@ void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
build(odsBuilder, odsState, mesh.getSymName());
}
+void ProcessLinearIndexOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "proc_linear_idx");
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -606,6 +646,11 @@ void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
}
+void AllGatherOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_gather");
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_reduce op
//===----------------------------------------------------------------------===//
@@ -620,6 +665,11 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}
+void AllReduceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_reduce");
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_slice op
//===----------------------------------------------------------------------===//
@@ -654,6 +704,11 @@ void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
}
+void AllSliceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_slice");
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_to_all op
//===----------------------------------------------------------------------===//
@@ -674,6 +729,11 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
}
+void AllToAllOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_to_all");
+}
+
//===----------------------------------------------------------------------===//
// mesh.broadcast op
//===----------------------------------------------------------------------===//
@@ -698,6 +758,11 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
}
+void BroadcastOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "broadcast");
+}
+
//===----------------------------------------------------------------------===//
// mesh.gather op
//===----------------------------------------------------------------------===//
@@ -724,6 +789,11 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
}
+void GatherOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "gather");
+}
+
//===----------------------------------------------------------------------===//
// mesh.recv op
//===----------------------------------------------------------------------===//
@@ -747,6 +817,10 @@ void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
}
+void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "recv");
+}
+
//===----------------------------------------------------------------------===//
// mesh.reduce op
//===----------------------------------------------------------------------===//
@@ -770,6 +844,11 @@ void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
}
+void ReduceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "reduce");
+}
+
//===----------------------------------------------------------------------===//
// mesh.reduce_scatter op
//===----------------------------------------------------------------------===//
@@ -791,6 +870,11 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
}
+void ReduceScatterOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "reduce_scatter");
+}
+
//===----------------------------------------------------------------------===//
// mesh.scatter op
//===----------------------------------------------------------------------===//
@@ -817,6 +901,11 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
}
+void ScatterOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "scatter");
+}
+
//===----------------------------------------------------------------------===//
// mesh.send op
//===----------------------------------------------------------------------===//
@@ -839,6 +928,10 @@ void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
}
+void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "send");
+}
+
//===----------------------------------------------------------------------===//
// mesh.shift op
//===----------------------------------------------------------------------===//
@@ -865,6 +958,11 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// offset % shift_axis_mesh_dim_size == 0.
}
+void ShiftOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "shift");
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
index 677a5982ea2540..e23cfd79a42745 100644
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -6,7 +6,7 @@ mesh.mesh @mesh2d(shape = ?x?)
func.func @multi_index_2d_mesh() -> (index, index) {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
- // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0:2 = mesh.process_multi_index on @mesh2d : index, index
// CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
return %0#0, %0#1 : index, index
@@ -16,7 +16,7 @@ func.func @multi_index_2d_mesh() -> (index, index) {
func.func @multi_index_2d_mesh_single_inner_axis() -> index {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
- // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0 = mesh.process_multi_index on @mesh2d axes = [0] : index
// CHECK: return %[[MULTI_IDX]]#0 : index
return %0 : index
``````````
</details>
https://github.com/llvm/llvm-project/pull/82408
More information about the Mlir-commits
mailing list