[Mlir-commits] [mlir] [mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr (PR #76886)

Boian Petkantchin llvmlistbot at llvm.org
Wed Jan 3 17:45:35 PST 2024


https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/76886

Analogous to func.call use FlatSymbolRefAttr to reference the corresponding mesh.

>From 371af44dd4721e3a09126e8459827555ea14006e Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 3 Jan 2024 17:36:27 -0800
Subject: [PATCH] [mlir][mesh] In sharding attr use FlatSymbolRefAttr instead
 of SymbolRefAttr

Analogous to func.call use FlatSymbolRefAttr to reference the
corresponding mesh.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td             | 8 ++++----
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td              | 4 ++--
 .../mlir/Dialect/Mesh/Interfaces/ShardingInterface.h      | 4 ++--
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp                      | 2 +-
 mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp    | 2 +-
 5 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 060d54b82efa63..bda6467e9c5d4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
   let mnemonic = "shard";
 
   let parameters = (ins
-    AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
+    AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster,
     ArrayRefParameter<"MeshAxesAttr">:$split_axes,
     OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
     OptionalParameter<"::mlir::mesh::Partial">:$partial_type
@@ -91,7 +91,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     The MeshSharding attribute could be used in the encoding of a
     `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
 
-    1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
+    1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
     cluster where the distributed tensor is placed. The symbol must resolve to a
     `mesh.cluster` operation.
 
@@ -145,7 +145,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
   }];
 
   let builders = [
-    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+    AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
                      "ArrayRef<MeshAxis>": $partial_axes,
                      "mesh::Partial": $partial_type), [{
@@ -156,7 +156,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
       return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
                    partial_type);
     }]>,
-    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+    AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
       return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
     }]>
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 1934bdfb427059..f459077ea12022 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -196,12 +196,12 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
     ```
   }];
   let arguments = (ins
-    Builtin_RankedTensor:$src,
+    AnyRankedTensor:$src,
     MeshSharding:$shard,
     UnitAttr:$annotate_for_users
   );
   let results = (outs
-    Builtin_RankedTensor:$result
+    AnyRankedTensor:$result
   );
   let assemblyFormat = [{
     $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 201c0151754eba..a32274d857f15d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -25,13 +25,13 @@ struct ShardingOption {
   // An array of int array. The sub-array at the i-th position signifies the
   // mesh axes the i-th loop will be sharded on.
   ShardingArray shardingArray = {};
-  SymbolRefAttr cluster = nullptr;
+  FlatSymbolRefAttr cluster = nullptr;
   // `empty` being true indicates that no sharding information can be inferred
   // at present. Note that it is different from the case where an operation is
   // not sharded.
   bool empty = false;
   ShardingOption() = default;
-  ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
+  ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster)
       : shardingArray(std::move(shardingArray)), cluster(cluster) {}
 };
 
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c3d8f1d456106d..6667d409df8b78 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -266,7 +266,7 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                         SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
+                         FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
                          ArrayRef<MeshAxis> partialAxes, Partial) {
   // TODO: At present cluster symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index ee885ab16b7b06..dca7e86e6f07f5 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -215,7 +215,7 @@ namespace {
 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
 static LogicalResult fillShardingOption(Operation *op,
                                         ShardingOption &shardingOption,
-                                        SymbolRefAttr cluster,
+                                        FlatSymbolRefAttr cluster,
                                         ArrayRef<MeshAxis> meshAxes,
                                         unsigned loopIdx) {
   if ((shardingOption.cluster && cluster &&



More information about the Mlir-commits mailing list