[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 9 06:55:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

- Replacing `#mesh.sharding` attribute with operation `mesh.sharding`
  - extended semantics now allow providing optional `halo_sizes` and `sharded_dims_sizes`
  - internally a sharding is represented as a non-IR class `mesh::MeshSharding`

What previously was
```mlir
%sharded0 = mesh.shard %arg0 <@<!-- -->mesh0, [[0]]> : tensor<4x8xf32>
%sharded1 = mesh.shard %arg1 <@<!-- -->mesh0, [[0]]> annotate_for_users : tensor<16x8xf32>
```
is now
```mlir
%sharding = mesh.sharding @<!-- -->mesh0, [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
%1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
```
and allows additional annotations to control the shard sizes:
```mlir
mesh.mesh @<!-- -->mesh1d_4(shape = 4)
%sharding0 = mesh.sharding @<!-- -->mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32>
%sharding0 = mesh.sharding @<!-- -->mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding
%1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
```
- `mesh.shard` op accepts additional optional attribute `force`, useful for halo updates
- Some initial spmdization support for the new semantics
- Support for `tensor.empty` reacting on `sharded_dims_sizes` and `halo_sizes` in the sharding
- New collective operation `mesh.update_halo` as a spmdized target for shardings with `halo_sizes`

@<!-- -->sogartar @<!-- -->yaochengji

---

Patch is 207.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98145.diff


28 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt (+4) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+22-90) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+68-11) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+274-73) 
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+14-15) 
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+6-6) 
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h (+13-11) 
- (added) mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h (+23) 
- (modified) mlir/include/mlir/InitAllDialects.h (+2) 
- (modified) mlir/include/mlir/Interfaces/InferTypeOpInterface.h (+2-2) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp (+15-16) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+317-65) 
- (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+73-57) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+32-34) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+159-90) 
- (modified) mlir/lib/Dialect/Tensor/IR/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp (+103) 
- (modified) mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir (+16-8) 
- (modified) mlir/test/Dialect/Linalg/mesh-spmdization.mlir (+54-41) 
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+2-2) 
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+35-11) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+96-28) 
- (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+56-22) 
- (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+125-73) 
- (modified) mlir/test/Dialect/Mesh/simplifications.mlir (+8-8) 
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+132-29) 
- (added) mlir/test/Dialect/Tensor/mesh-spmdization.mlir (+42) 
- (modified) mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp (+10-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index 7ba966d8cab7c..f26c6285efd89 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -13,6 +13,10 @@ set(LLVM_TARGET_DEFINITIONS MeshBase.td)
 mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
 mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
 
+set(LLVM_TARGET_DEFINITIONS MeshBase.td)
+mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs)
+
 set(LLVM_TARGET_DEFINITIONS MeshOps.td)
 mlir_tablegen(MeshOps.h.inc -gen-op-decls)
 mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 3a85bf2d552f3..61403ac178980 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -12,6 +12,7 @@
 include "mlir/IR/OpBase.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/CommonAttrConstraints.td"
 include "mlir/IR/EnumAttr.td"
 
 //===----------------------------------------------------------------------===//
@@ -31,11 +32,13 @@ def Mesh_Dialect : Dialect {
   ];
 
   let useDefaultAttributePrinterParser = 1;
+  let useDefaultTypePrinterParser = 1;
   let hasConstantMaterializer = 1;
 }
 
 def Mesh_MeshAxis : I<16>;
 def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
 
 //===----------------------------------------------------------------------===//
 // Mesh Enums.
@@ -59,104 +62,33 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
 }
 
 def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
-  let assemblyFormat = "`<` $value `>`";
+  let assemblyFormat = "$value";
+}
+
+class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
+                   string baseCppClass = "::mlir::Type">
+    : TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
+  let mnemonic = typeMnemonic;
+}
+
+def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
+  let summary = "sharding definition";
+  let assemblyFormat = "";
 }
 
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
 //===----------------------------------------------------------------------===//
 
