[Mlir-commits] [mlir] f78027d - [mlir][mesh] Better op result names (#82408)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 20 16:53:31 PST 2024
Author: Boian Petkantchin
Date: 2024-02-20T16:53:26-08:00
New Revision: f78027dfeca9925efe7e025beb05b4cef8a1581a
URL: https://github.com/llvm/llvm-project/commit/f78027dfeca9925efe7e025beb05b4cef8a1581a
DIFF: https://github.com/llvm/llvm-project/commit/f78027dfeca9925efe7e025beb05b4cef8a1581a.diff
LOG: [mlir][mesh] Better op result names (#82408)
Implement OpAsmOpInterface for most ops to increase IR readability. For
example `mesh.process_linear_index` would produce a value with name
`proc_linear_idx`.
Added:
Modified:
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ba7c111aea6bb..b9cd15e2062669 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -16,6 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/CommonTypeConstraints.td"
+include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -78,7 +79,10 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
}
def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
- Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
let summary = "Get the shape of the mesh.";
let arguments = (ins
FlatSymbolRefAttr:$mesh,
@@ -101,7 +105,11 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
];
}
-def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
+def Mesh_ShardOp : Mesh_Op<"shard", [
+ Pure,
+ SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
let summary = "Annotate on how a tensor is sharded across a mesh.";
let description = [{
The mesh.shard operation is designed to specify and guide the sharding
@@ -194,7 +202,8 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
Pure,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the multi index of current device along specified mesh axes.";
let description = [{
@@ -221,7 +230,8 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
Pure,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the linear index of the current device.";
let description = [{
@@ -248,7 +258,10 @@ class Mesh_CollectiveCommunicationOpBase<
string mnemonic, list<Trait> traits = []> :
Mesh_Op<mnemonic,
!listconcat(traits,
- [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
+ [
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ])> {
dag commonArgs = (ins
FlatSymbolRefAttr:$mesh,
DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
@@ -258,7 +271,7 @@ class Mesh_CollectiveCommunicationOpBase<
def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
Pure,
SameOperandsAndResultElementType,
- SameOperandsAndResultRank
+ SameOperandsAndResultRank,
]> {
let summary = "All-gather over a device mesh.";
let description = [{
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 838255cf5a5ba3..50163880e85f96 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -24,7 +24,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
@@ -34,7 +33,6 @@
#include <iterator>
#include <numeric>
#include <optional>
-#include <string>
#include <utility>
#define DEBUG_TYPE "mesh-ops"
@@ -244,6 +242,11 @@ void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
+void MeshShapeOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResults()[0], "mesh_shape");
+}
+
//===----------------------------------------------------------------------===//
// mesh.shard attr
//===----------------------------------------------------------------------===//
@@ -307,6 +310,15 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
std::mem_fn(&MeshAxesAttr::empty));
}
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+void ShardOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "sharding_annotated");
+}
+
//===----------------------------------------------------------------------===//
// mesh.process_multi_index op
//===----------------------------------------------------------------------===//
@@ -345,6 +357,11 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
+void ProcessMultiIndexOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResults()[0], "proc_linear_idx");
+}
+
//===----------------------------------------------------------------------===//
// mesh.process_linear_index op
//===----------------------------------------------------------------------===//
@@ -363,6 +380,11 @@ void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
build(odsBuilder, odsState, mesh.getSymName());
}
+void ProcessLinearIndexOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "proc_linear_idx");
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -606,6 +628,11 @@ void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
}
+void AllGatherOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_gather");
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_reduce op
//===----------------------------------------------------------------------===//
@@ -620,6 +647,11 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}
+void AllReduceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_reduce");
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_slice op
//===----------------------------------------------------------------------===//
@@ -654,6 +686,11 @@ void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
}
+void AllSliceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_slice");
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_to_all op
//===----------------------------------------------------------------------===//
@@ -674,6 +711,11 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
}
+void AllToAllOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "all_to_all");
+}
+
//===----------------------------------------------------------------------===//
// mesh.broadcast op
//===----------------------------------------------------------------------===//
@@ -698,6 +740,11 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
}
+void BroadcastOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "broadcast");
+}
+
//===----------------------------------------------------------------------===//
// mesh.gather op
//===----------------------------------------------------------------------===//
@@ -724,6 +771,11 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
}
+void GatherOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "gather");
+}
+
//===----------------------------------------------------------------------===//
// mesh.recv op
//===----------------------------------------------------------------------===//
@@ -747,6 +799,10 @@ void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
}
+void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "recv");
+}
+
//===----------------------------------------------------------------------===//
// mesh.reduce op
//===----------------------------------------------------------------------===//
@@ -770,6 +826,11 @@ void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
}
+void ReduceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "reduce");
+}
+
//===----------------------------------------------------------------------===//
// mesh.reduce_scatter op
//===----------------------------------------------------------------------===//
@@ -791,6 +852,11 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
}
+void ReduceScatterOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "reduce_scatter");
+}
+
//===----------------------------------------------------------------------===//
// mesh.scatter op
//===----------------------------------------------------------------------===//
@@ -817,6 +883,11 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
}
+void ScatterOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "scatter");
+}
+
//===----------------------------------------------------------------------===//
// mesh.send op
//===----------------------------------------------------------------------===//
@@ -839,6 +910,10 @@ void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
}
+void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "send");
+}
+
//===----------------------------------------------------------------------===//
// mesh.shift op
//===----------------------------------------------------------------------===//
@@ -865,6 +940,11 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// offset % shift_axis_mesh_dim_size == 0.
}
+void ShiftOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "shift");
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
index 677a5982ea2540..e23cfd79a42745 100644
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -6,7 +6,7 @@ mesh.mesh @mesh2d(shape = ?x?)
func.func @multi_index_2d_mesh() -> (index, index) {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
- // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0:2 = mesh.process_multi_index on @mesh2d : index, index
// CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
return %0#0, %0#1 : index, index
@@ -16,7 +16,7 @@ func.func @multi_index_2d_mesh() -> (index, index) {
func.func @multi_index_2d_mesh_single_inner_axis() -> index {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
// CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
- // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0 = mesh.process_multi_index on @mesh2d axes = [0] : index
// CHECK: return %[[MULTI_IDX]]#0 : index
return %0 : index
More information about the Mlir-commits
mailing list