-def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
-  let mnemonic = "shard";
-
-  let parameters = (ins
-    AttrParameter<"::mlir::FlatSymbolRefAttr",
-     "The mesh on which tensors are sharded.">:$mesh,
-    ArrayRefParameter<"MeshAxesAttr">:$split_axes,
-    OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
-    OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
-  );
-
-  let summary = "Attribute that extends tensor type to distributed tensor type.";
-
-  let description = [{
-    The MeshSharding attribute is used in a `mesh.shard` operation.
-    It specifies how a tensor is sharded and distributed across the process
-    mesh.
-
-    1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
-    mesh where the distributed tensor is placed. The symbol must resolve to a
-    `mesh.mesh` operation.
-
-    2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
-    maximum size is the `rank` of the related tensor. For the i-th sub-array, if
-    its value is [x, y], it indicates that the tensor's i-th dimension is splitted
-    along the x and y axes of the device mesh.
-
-    3. `partial_axes`: if not empty, this signifies that the tensor is partial
-    one along the specified mesh axes. An all-reduce should be applied to obtain
-    the complete tensor, with reduction type being specified by `partial_type`.
-
-    4. `partial_type`: indicates the reduction type of the possible all-reduce
-    op. It has 4 possible values:
-    `generic`: is not an allowed value inside a shard attribute.
-
-    Example:
-
-    ```
-    mesh.mesh @mesh0(shape = 2x2x4)
-
-    // The tensor is fully replicated on @mesh0.
-    // Currently, there must be at least one sub-array present in axes, even
-    // if it's empty. Otherwise, a parsing error will occur.
-    #mesh.shard<@mesh0, [[]]>
-
-    // The tensor is sharded on the first dimension along axis 0 of @mesh0
-    #mesh.shard<@mesh0, [[0]]>
-
-    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
-    // it is also a partial_sum along mesh axis 1.
-    #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
-
-    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
-    // it is also a partial_max along mesh axis 1.
-    #mesh.shard<@mesh0, [[0]], partial = max[1]>
-
-    // Could be used in the attribute of mesh.shard op
-    %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
-    ```
-  }];
-  let assemblyFormat = [{
-    `<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
-       $partial_axes^ `]`)? `>`
-  }];
-
-  let builders = [
-    AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
-                     "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
-                     "ArrayRef<MeshAxis>": $partial_axes,
-                     "mesh::ReductionKind": $partial_type), [{
-      SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
-                  split_axes, [&](ArrayRef<MeshAxis> array) {
-          return MeshAxesAttr::get($_ctxt, array);
-      });
-      return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
-                   partial_type);
-    }]>,
-    AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
-                     "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
-      return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
-    }]>
-  ];
-
+def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
+  let mnemonic = "axisarray";
+  let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
+  let assemblyFormat = "`[` $axes `]`";
   let extraClassDeclaration = [{
-    bool operator==(::mlir::Attribute rhs) const;
-    bool operator!=(::mlir::Attribute rhs) const;
-    bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
-    bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
+    size_t size() const { return getAxes().size(); }
+    auto begin() const { return getAxes().begin(); }
+    auto end() const { return getAxes().end(); }
   }];
-
-  let genVerifyDecl = 1;
 }
 
 #endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index b27c9e81b3293..3c467d6f95948 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,6 +24,8 @@ namespace mesh {
 
 using MeshAxis = int16_t;
 using MeshAxesAttr = DenseI16ArrayAttr;
+using ShardShapeAttr = DenseI64ArrayAttr;
+using HaloSizePairAttr = DenseI64ArrayAttr;
 
 } // namespace mesh
 } // namespace mlir
@@ -33,6 +35,59 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
 
+namespace mlir {
+namespace mesh {
+
+class MeshSharding {
+private:
+  ::mlir::FlatSymbolRefAttr mesh;
+  SmallVector<MeshAxesAttr> split_axes;
+  SmallVector<MeshAxis> partial_axes;
+  ReductionKind partial_type;
+  SmallVector<int64_t> static_halo_sizes;
+  SmallVector<int64_t> static_sharded_dims_sizes;
+  SmallVector<Value> dynamic_halo_sizes;
+  SmallVector<Value> dynamic_sharded_dims_sizes;
+
+public:
+  MeshSharding() = default;
+  MeshSharding(Value rhs);
+  static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
+                          ArrayRef<MeshAxesAttr> split_axes_,
+                          ArrayRef<MeshAxis> partial_axes_ = {},
+                          ReductionKind partial_type_ = ReductionKind::Sum,
+                          ArrayRef<int64_t> static_halo_sizes_ = {},
+                          ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+                          ArrayRef<Value> dynamic_halo_sizes_ = {},
+                          ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+  ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
+  ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+  ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
+  ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
+  ReductionKind getPartialType() const { return partial_type; }
+  ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
+  ArrayRef<int64_t> getStaticShardedDimsSizes() const {
+    return static_sharded_dims_sizes;
+  }
+  ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
+  ArrayRef<Value> getDynamicShardedDimsSizes() const {
+    return dynamic_sharded_dims_sizes;
+  }
+  operator bool() const { return (!mesh) == false; }
+  bool operator==(Value rhs) const;
+  bool operator!=(Value rhs) const;
+  bool operator==(const MeshSharding &rhs) const;
+  bool operator!=(const MeshSharding &rhs) const;
+  bool sameExceptConstraint(const MeshSharding &rhs) const;
+  bool sameConstraint(const MeshSharding &rhs) const;
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
 
@@ -50,9 +105,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 }
 
 // Is the same tensor replicated on all processes.
-inline bool isFullReplication(MeshShardingAttr attr) {
-  return attr.getPartialAxes().empty() &&
-         llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
+inline bool isFullReplication(MeshSharding sharding) {
+  return sharding.getPartialAxes().empty() &&
+         llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
            return axes.asArrayRef().empty();
          });
 }
@@ -80,8 +135,10 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
 template <>
 inline mesh::MeshOp
 getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
-  return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
-                 symbolTableCollection);
+  return getMesh(
+      op.getOperation(),
+      cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
+      symbolTableCollection);
 }
 
 // Get the number of processes that participate in each group
@@ -131,22 +188,22 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
 // result in a shape for each shard of ?x2x?.
 ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
-                           MeshShardingAttr sharding);
+                           MeshSharding sharding);
 
 // If ranked tensor type return its sharded counterpart.
 //
 // If not ranked tensor type return `type`.
 // `sharding` in that case must be null.
-Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 
 // Insert shard op if there is not one that already has the same sharding.
 // May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                          OpOperand &operand,
                                          OpBuilder &builder);
-void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
-                                         OpResult result, OpBuilder &builder);
-void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
+                                         OpBuilder &builder);
+void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
                                          OpOperand &operand,
                                          OpBuilder &builder);
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8e1e475463585..49c4037942f6f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -20,7 +20,7 @@ include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
-// Mesh Dialect operations.
+// Mesh operations.
 //===----------------------------------------------------------------------===//
 
 class Mesh_Op<string mnemonic, list<Trait> traits = []> :
@@ -105,22 +105,221 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
   ];
 }
 
+def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary = "Get the multi index of current device along specified mesh axes.";
+  let description = [{
+    It is used in the SPMD format of IR.
+    The `axes` mush be non-negative and less than the total number of mesh axes.
+    If the axes are empty then get the index along all axes.
+  }];
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+  let results = (outs
+    Variadic<Index>:$result
+  );
+  let assemblyFormat = [{
+    `on` $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+  ];
+}
+
+def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary = "Get the linear index of the current device.";
+  let description = [{
+    Example:
+    ```
+    %idx = mesh.process_linear_index on @mesh : index
+    ```
+    if `@mesh` has shape `(10, 20, 30)`, a device with multi
+    index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
+  }];
+  let arguments = (ins FlatSymbolRefAttr:$mesh);
+  let results = (outs Index:$result);
+  let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Sharding operations.
+//===----------------------------------------------------------------------===//
+
+def Mesh_ShardingOp : Mesh_Op<"sharding", [
+    Pure,
+    AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Define a sharding of a tensor.";
+  let description = [{
+    The MeshSharding specifies how a tensor is sharded and distributed across the
+    process mesh. It is typically used in a `mesh.shard` operation.
+    The operation has the follwing attributes and operands:
+
+    1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+    mesh where the distributed tensor is placed. The symbol must resolve to a
+    `mesh.mesh` operation.
+
+    2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+    maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+    its value is [x, y], it indicates that the tensor's i-th dimension is splitted
+    along the x and y axes of the device mesh.
+
+    3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
+    one along the specified mesh axes. An all-reduce should be applied to obtain
+    the complete tensor, with reduction type being specified by `partial_type`.
+
+    4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
+    op. It has 4 possible values:
+    `generic`: is not an allowed value inside a shard attribute.
+
+    5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
+    `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
+    `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
+    halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
+    `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
+    e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
+    `?` indicates dynamic halo sizes.
+    
+    6. [Optional] Sizes of sharded dimensions of each shard.
+    `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
+    device-mesh one value for each sharded tensor dimension.
+    Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
+    `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
+    the device-mesh will get a shard of shape 16x8x32 and the second device will get a
+    shard of shape 16x24x32.
+    `?` indicates dynamic shard dimensions.
+    
+    `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+
+    Examples:
+
+    ```
+    mesh.mesh @mesh0(shape = 2x2x4)
+    mesh.mesh @mesh1d_4(shape = 4)
+
+    // The tensor is fully replicated on @mesh0.
+    // Currently, there must be at least one sub-array present in axes, even
+    // if it's empty. Otherwise, a parsing error will occur.
+    %sharding0 = mesh.sharding @mesh0, [[]]
+
+    // The tensor is sharded on the first dimension along axis 0 of @mesh0
+    %sharding1 = mesh.sharding @mesh0, [[0]]
+
+    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+    // it is also a partial_sum along mesh axis 1.
+    %sharding2 = mesh.sharding @mesh0, [[0], []] partial = sum[1]
+
+    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+    // it is also a partial_max along mesh axis 1.
+    %sharding3 = mesh.sharding @mesh0, [[0]] partial = max[1]
+
+    // Could be used for a mesh.shard op
+    %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
+
+    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+    // and it has halo-sizes of 1 and 2 on the sharded dim.
+    %halo_sharding = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2]
+    %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>
+    
+    // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
+    // and it has pre-defined shard sizes. The shards of the devices will have
+    // the following shapes: [4x2, 4x3, 4x4, 4x5]
+    %sharding4 = mesh.sharding @mesh1d_4, [[], [0]] sharded_dims_sizes = [2, 3, 4, 5]
+    %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
+    ```
+  }];
+    
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    Mesh_MeshAxesArrayAttr:$split_axes,
+    OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
+    OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
+    Variadic<I64>:$dynamic_sharded_dims_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
+    Variadic<I64>:$dynamic_halo_sizes
+  );
+  let results = (outs
+    Mesh_Sharding:$result
+  );
+  let assemblyFormat = [{
+    $mesh `,` $split_axes
+    (`partial` `=` $partial_type $partial_axes^)?
+    (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
+    (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+    attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+                   "ArrayRef<MeshAxesAttr>":$split_axes,
+                   "ArrayRef<MeshAxis>":$partial_axes,
+                   "mesh::ReductionKind":$partial_type,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+                   "ArrayRef<MeshAxesAttr>":$split_axes)>,
+    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+                   "ArrayRef<MeshAxesAttr...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/98145


More information about the Mlir-commits mailing list