[Mlir-commits] [mlir] baabcb2 - [mlir][mesh] Shardingcontrol (#102598)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 12 04:21:01 PDT 2024
Author: Frank Schlimbach
Date: 2024-08-12T12:20:58+01:00
New Revision: baabcb28983edf8f20e39b89e2b1745412073b44
URL: https://github.com/llvm/llvm-project/commit/baabcb28983edf8f20e39b89e2b1745412073b44
DIFF: https://github.com/llvm/llvm-project/commit/baabcb28983edf8f20e39b89e2b1745412073b44.diff
LOG: [mlir][mesh] Shardingcontrol (#102598)
This is a fixed copy of #98145 (necessary after it got reverted).
@sogartar @yaochengji
This PR adds the following to #98145:
- `UpdateHaloOp` accepts a `memref` (instead of a tensor) and not
returning a result to clarify its inplace-semantics
- `UpdateHaloOp` accepts `split_axis` to allow multiple mesh-axes per
tensor/memref-axis (similar to `mesh.sharding`)
- The implementation of `Shardinginterface` for tensor operation
(`tensor.empty` for now) moved from the tensor library to the mesh
interface library. `spmdize` uses features from `mesh` dialect.
@rengolin agreed that `tensor` should not depend on `mesh` so this
functionality cannot live in a `tensor`s lib. The unfulfilled dependency
caused the issues leading to reverting #98145. Such cases are generally
possible and might lead to re-considering the current structure (like
for tosa ops).
- rebased onto latest main
--------------------------
Replacing `#mesh.sharding` attribute with operation `mesh.sharding`
- extended semantics now allow providing optional `halo_sizes` and
`sharded_dims_sizes`
- internally a sharding is represented as a non-IR class
`mesh::MeshSharding`
What previously was
```mlir
%sharded0 = mesh.shard %arg0 <@mesh0, [[0]]> : tensor<4x8xf32>
%sharded1 = mesh.shard %arg1 <@mesh0, [[0]]> annotate_for_users : tensor<16x8xf32>
```
is now
```mlir
%sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
%1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
```
and allows additional annotations to control the shard sizes:
```mlir
mesh.mesh @mesh0 (shape = 4)
%sharding0 = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding0 : tensor<4x8xf32>
%sharding1 = mesh.sharding @mesh0, [[0]] sharded_dims_sizes = [3, 5, 5, 3] : !mesh.sharding
%1 = mesh.shard %arg1 to %sharding1 annotate_for_users : tensor<16x8xf32>
```
- `mesh.shard` op accepts additional optional attribute `force`, useful
for halo updates
- Some initial spmdization support for the new semantics
- Support for `tensor.empty` reacting on `sharded_dims_sizes` and
`halo_sizes` in the sharding
- New collective operation `mesh.update_halo` as a spmdized target for
shardings with `halo_sizes`
---------
Co-authored-by: frank.schlimbach <fschlimb at smtp.igk.intel.com>
Co-authored-by: Jie Fu <jiefu at tencent.com>
Added:
mlir/include/mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h
mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp
mlir/test/Dialect/Tensor/mesh-spmdization.mlir
Modified:
mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
mlir/test/Dialect/Linalg/mesh-spmdization.mlir
mlir/test/Dialect/Mesh/canonicalization.mlir
mlir/test/Dialect/Mesh/invalid.mlir
mlir/test/Dialect/Mesh/ops.mlir
mlir/test/Dialect/Mesh/resharding-spmdization.mlir
mlir/test/Dialect/Mesh/sharding-propagation.mlir
mlir/test/Dialect/Mesh/simplifications.mlir
mlir/test/Dialect/Mesh/spmdization.mlir
mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index 7ba966d8cab7c8..f26c6285efd896 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -13,6 +13,10 @@ set(LLVM_TARGET_DEFINITIONS MeshBase.td)
mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
+set(LLVM_TARGET_DEFINITIONS MeshBase.td)
+mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs)
+
set(LLVM_TARGET_DEFINITIONS MeshOps.td)
mlir_tablegen(MeshOps.h.inc -gen-op-decls)
mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 3a85bf2d552f3b..61403ac1789802 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -12,6 +12,7 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/EnumAttr.td"
//===----------------------------------------------------------------------===//
@@ -31,11 +32,13 @@ def Mesh_Dialect : Dialect {
];
let useDefaultAttributePrinterParser = 1;
+ let useDefaultTypePrinterParser = 1;
let hasConstantMaterializer = 1;
}
def Mesh_MeshAxis : I<16>;
def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
//===----------------------------------------------------------------------===//
// Mesh Enums.
@@ -59,104 +62,33 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
}
def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
- let assemblyFormat = "`<` $value `>`";
+ let assemblyFormat = "$value";
+}
+
+class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
+ string baseCppClass = "::mlir::Type">
+ : TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
+ let mnemonic = typeMnemonic;
+}
+
+def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
+ let summary = "sharding definition";
+ let assemblyFormat = "";
}
//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
-def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
- let mnemonic = "shard";
-
- let parameters = (ins
- AttrParameter<"::mlir::FlatSymbolRefAttr",
- "The mesh on which tensors are sharded.">:$mesh,
- ArrayRefParameter<"MeshAxesAttr">:$split_axes,
- OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
- OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
- );
-
- let summary = "Attribute that extends tensor type to distributed tensor type.";
-
- let description = [{
- The MeshSharding attribute is used in a `mesh.shard` operation.
- It specifies how a tensor is sharded and distributed across the process
- mesh.
-
- 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
- mesh where the distributed tensor is placed. The symbol must resolve to a
- `mesh.mesh` operation.
-
- 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
- maximum size is the `rank` of the related tensor. For the i-th sub-array, if
- its value is [x, y], it indicates that the tensor's i-th dimension is splitted
- along the x and y axes of the device mesh.
-
- 3. `partial_axes`: if not empty, this signifies that the tensor is partial
- one along the specified mesh axes. An all-reduce should be applied to obtain
- the complete tensor, with reduction type being specified by `partial_type`.
-
- 4. `partial_type`: indicates the reduction type of the possible all-reduce
- op. It has 4 possible values:
- `generic`: is not an allowed value inside a shard attribute.
-
- Example:
-
- ```
- mesh.mesh @mesh0(shape = 2x2x4)
-
- // The tensor is fully replicated on @mesh0.
- // Currently, there must be at least one sub-array present in axes, even
- // if it's empty. Otherwise, a parsing error will occur.
- #mesh.shard<@mesh0, [[]]>
-
- // The tensor is sharded on the first dimension along axis 0 of @mesh0
- #mesh.shard<@mesh0, [[0]]>
-
- // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
- // it is also a partial_sum along mesh axis 1.
- #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
-
- // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
- // it is also a partial_max along mesh axis 1.
- #mesh.shard<@mesh0, [[0]], partial = max[1]>
-
- // Could be used in the attribute of mesh.shard op
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
- ```
- }];
- let assemblyFormat = [{
- `<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
- $partial_axes^ `]`)? `>`
- }];
-
- let builders = [
- AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
- "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
- "ArrayRef<MeshAxis>": $partial_axes,
- "mesh::ReductionKind": $partial_type), [{
- SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
- split_axes, [&](ArrayRef<MeshAxis> array) {
- return MeshAxesAttr::get($_ctxt, array);
- });
- return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
- partial_type);
- }]>,
- AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
- "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
- return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
- }]>
- ];
-
+def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
+ let mnemonic = "axisarray";
+ let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
+ let assemblyFormat = "`[` $axes `]`";
let extraClassDeclaration = [{
- bool operator==(::mlir::Attribute rhs) const;
- bool operator!=(::mlir::Attribute rhs) const;
- bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
- bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
+ size_t size() const { return getAxes().size(); }
+ auto begin() const { return getAxes().begin(); }
+ auto end() const { return getAxes().end(); }
}];
-
- let genVerifyDecl = 1;
}
#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index b27c9e81b32933..683975bbf215ed 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,6 +24,8 @@ namespace mesh {
using MeshAxis = int16_t;
using MeshAxesAttr = DenseI16ArrayAttr;
+using ShardShapeAttr = DenseI64ArrayAttr;
+using HaloSizePairAttr = DenseI64ArrayAttr;
} // namespace mesh
} // namespace mlir
@@ -33,6 +35,59 @@ using MeshAxesAttr = DenseI16ArrayAttr;
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
+namespace mlir {
+namespace mesh {
+
+class MeshSharding {
+private:
+ ::mlir::FlatSymbolRefAttr mesh;
+ SmallVector<MeshAxesAttr> split_axes;
+ SmallVector<MeshAxis> partial_axes;
+ ReductionKind partial_type;
+ SmallVector<int64_t> static_halo_sizes;
+ SmallVector<int64_t> static_sharded_dims_sizes;
+ SmallVector<Value> dynamic_halo_sizes;
+ SmallVector<Value> dynamic_sharded_dims_sizes;
+
+public:
+ MeshSharding() = default;
+ MeshSharding(Value rhs);
+ static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
+ ArrayRef<MeshAxesAttr> split_axes_,
+ ArrayRef<MeshAxis> partial_axes_ = {},
+ ReductionKind partial_type_ = ReductionKind::Sum,
+ ArrayRef<int64_t> static_halo_sizes_ = {},
+ ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+ ArrayRef<Value> dynamic_halo_sizes_ = {},
+ ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+ ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
+ ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+ ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
+ ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
+ ReductionKind getPartialType() const { return partial_type; }
+ ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
+ ArrayRef<int64_t> getStaticShardedDimsSizes() const {
+ return static_sharded_dims_sizes;
+ }
+ ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
+ ArrayRef<Value> getDynamicShardedDimsSizes() const {
+ return dynamic_sharded_dims_sizes;
+ }
+ operator bool() const { return (!mesh) == false; }
+ bool operator==(Value rhs) const;
+ bool operator!=(Value rhs) const;
+ bool operator==(const MeshSharding &rhs) const;
+ bool operator!=(const MeshSharding &rhs) const;
+ bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
+ bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
@@ -50,9 +105,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
}
// Is the same tensor replicated on all processes.
-inline bool isFullReplication(MeshShardingAttr attr) {
- return attr.getPartialAxes().empty() &&
- llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
+inline bool isFullReplication(MeshSharding sharding) {
+ return sharding.getPartialAxes().empty() &&
+ llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
return axes.asArrayRef().empty();
});
}
@@ -80,8 +135,10 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
template <>
inline mesh::MeshOp
getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
- return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
- symbolTableCollection);
+ return getMesh(
+ op.getOperation(),
+ cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
+ symbolTableCollection);
}
// Get the number of processes that participate in each group
@@ -131,22 +188,22 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
// result in a shape for each shard of ?x2x?.
ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
- MeshShardingAttr sharding);
+ MeshSharding sharding);
// If ranked tensor type return its sharded counterpart.
//
// If not ranked tensor type return `type`.
// `sharding` in that case must be null.
-Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Insert shard op if there is not one that already has the same sharding.
// May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
-void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
- OpResult result, OpBuilder &builder);
-void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
+ OpBuilder &builder);
+void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8e1e475463585e..8f696bbc1a0f6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -20,7 +20,7 @@ include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
-// Mesh Dialect operations.
+// Mesh operations.
//===----------------------------------------------------------------------===//
class Mesh_Op<string mnemonic, list<Trait> traits = []> :
@@ -105,22 +105,223 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
];
}
+def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary = "Get the multi index of current device along specified mesh axes.";
+ let description = [{
+ It is used in the SPMD format of IR.
+ The `axes` mush be non-negative and less than the total number of mesh axes.
+ If the axes are empty then get the index along all axes.
+ }];
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+ let results = (outs
+ Variadic<Index>:$result
+ );
+ let assemblyFormat = [{
+ `on` $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+ ];
+}
+
+def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary = "Get the linear index of the current device.";
+ let description = [{
+ Example:
+ ```
+ %idx = mesh.process_linear_index on @mesh : index
+ ```
+ if `@mesh` has shape `(10, 20, 30)`, a device with multi
+ index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
+ }];
+ let arguments = (ins FlatSymbolRefAttr:$mesh);
+ let results = (outs Index:$result);
+ let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// Sharding operations.
+//===----------------------------------------------------------------------===//
+
+def Mesh_ShardingOp : Mesh_Op<"sharding", [
+ Pure,
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Define a sharding of a tensor.";
+ let description = [{
+ The MeshSharding specifies how a tensor is sharded and distributed across the
+ process mesh. It is typically used in a `mesh.shard` operation.
+ The operation has the follwing attributes and operands:
+
+ 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+ mesh where the distributed tensor is placed. The symbol must resolve to a
+ `mesh.mesh` operation.
+
+ 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+ maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+ its value is [x, y], it indicates that the tensor's i-th dimension is splitted
+ along the x and y axes of the device mesh.
+
+ 3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
+ one along the specified mesh axes. An all-reduce should be applied to obtain
+ the complete tensor, with reduction type being specified by `partial_type`.
+
+ 4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
+ op. It has 4 possible values:
+ `generic`: is not an allowed value inside a shard attribute.
+
+ 5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
+ `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
+ `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
+ halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
+ `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
+ e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
+ `?` indicates dynamic halo sizes.
+
+ 6. [Optional] Sizes of sharded dimensions of each shard.
+ `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
+ device-mesh one value for each sharded tensor dimension.
+ Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
+ `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
+ the device-mesh will get a shard of shape 16x8x32 and the second device will get a
+ shard of shape 16x24x32.
+ `?` indicates dynamic shard dimensions.
+
+ `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+
+ Examples:
+
+ ```
+ mesh.mesh @mesh0(shape = 2x2x4)
+ mesh.mesh @mesh1d_4(shape = 4)
+
+ // The tensor is fully replicated on @mesh0.
+ // Currently, there must be at least one sub-array present in axes, even
+ // if it's empty. Otherwise, a parsing error will occur.
+ %sharding0 = mesh.sharding @mesh0 split_axes = [[]]
+
+ // The tensor is sharded on the first dimension along axis 0 of @mesh0
+ %sharding1 = mesh.sharding @mesh0 split_axes = [[0]]
+
+ // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+ // it is also a partial_sum along mesh axis 1.
+ %sharding2 = mesh.sharding @mesh0 split_axes = [[0] split_axes = []] partial = sum[1]
+
+ // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+ // it is also a partial_max along mesh axis 1.
+ %sharding3 = mesh.sharding @mesh0 split_axes = [[0]] partial = max[1]
+
+ // Could be used for a mesh.shard op
+ %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
+
+ // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+ // and it has halo-sizes of 1 and 2 on the sharded dim.
+ %halo_sharding = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2]
+ %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>
+
+ // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
+ // and it has pre-defined shard sizes. The shards of the devices will have
+ // the following shapes: [4x2, 4x3, 4x4, 4x5]
+ %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
+ %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ Mesh_MeshAxesArrayAttr:$split_axes,
+ OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
+ OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
+ Variadic<I64>:$dynamic_sharded_dims_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
+ Variadic<I64>:$dynamic_halo_sizes
+ );
+ let results = (outs
+ Mesh_Sharding:$result
+ );
+ let assemblyFormat = [{
+ $mesh
+ `split_axes` `=` $split_axes
+ (`partial` `=` $partial_type $partial_axes^)?
+ (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
+ (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+ attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+ "ArrayRef<MeshAxesAttr>":$split_axes,
+ "ArrayRef<MeshAxis>":$partial_axes,
+ "mesh::ReductionKind":$partial_type,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+ OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+ "ArrayRef<MeshAxesAttr>":$split_axes)>,
+ OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+ "ArrayRef<MeshAxesAttr>":$split_axes,
+ "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
+ "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
+ OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
+ ];
+ let hasVerifier = 1;
+}
+
+def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
+ let summary = "Get the shard shape of a given process/device.";
+ let description = [{
+ The device/process id is a linearized id of the device/process in the mesh.
+ This operation might be used during spmdization when the shard shape depends
+ on (non-constant) values used in `mesh.sharding`.
+ }];
+ let arguments = (ins
+ DenseI64ArrayAttr:$shape,
+ Mesh_Sharding:$sharding,
+ Index:$device
+ );
+ let results = (outs Variadic<Index>:$result);
+ let assemblyFormat = [{
+ custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
+ ];
+}
+
def Mesh_ShardOp : Mesh_Op<"shard", [
Pure,
- SameOperandsAndResultType,
+ AllTypesMatch<["result", "src"]>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Annotate on how a tensor is sharded across a mesh.";
let description = [{
The mesh.shard operation is designed to specify and guide the sharding
- behavior of a tensor value across a mesh topology. This operation has one
- operand and two attributes:
+ behavior of a tensor value across a mesh topology. This operation has two
+ operands and two optional attributes:
1. `input`: This operand represents the tensor value that needs to be
annotated for sharding.
- 2. `shard`: This attribute is type of `MeshSharding`, which is the core data
- structure to represent distribution of a tensor on a mesh.
+ 2. `sharding`: This attribute is type of `MeshShardingType`, which is the core data
+ structure to represent distribution of a tensor on a mesh. it is typically defiend
+ by an `mesh.sharding` operation.
3. `annotate_for_users`: A unit attribute addressing the scenario when a
tensor's sharding annotation
diff ers based on its context of use (either as
@@ -132,12 +333,21 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
Example:
```
func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+ %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
...
}
func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
+ %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
+ ...
+ }
+
+ func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
+ %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
...
}
@@ -146,9 +356,12 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
// operand of op2
func.func @both_result_and_multi_operands_annotated(
%arg0 : tensor<4x8xf32>) -> () {
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
- %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
- %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
+ %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
+ %sharding1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
+ %1 = mesh.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+ %sharding2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
+ %2 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
"op0"(%1) : ...
"op1"(%2) : ...
...
@@ -159,97 +372,56 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
```
func.func @annotate_on_same_result_with_
diff erent_sharding(
%arg0 : tensor<4x8xf32>) -> () {
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
- %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+ %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+ %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to $sharding1 : tensor<4x8xf32>
+ %1 = mesh.shard %0 to sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_result_same_value_with_
diff erent_sharding(
%arg0 : tensor<4x8xf32>) -> () {
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
- %1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+ %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+ %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %sharding1 : tensor<4x8xf32>
+ %1 = mesh.shard %arg0 to %sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_operand_with_
diff erent_sharding(
%arg0 : tensor<4x8xf32>) -> () {
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
- %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
+ %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+ %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+ %1 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
...
}
func.func @result_annotated_after_operand(
%arg0 : tensor<4x8xf32>) -> () {
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
- %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+ %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+ %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+ %1 = mesh.shard %0 to %sharding2 : tensor<4x8xf32>
...
}
```
}];
let arguments = (ins
AnyRankedTensor:$src,
- MeshSharding:$shard,
+ Mesh_Sharding:$sharding,
UnitAttr:$annotate_for_users
);
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
- $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
- type($result)
+ $src `to` $sharding
+ (`annotate_for_users` $annotate_for_users^)?
+ attr-dict `:` type($result)
}];
}
-def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
- Pure,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>,
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
-]> {
- let summary = "Get the multi index of current device along specified mesh axes.";
- let description = [{
- It is used in the SPMD format of IR.
- The `axes` mush be non-negative and less than the total number of mesh axes.
- If the axes are empty then get the index along all axes.
- }];
- let arguments = (ins
- FlatSymbolRefAttr:$mesh,
- DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
- );
- let results = (outs
- Variadic<Index>:$result
- );
- let assemblyFormat = [{
- `on` $mesh (`axes` `=` $axes^)?
- attr-dict `:` type($result)
- }];
- let builders = [
- OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
- OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
- ];
-}
-
-def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
- Pure,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>,
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
-]> {
- let summary = "Get the linear index of the current device.";
- let description = [{
- Example:
- ```
- %idx = mesh.process_linear_index on @mesh : index
- ```
- if `@mesh` has shape `(10, 20, 30)`, a device with multi
- index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
- }];
- let arguments = (ins FlatSymbolRefAttr:$mesh);
- let results = (outs Index:$result);
- let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
- let builders = [
- OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
- ];
-}
-
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -879,4 +1051,38 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
let hasCanonicalizer = 1;
}
+def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>
+]> {
+ let summary = "Update halo data.";
+ let description = [{
+ This operation updates halo regions of shards, e.g. if their sharding
+ specified halos and the actual tensor data might have changed
+ on the remote devices. Changes might be caused by mutating operations
+ and/or if the new halo regions are larger than the existing ones.
+
+ Assumes all devices hold tensors with same-sized halo data as specified
+ by `dynamic/static_halo_sizes`.
+
+ `split_axes` specifies for each tensor axis along which mesh axes its halo
+ data is updated.
+
+ Optionally resizes to new halo sizes `target_halo_sizes`.
+ }];
+ let arguments = (ins
+ AnyNon0RankedMemRef:$input,
+ FlatSymbolRefAttr:$mesh,
+ Mesh_MeshAxesArrayAttr:$split_axes,
+ Variadic<I64>:$dynamic_halo_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
+ );
+ let assemblyFormat = [{
+ $input `on` $mesh
+ `split_axes` `=` $split_axes
+ (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
+ (`target_halo_sizes` `=` $target_halo_sizes^)?
+ attr-dict `:` type($input)
+ }];
+}
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h
new file mode 100644
index 00000000000000..3e23419eeec07c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tensor {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 216d7e10296df2..b4d25cef05a7b9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -44,24 +44,21 @@ struct ShardingOption {
}
};
-// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// This method retrieves the 'MeshSharding' from a given operation
// result and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshShardingAttr>>
-getMeshShardingAttr(OpResult result);
+FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpResult result);
-// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// This method retrieves the 'MeshSharding' from a given operation
// operand and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshShardingAttr>>
-getMeshShardingAttr(OpOperand &opOperand);
+FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpOperand &opOperand);
namespace detail {
FailureOr<ShardingOption>
-defaultGetShardingOption(Operation *op,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings);
+defaultGetShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings);
-FailureOr<SmallVector<MeshShardingAttr>>
+FailureOr<std::vector<MeshSharding>>
defaultGetShardingAnnotations(Operation *op,
const ShardingOption &shardingOption);
@@ -72,11 +69,13 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
} // namespace detail
// Assumes full replication on all ranked tensor arguments and results.
-void spmdizeFullyReplicatedOperation(
- Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable, OpBuilder &builder);
+void spmdizeFullyReplicatedOperation(Operation &op,
+ ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder);
} // namespace mesh
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 47a74f619f56c4..a70d2c3e03851d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -84,8 +84,8 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
/*retTy=*/"FailureOr<ShardingOption>",
/*methodName=*/"getShardingOption",
/*args=*/(ins
- "ArrayRef<MeshShardingAttr>": $operandShardings,
- "ArrayRef<MeshShardingAttr>": $resultShardings
+ "ArrayRef<MeshSharding>": $operandShardings,
+ "ArrayRef<MeshSharding>": $resultShardings
),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -100,7 +100,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
This is what shardings the operands and results need to have in order
to shard the op according to shardingOption.
}],
- /*retTy=*/"FailureOr<SmallVector<MeshShardingAttr>>",
+ /*retTy=*/"FailureOr<std::vector<MeshSharding>>",
/*methodName=*/"getShardingAnnotations",
/*args=*/(ins
"const ShardingOption &":$shardingOption
@@ -139,7 +139,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
annotations from the IR for each argument/result and prepare
`operandShardings` and `resultShardings`.
Values that are not ranked tensors do not have sharding annotations.
- In this case their corresponding MeshShardingAttr is null.
+ In this case their corresponding MeshSharding is null.
For convenience it will also prepare `spmdizedOperands`, although
they can be retrieved from the `spmdizationMap`.
@@ -161,8 +161,8 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
/*methodName=*/"spmdize",
/*args=*/(ins
"ArrayRef<Value>": $spmdizedOperands,
- "ArrayRef<MeshShardingAttr>": $operandShardings,
- "ArrayRef<MeshShardingAttr>": $resultShardings,
+ "ArrayRef<MeshSharding>": $operandShardings,
+ "ArrayRef<MeshSharding>": $resultShardings,
"IRMapping&": $spmdizationMap,
"SymbolTableCollection &": $symbolTableCollection,
"OpBuilder &":$builder
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index 5e4b4f3a66af9d..2af8b2bd1d906f 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -26,8 +26,8 @@ namespace mesh {
// on the provided shardings for the op's operands and results.
// Assumes that the indexingMaps are projected permutations.
ShardingArray getMeshAxisAssignmentForLoopIterators(
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<AffineMap> indexingMaps);
@@ -42,11 +42,13 @@ SmallVector<MeshAxis> getReductionMeshAxes(
// Inserts a clone of the operation that has all ranked tensor
// arguments/results sharded.
-void spmdizeTriviallyShardableOperation(
- Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable, OpBuilder &builder);
+void spmdizeTriviallyShardableOperation(Operation &op,
+ ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder);
// All ranked tensor argument and result dimensions have
// independent parallel loop iterators.
@@ -72,8 +74,8 @@ struct IndependentParallelIteratorDomainShardingInterface
}
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
@@ -128,8 +130,8 @@ struct ElementwiseShardingInterface
}
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 549c26c72d8a1e..01f28c5d21b37d 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -57,6 +57,7 @@
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
+#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -179,6 +180,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
+ tensor::registerShardingInterfaceExternalModels(registry);
tensor::registerSubsetOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 67de05b0cb4ff3..47bcfc9bbd4f96 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -110,7 +110,7 @@ class ShapedTypeComponents {
public:
/// Default construction is an unranked shape.
- ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
+ ShapedTypeComponents() : elementType(nullptr), attr(nullptr) {};
ShapedTypeComponents(Type elementType)
: elementType(elementType), attr(nullptr), ranked(false) {}
ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
@@ -270,7 +270,7 @@ class InferShapedTypeOpAdaptor
/// shape and elemental types.
/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
/// Less strict is possible (e.g., implements inferReturnTypeComponents and
-/// these always populates all element types and shapes or fails, but this\
+/// these always populates all element types and shapes or fails, but this
/// trait is currently only used where the interfaces are, so keep it
/// restricted for now).
template <typename ConcreteType>
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 36b6088b83cc2c..d47a82b59bcada 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -43,7 +43,7 @@ namespace mlir::linalg {
using MeshAxis = mesh::MeshAxis;
using ReductionKind = mesh::ReductionKind;
-using MeshShardingAttr = mesh::MeshShardingAttr;
+using MeshSharding = mesh::MeshSharding;
using ShardingArray = mesh::ShardingArray;
using MeshOp = mesh::MeshOp;
@@ -102,19 +102,18 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
return getReductionKind(reductionOp.value());
}
-static MeshOp getMesh(Operation *op,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings,
+static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
SymbolTableCollection &symbolTable) {
- for (MeshShardingAttr sharding : operandShardings) {
+ for (MeshSharding sharding : operandShardings) {
if (sharding) {
- return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+ return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
}
}
- for (MeshShardingAttr sharding : resultShardings) {
+ for (MeshSharding sharding : resultShardings) {
if (sharding) {
- return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+ return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
}
}
@@ -185,7 +184,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
static void createAllReduceForResultWithoutPartialSharding(
Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
- MeshShardingAttr resultSharding, ReductionKind reductionKind,
+ MeshSharding resultSharding, ReductionKind reductionKind,
IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
SmallVector<MeshAxis> allReduceMeshAxes;
llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
@@ -199,14 +198,14 @@ static void createAllReduceForResultWithoutPartialSharding(
Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
Value reducedValue = builder.create<mesh::AllReduceOp>(
- spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
- allReduceMeshAxes, reductionKind);
+ spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes,
+ reductionKind);
spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
}
static void createAllReduceForResultsWithoutPartialShardings(
LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
- ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
ImplicitLocOpBuilder &builder) {
ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
for (auto [unshardedLinalgOpResult, resultSharding] :
@@ -219,8 +218,8 @@ static void createAllReduceForResultsWithoutPartialShardings(
static void spmdizeLinalgOpWithShardedReduction(
LinalgOp op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
@@ -293,8 +292,8 @@ struct StructuredOpShardingInterface
}
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 75fceee14e1230..c35020b4c20ccc 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -86,6 +87,10 @@ void MeshDialect::initialize() {
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
>();
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
+ >();
}
Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
@@ -147,39 +152,101 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
return success();
}
+template <typename Op>
+static FailureOr<MeshOp>
+getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
+ auto mesh =
+ ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
+ return failure();
+ }
+ return mesh;
+}
+
template <typename InShape, typename MeshShape, typename SplitAxes,
typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
- const SplitAxes &splitAxes, OutShape &outShape) {
+ const SplitAxes &splitAxes, OutShape &outShape,
+ ArrayRef<int64_t> shardedDimsSizes = {},
+ ArrayRef<int64_t> haloSizes = {}) {
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
llvm::adl_begin(outShape));
- for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
- outShape[tensorAxis] = shardDimension(
- inShape[tensorAxis],
- collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
+
+ if (!shardedDimsSizes.empty()) {
+ for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+ if (innerSplitAxes.empty()) {
+#ifndef NDEBUG
+ for (auto dimSz : shardedDimsSizes) {
+ auto inAxis = dimSz % inShape.size();
+ assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
+ inShape[inAxis] == ShapedType::kDynamic);
+ }
+#endif // NDEBUG
+ } else {
+ // find sharded dims in sharded_dims_sizes with same static size on
+ // all devices. Use kDynamic for dimensions with dynamic or non-uniform
+ // sizes in sharded_dims_sizes.
+ auto sz = shardedDimsSizes[tensorAxis];
+ bool same = true;
+ for (size_t i = tensorAxis + inShape.size();
+ i < shardedDimsSizes.size(); i += inShape.size()) {
+ if (shardedDimsSizes[i] != sz) {
+ same = false;
+ break;
+ }
+ }
+ outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
+ }
+ }
+ } else {
+ for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+ outShape[tensorAxis] = shardDimension(
+ inShape[tensorAxis],
+ collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
+ }
+
+ if (!haloSizes.empty()) {
+ // add halo sizes if requested
+ int haloAxis = 0;
+ for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+ if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
+ !innerSplitAxes.empty()) {
+ if (haloSizes[haloAxis * 2] >= 0 &&
+ haloSizes[haloAxis * 2 + 1] >= 0) {
+ outShape[tensorAxis] +=
+ haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
+ ++haloAxis;
+ } else {
+ outShape[tensorAxis] = ShapedType::kDynamic;
+ }
+ }
+ }
+ }
}
}
ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
- MeshShardingAttr sharding) {
+ MeshSharding sharding) {
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
SmallVector<Dim> resShapeArr(shape.getShape().size());
shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
- resShapeArr);
+ resShapeArr, sharding.getStaticShardedDimsSizes(),
+ sharding.getStaticHaloSizes());
return shape.clone(resShapeArr);
}
-Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
+Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
if (rankedTensorType) {
return shardShapedType(rankedTensorType, mesh, sharding);
}
-
- assert(!sharding);
return type;
}
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder) {
OpBuilder::InsertionGuard insertionGuard(builder);
@@ -187,14 +254,15 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
Operation *operandOp = operand.getOwner();
builder.setInsertionPointAfterValue(operandValue);
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
- if (shardOp && shardOp.getShard() == sharding &&
+ if (shardOp && sharding == shardOp.getSharding() &&
!shardOp.getAnnotateForUsers()) {
// No need for anything the correct sharding is already set.
return;
}
+ auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
auto newShardOp =
- builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ false);
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
@@ -206,12 +274,13 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
return;
}
- auto newShardOp2 = builder.create<ShardOp>(
- operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
+ auto newShardOp2 =
+ builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
+ /*annotate_for_users*/ true);
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
}
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
@@ -219,7 +288,7 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
}
}
-void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder) {
OpBuilder::InsertionGuard insertionGuard(builder);
@@ -229,15 +298,17 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
bool isBlockArg = !operandSrcOp;
ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
- if (shardOp && shardOp.getShard() == sharding &&
+ if (shardOp && sharding == shardOp.getSharding() &&
shardOp.getAnnotateForUsers()) {
// No need for anything the correct sharding is already set.
return;
}
builder.setInsertionPoint(operandOp);
+ auto shardingOp =
+ builder.create<ShardingOp>(operand.get().getLoc(), sharding);
auto newShardOp =
- builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ true);
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
@@ -252,13 +323,12 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
builder.setInsertionPoint(newShardOp);
auto newPreceedingShardOp =
- builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ false);
- rewriter.replaceUsesWithIf(newShardOp.getOperand(), newPreceedingShardOp,
- [&newShardOp](OpOperand &use) {
- return use.getOwner() ==
- newShardOp.getOperation();
- });
+ rewriter.replaceUsesWithIf(
+ newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
+ return use.getOwner() == newShardOp.getOperation();
+ });
}
//===----------------------------------------------------------------------===//
@@ -331,16 +401,71 @@ void MeshShapeOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.shard attr
+// mesh.sharding
//===----------------------------------------------------------------------===//
-LogicalResult
-MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
- ArrayRef<MeshAxis> partialAxes, ReductionKind) {
- // TODO: At present mesh symbol ref is not verified. This is due to the
- //
diff iculty in fetching the corresponding symbol op based on an attribute.
-
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+ FlatSymbolRefAttr mesh,
+ ArrayRef<MeshAxesAttr> split_axes,
+ ArrayRef<MeshAxis> partial_axes,
+ mesh::ReductionKind partial_type,
+ ArrayRef<int64_t> static_halo_sizes,
+ ArrayRef<int64_t> static_sharded_dims_sizes) {
+ return build(
+ b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
+ ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
+ ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
+ {});
+}
+
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+ FlatSymbolRefAttr mesh,
+ ArrayRef<MeshAxesAttr> split_axes) {
+ return build(
+ b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
+ ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+ {}, {}, {}, {});
+}
+
+void ShardingOp::build(
+ ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+ FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
+ ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
+ ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+ mlir::SmallVector<int64_t> staticHalos, staticDims;
+ mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
+ dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
+ dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+ return build(
+ b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
+ ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
+}
+
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+ mlir::mesh::MeshSharding from) {
+
+ build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
+ MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
+ from.getPartialAxes().empty()
+ ? DenseI16ArrayAttr()
+ : b.getDenseI16ArrayAttr(from.getPartialAxes()),
+ ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
+ from.getPartialType()),
+ from.getStaticShardedDimsSizes().empty()
+ ? DenseI64ArrayAttr()
+ : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
+ from.getDynamicShardedDimsSizes(),
+ from.getStaticHaloSizes().empty()
+ ? DenseI64ArrayAttr()
+ : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
+ from.getDynamicHaloSizes());
+}
+
+LogicalResult ShardingOp::verify() {
llvm::SmallSet<MeshAxis, 4> visitedAxes;
auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
@@ -353,28 +478,58 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
};
- for (MeshAxesAttr subAxes : splitAxes) {
+ for (auto subAxes : getSplitAxes().getAxes()) {
ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
if (failed(checkMeshAxis(subAxesArray)))
return failure();
}
- if (failed(checkMeshAxis(partialAxes)))
+ if (getPartialAxes().has_value() &&
+ failed(checkMeshAxis(getPartialAxes().value())))
return failure();
+
+ if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
+ return emitOpError("halo sizes and shard shapes are mutually exclusive");
+ }
+
+ if (!getStaticHaloSizes().empty()) {
+ auto numSplitAxes = getSplitAxes().getAxes().size();
+ for (auto splitAxis : getSplitAxes().getAxes()) {
+ if (splitAxis.empty()) {
+ --numSplitAxes;
+ }
+ }
+ if (getStaticHaloSizes().size() != numSplitAxes * 2) {
+ return emitError() << "halo sizes must be specified for all split axes.";
+ }
+ }
+
return success();
}
-bool MeshShardingAttr::operator==(Attribute rhs) const {
- MeshShardingAttr rhsAsMeshShardingAttr =
- mlir::dyn_cast<MeshShardingAttr>(rhs);
- return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
+void ShardingOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "sharding");
}
-bool MeshShardingAttr::operator!=(Attribute rhs) const {
- return !(*this == rhs);
+LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
+ getStaticShardedDimsSizes().size() > 0) {
+ return emitError() << "sharded dims sizes are not allowed for "
+ "devices meshes with dynamic shape.";
+ }
+ return success();
}
-bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
- if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
+//===----------------------------------------------------------------------===//
+// MeshSharding
+//===----------------------------------------------------------------------===//
+
+bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
+ if (getMesh() != rhs.getMesh()) {
return false;
}
@@ -398,10 +553,108 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
std::mem_fn(&MeshAxesAttr::empty));
}
-bool MeshShardingAttr::operator!=(MeshShardingAttr rhs) const {
+bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
+ if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
+ !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
+ getStaticHaloSizes().end()),
+ llvm::make_range(rhs.getStaticHaloSizes().begin(),
+ rhs.getStaticHaloSizes().end()))) {
+ return false;
+ }
+ if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size() ||
+ !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
+ getStaticShardedDimsSizes().end()),
+ llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
+ rhs.getStaticShardedDimsSizes().end()))) {
+ return false;
+ }
+ if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size() ||
+ !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
+ getDynamicHaloSizes().end()),
+ llvm::make_range(rhs.getDynamicHaloSizes().begin(),
+ rhs.getDynamicHaloSizes().end()))) {
+ return false;
+ }
+ if (rhs.getDynamicShardedDimsSizes().size() !=
+ getDynamicShardedDimsSizes().size() ||
+ !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
+ getDynamicShardedDimsSizes().end()),
+ llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(),
+ rhs.getDynamicShardedDimsSizes().end()))) {
+ return false;
+ }
+ return true;
+}
+
+bool MeshSharding::operator==(Value rhs) const {
+ return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
+}
+
+bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
+
+bool MeshSharding::operator==(const MeshSharding &rhs) const {
+ return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
+}
+
+bool MeshSharding::operator!=(const MeshSharding &rhs) const {
return !(*this == rhs);
}
+MeshSharding::MeshSharding(Value rhs) {
+ auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
+ assert(shardingOp && "expected sharding op");
+ *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
+ shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
+ shardingOp.getPartialType().value_or(ReductionKind::Sum),
+ shardingOp.getStaticHaloSizes(),
+ shardingOp.getStaticShardedDimsSizes(),
+ SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
+ SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
+}
+
+MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
+ ArrayRef<MeshAxesAttr> split_axes_,
+ ArrayRef<MeshAxis> partial_axes_,
+ ReductionKind partial_type_,
+ ArrayRef<int64_t> static_halo_sizes_,
+ ArrayRef<int64_t> static_sharded_dims_sizes_,
+ ArrayRef<Value> dynamic_halo_sizes_,
+ ArrayRef<Value> dynamic_sharded_dims_sizes_) {
+ MeshSharding res;
+ res.mesh = mesh_;
+ res.split_axes.resize(split_axes_.size());
+ for (auto [i, axis] : llvm::enumerate(split_axes_)) {
+ res.split_axes[i] =
+ MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
+ }
+
+ auto clone = [](const auto src, auto &dst) {
+ dst.resize(src.size());
+ llvm::copy(src, dst.begin());
+ };
+
+ clone(partial_axes_, res.partial_axes);
+ res.partial_type = partial_type_;
+ clone(static_halo_sizes_, res.static_halo_sizes);
+ clone(static_sharded_dims_sizes_, res.static_sharded_dims_sizes);
+ clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
+ clone(dynamic_sharded_dims_sizes_, res.dynamic_sharded_dims_sizes);
+
+ return res;
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard_shape
+//===----------------------------------------------------------------------===//
+
+void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::llvm::ArrayRef<int64_t> shape,
+ ::mlir::Value sharding, ::mlir::Value device) {
+ SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
+ build(odsBuilder, odsState, resType, shape, sharding, device);
+}
+
//===----------------------------------------------------------------------===//
// mesh.shard op
//===----------------------------------------------------------------------===//
@@ -530,20 +783,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
return success();
}
-template <typename Op>
-static FailureOr<MeshOp>
-getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
- auto mesh =
- ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
- if (failed(mesh)) {
- return failure();
- }
- if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
- return failure();
- }
- return mesh;
-}
-
template <typename It>
static auto product(It begin, It end) {
using ElementType = std::decay_t<decltype(*begin)>;
@@ -1044,6 +1283,20 @@ void ShiftOp::getAsmResultNames(
setNameFn(getResult(), "shift");
}
+//===----------------------------------------------------------------------===//
+// mesh.update_halo op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
@@ -1054,4 +1307,7 @@ void ShiftOp::getAsmResultNames(
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
+
#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
index 1010756f1fe279..266fa6fa54557c 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_library(MLIRShardingInterface
ShardingInterface.cpp
+ TensorShardingInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
@@ -11,5 +12,6 @@ add_mlir_library(MLIRShardingInterface
MLIRDialectUtils
MLIRIR
MLIRMeshDialect
+ MLIRTensorDialect
MLIRSupport
)
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index bcd0e155613208..c1f4d563d5b42c 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -91,12 +91,22 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
return positions;
}
+template <typename T>
+SmallVector<MeshAxesAttr>
+fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {
+ SmallVector<MeshAxesAttr> res;
+ for (const auto &v : vec) {
+ res.emplace_back(MeshAxesAttr::get(ctxt, v));
+ }
+ return res;
+}
+
//===----------------------------------------------------------------------===//
-// mesh::getMeshShardingAttr
+// mesh::getMeshSharding
//===----------------------------------------------------------------------===//
-FailureOr<std::pair<bool, MeshShardingAttr>>
-mesh::getMeshShardingAttr(OpResult result) {
+FailureOr<std::pair<bool, MeshSharding>>
+mesh::getMeshSharding(OpResult result) {
Value val = cast<Value>(result);
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
@@ -111,7 +121,7 @@ mesh::getMeshShardingAttr(OpResult result) {
if (!val.hasOneUse())
return failure();
auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
- return std::make_pair(false, shardOp.getShard());
+ return std::make_pair(false, MeshSharding(shardOp.getSharding()));
}
bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
@@ -127,11 +137,11 @@ mesh::getMeshShardingAttr(OpResult result) {
if (shardOp)
shardOps.push_back(shardOp);
}
- MeshShardingAttr shardForDef = shardOps[0].getShard();
+ MeshSharding shardForDef = shardOps[0].getSharding();
for (size_t i = 1; i < shardOps.size(); ++i) {
// TODO: Deduce a reasonable mesh sharding attr for def when they are
//
diff erent
- assert(shardOps[i].getShard() == shardForDef &&
+ assert(shardForDef == shardOps[i].getSharding() &&
"only support all shard ops have the same mesh sharding attr");
}
return std::make_pair(true, shardForDef);
@@ -139,11 +149,12 @@ mesh::getMeshShardingAttr(OpResult result) {
return failure();
}
-FailureOr<std::pair<bool, MeshShardingAttr>>
-mesh::getMeshShardingAttr(OpOperand &opOperand) {
+FailureOr<std::pair<bool, MeshSharding>>
+mesh::getMeshSharding(OpOperand &opOperand) {
Value val = opOperand.get();
if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
- return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
+ return std::make_pair(shardOp.getAnnotateForUsers(),
+ MeshSharding(shardOp.getSharding()));
return failure();
}
@@ -250,9 +261,10 @@ static LogicalResult fillShardingOption(Operation *op,
} // namespace
-FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
- Operation *op, ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings) {
+FailureOr<ShardingOption>
+mesh::detail::defaultGetShardingOption(Operation *op,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings) {
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
ShardingOption shardingOption;
@@ -269,7 +281,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
// 1. Fill sharding option based on op results
for (auto shardingIt : llvm::enumerate(resultShardings)) {
- MeshShardingAttr shardAttr = shardingIt.value();
+ MeshSharding shardAttr = shardingIt.value();
if (!shardAttr)
continue;
AffineMap map = maps[numOperands + shardingIt.index()];
@@ -283,7 +295,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
auto dim = cast<AffineDimExpr>(expr);
unsigned index = dim.getPosition();
visitedLoopIndices.insert(index);
- if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
+ if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
axes, index)))
return failure();
}
@@ -307,7 +319,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
// 2. Fill sharding option based on operands
for (auto shardingIt : llvm::enumerate(operandShardings)) {
- MeshShardingAttr shardAttr = shardingIt.value();
+ MeshSharding shardAttr = shardingIt.value();
if (!shardAttr)
continue;
@@ -334,8 +346,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (loopIndices->size() == 1) {
unsigned loopIdx = *loopIndices->begin();
visitedLoopIndices.insert(loopIdx);
- if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
- axes, loopIdx)))
+ if (failed(fillShardingOption(op, shardingOption,
+ shardAttr.getMeshAttr(), axes, loopIdx)))
return failure();
}
// If multiple loop indices correspond to a dimension of an operand, it is
@@ -389,16 +401,16 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
}
// Get the sharding attributed for the given result and sharding option.
-MeshShardingAttr
-getShardingAttribute(OpResult result, const ShardingOption &shardingOption,
- AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
- ArrayRef<ReductionKind> reductionLoopKinds) {
+MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
+ AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
+ ArrayRef<ReductionKind> reductionLoopKinds) {
auto resultType = cast<RankedTensorType>(result.getType());
SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
SmallVector<MeshAxis> partialAxes;
// process the split axes
for (auto it : llvm::enumerate(map.getResults())) {
+ SmallVector<MeshAxis> tmp_axes;
AffineExpr expr = it.value();
// `expr` must be an `AffineDimExpr` because `map` is verified by
// isProjectedPermutation
@@ -427,13 +439,14 @@ getShardingAttribute(OpResult result, const ShardingOption &shardingOption,
}
removeTrailingEmptySubArray(splitAxes);
- return MeshShardingAttr::get(result.getContext(), shardingOption.mesh,
- splitAxes, partialAxes, partialType);
+ return MeshSharding::get(shardingOption.mesh,
+ fromArrayOfVector(result.getContext(), splitAxes),
+ partialAxes, partialType);
}
-static FailureOr<MeshShardingAttr>
-getShardingAttribute(OpOperand &opOperand, const ShardingOption &shardingOption,
- AffineMap map) {
+static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
+ const ShardingOption &shardingOption,
+ AffineMap map) {
Value operandValue = opOperand.get();
auto operandType = cast<RankedTensorType>(operandValue.getType());
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
@@ -461,14 +474,15 @@ getShardingAttribute(OpOperand &opOperand, const ShardingOption &shardingOption,
}
removeTrailingEmptySubArray(splitAxes);
- return MeshShardingAttr::get(opOperand.get().getContext(),
- shardingOption.mesh, splitAxes);
+ return MeshSharding::get(
+ shardingOption.mesh,
+ fromArrayOfVector(opOperand.get().getContext(), splitAxes));
}
-FailureOr<SmallVector<MeshShardingAttr>>
+FailureOr<std::vector<MeshSharding>>
mesh::detail::defaultGetShardingAnnotations(
Operation *op, const ShardingOption &shardingOption) {
- SmallVector<MeshShardingAttr> res;
+ std::vector<MeshSharding> res;
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
SmallVector<utils::IteratorType> loopTypes =
@@ -479,7 +493,7 @@ mesh::detail::defaultGetShardingAnnotations(
unsigned numOperands = op->getNumOperands();
for (OpOperand &opOperand : op->getOpOperands()) {
- FailureOr<MeshShardingAttr> shardingAttr = getShardingAttribute(
+ FailureOr<MeshSharding> shardingAttr = getSharding(
opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
if (failed(shardingAttr))
return failure();
@@ -487,9 +501,9 @@ mesh::detail::defaultGetShardingAnnotations(
}
for (OpResult result : op->getResults()) {
- res.push_back(getShardingAttribute(
- result, shardingOption, maps[numOperands + result.getResultNumber()],
- loopTypes, reductionKinds));
+ res.push_back(getSharding(result, shardingOption,
+ maps[numOperands + result.getResultNumber()],
+ loopTypes, reductionKinds));
}
return res;
@@ -506,9 +520,9 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
AffineMap map,
ArrayRef<utils::IteratorType> loopTypes,
ArrayRef<ReductionKind> reductionLoopKinds) {
- MeshShardingAttr shardAttr = getShardingAttribute(
- result, shardingOption, map, loopTypes, reductionLoopKinds);
- maybeInsertTargetShardingAnnotation(shardAttr, result, b);
+ MeshSharding sharding =
+ getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
+ maybeInsertTargetShardingAnnotation(sharding, result, b);
return success();
}
@@ -519,13 +533,13 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
const ShardingOption &shardingOption,
AffineMap map) {
- FailureOr<MeshShardingAttr> shardAttr =
- getShardingAttribute(opOperand, shardingOption, map);
- if (failed(shardAttr)) {
+ FailureOr<MeshSharding> sharding =
+ getSharding(opOperand, shardingOption, map);
+ if (failed(sharding)) {
return failure();
}
OpBuilder::InsertionGuard guard(b);
- maybeInsertSourceShardingAnnotation(*shardAttr, opOperand, b);
+ maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b);
return success();
}
@@ -563,7 +577,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
#ifndef NDEBUG
static bool
isValueCompatibleWithFullReplicationSharding(Value value,
- MeshShardingAttr sharding) {
+ MeshSharding sharding) {
if (isa<RankedTensorType>(value.getType())) {
return sharding && isFullReplication(sharding);
}
@@ -571,27 +585,27 @@ isValueCompatibleWithFullReplicationSharding(Value value,
return !sharding;
}
-template <typename ValueRange, typename MeshShardingAttrRage>
-static bool areValuesCompatibleWithFullReplicationShardings(
- ValueRange &&values, MeshShardingAttrRage &&shardings) {
+template <typename ValueRange, typename MeshShardingRage>
+static bool
+areValuesCompatibleWithFullReplicationShardings(ValueRange &&values,
+ MeshShardingRage &&shardings) {
if (std::size(values) != std::size(shardings)) {
return false;
}
- return llvm::all_of(llvm::zip_equal(
- std::forward<ValueRange>(values),
- std::forward<MeshShardingAttrRage>(shardings)),
- [](auto valueAndSharding) {
- return isValueCompatibleWithFullReplicationSharding(
- std::get<0>(valueAndSharding),
- std::get<1>(valueAndSharding));
- });
+ return llvm::all_of(
+ llvm::zip_equal(std::forward<ValueRange>(values),
+ std::forward<MeshShardingRage>(shardings)),
+ [](auto valueAndSharding) {
+ return isValueCompatibleWithFullReplicationSharding(
+ std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
+ });
}
#endif // NDEBUG
void mesh::spmdizeFullyReplicatedOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable, OpBuilder &builder) {
assert(spmdizedOperands.size() == operandShardings.size());
assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
@@ -618,13 +632,13 @@ static void updateMeshAxisAssignmentForLoopIterators(
}
ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<AffineMap> indexingMaps) {
SmallVector<std::optional<SmallVector<MeshAxis>>>
meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
- SmallVector<MeshShardingAttr> operatorAndResultShardings;
+ std::vector<MeshSharding> operatorAndResultShardings;
operatorAndResultShardings.reserve(operandShardings.size() +
resultShardings.size());
llvm::append_range(operatorAndResultShardings, operandShardings);
@@ -686,16 +700,16 @@ SmallVector<MeshAxis> mesh::getReductionMeshAxes(
void mesh::spmdizeTriviallyShardableOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable, OpBuilder &builder) {
// `clone` will populate the mapping of old to new results.
Operation *newOp = builder.clone(op, spmdizationMap);
// Set the result types to the sharded counterparts.
for (auto [oldResult, newResult, sharding] :
llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
- newResult.setType(shardType(newResult.getType(),
- getMesh(&op, sharding.getMesh(), symbolTable),
- sharding));
+ newResult.setType(
+ shardType(newResult.getType(),
+ getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
}
}
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp
new file mode 100644
index 00000000000000..9422dd4a529fd4
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp
@@ -0,0 +1,105 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tensor-sharding-impl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::tensor;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of tensor.empty
+struct EmptyOpShardingInterface
+ : public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
+ tensor::EmptyOp> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
+ return SmallVector<utils::IteratorType>(ndims,
+ utils::IteratorType::parallel);
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ MLIRContext *ctx = op->getContext();
+ Value val = op->getResult(0);
+ auto type = dyn_cast<RankedTensorType>(val.getType());
+ if (!type)
+ return {};
+ return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)};
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ auto shardType = cast<ShapedType>(mesh::shardType(
+ op->getResult(0).getType(),
+ mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable),
+ resultShardings[0]));
+ Operation *newOp = nullptr;
+ // if the sharding introduces a new dynamic dimension, we take it from
+ // the dynamic sharding info. For now bail out if it's not
+ // provided.
+ assert(resultShardings.size() == 1);
+ if (!shardType.hasStaticShape()) {
+ assert(op->getResult(0).hasOneUse());
+ SmallVector<Value> newOperands;
+ auto oldType = cast<ShapedType>(op->getResult(0).getType());
+ assert(oldType.getRank() == shardType.getRank());
+ int currOldOprndNum = -1;
+ mesh::ShardShapeOp shapeForDevice;
+ Value device;
+ Operation *newSharding = nullptr;
+ for (auto i = 0; i < oldType.getRank(); ++i) {
+ if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
+ if (!newSharding) {
+ newSharding =
+ builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
+ device = builder.create<mesh::ProcessLinearIndexOp>(
+ op->getLoc(), resultShardings[0].getMesh());
+ shapeForDevice = builder.create<mesh::ShardShapeOp>(
+ op->getLoc(), oldType.getShape(), newSharding->getResult(0),
+ device);
+ }
+ newOperands.emplace_back(shapeForDevice.getResult()[i]);
+ } else if (oldType.isDynamicDim(i)) {
+ assert(shardType.isDynamicDim(i));
+ newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
+ }
+ }
+ newOp =
+ builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands);
+ spmdizationMap.map(op->getResult(0), newOp->getResult(0));
+ } else {
+ // `clone` will populate the mapping of old to new results.
+ newOp = builder.clone(*op, spmdizationMap);
+ }
+ newOp->getResult(0).setType(shardType);
+
+ return success();
+ }
+};
+} // namespace
+
+void mlir::tensor::registerShardingInterfaceExternalModels(
+ DialectRegistry ®istry) {
+
+ registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
+ EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 511c9102fa3037..4bd3b425219c1a 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -108,16 +108,15 @@ operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
// specific shardings. For example, mustShardings = [shard0, None] and
// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
// [shard0, None]]
-static SmallVector<SmallVector<MeshShardingAttr>>
-getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
- ArrayRef<MeshShardingAttr> optionalShardings) {
- SmallVector<SmallVector<MeshShardingAttr>> allShardingAttrs;
- SmallVector<MeshShardingAttr> curShardingAttrs;
+static SmallVector<std::vector<MeshSharding>>
+getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
+ ArrayRef<MeshSharding> optionalShardings) {
+ SmallVector<std::vector<MeshSharding>> allShardingAttrs;
+ std::vector<MeshSharding> curShardingAttrs;
std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
if (i == mustShardings.size()) {
- allShardingAttrs.push_back(
- SmallVector<MeshShardingAttr>(curShardingAttrs));
+ allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
return;
}
@@ -132,13 +131,13 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
curShardingAttrs.push_back(optionalShardings[i]);
dfsCreateShardingAttrs(i + 1);
curShardingAttrs.pop_back();
- curShardingAttrs.push_back(nullptr);
+ curShardingAttrs.push_back({});
dfsCreateShardingAttrs(i + 1);
curShardingAttrs.pop_back();
return;
}
- curShardingAttrs.push_back(nullptr);
+ curShardingAttrs.push_back({});
dfsCreateShardingAttrs(i + 1);
curShardingAttrs.pop_back();
};
@@ -158,8 +157,7 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
// 3. All other cases. Resharding is required for operands/results with
// annotation targeting explicitly this operation.
ReshardingRquirementKind getReshardingRquirementKind(
- Operation *op,
- const SmallVector<MeshShardingAttr> &operandAndResultShardings) {
+ Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
size_t operandsCount = op->getOperands().size();
@@ -176,7 +174,7 @@ ReshardingRquirementKind getReshardingRquirementKind(
if (!shardOp) {
continue;
}
- bool needsResharding = shardOp.getShardAttr() != sharding;
+ bool needsResharding = sharding != shardOp.getSharding();
bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
if (needsResharding) {
if (isExplicitAnnotationForThisOp) {
@@ -194,7 +192,7 @@ ReshardingRquirementKind getReshardingRquirementKind(
if (!shardOp) {
continue;
}
- bool needsResharding = shardOp.getShardAttr() != sharding;
+ bool needsResharding = sharding != shardOp.getSharding();
bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
if (needsResharding) {
if (isExplicitAnnotationForThisOp) {
@@ -218,14 +216,13 @@ ReshardingRquirementKind getReshardingRquirementKind(
// 3. Resharding of existing explicit sharding annotations for this op.
static FailureOr<ShardingOption> selectShardingOption(
ShardingInterface shardingOp,
- ArrayRef<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs,
- ArrayRef<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs) {
+ ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
+ ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
shardingOptionsAndReshardingRequirements;
- for (ArrayRef<MeshShardingAttr> resultShardings :
- possibleResultShardingAttrs) {
- for (ArrayRef<MeshShardingAttr> operandShardings :
+ for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
+ for (ArrayRef<MeshSharding> operandShardings :
possibleOperandShardingAttrs) {
FailureOr<ShardingOption> shardingOption =
shardingOp.getShardingOption(operandShardings, resultShardings);
@@ -237,14 +234,14 @@ static FailureOr<ShardingOption> selectShardingOption(
// They may be missing some annotations.
// Whatever is returned by getShardingAnnotations is exactly what the op
// needs.
- FailureOr<SmallVector<MeshShardingAttr>> operandAndResultShardings =
+ FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
shardingOp.getShardingAnnotations(*shardingOption);
if (failed(operandAndResultShardings)) {
return failure();
}
- LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
- << *operandAndResultShardings << "\n";);
+ // LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
+ // << *operandAndResultShardings << "\n";);
ReshardingRquirementKind reshardingRquirement =
getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
@@ -285,7 +282,8 @@ static FailureOr<ShardingOption> selectShardingOption(
// a `mesh.shard` operation for all remaining operands and results that do not
// have sharding annotations.
static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
- if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
+ if (op->hasTrait<OpTrait::IsTerminator>() ||
+ llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
return success();
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
@@ -294,14 +292,14 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
return failure();
}
- // collect MeshShardingAttr from results
- SmallVector<MeshShardingAttr> allowConflictsResultShardings;
+ // collect MeshSharding from results
+ std::vector<MeshSharding> allowConflictsResultShardings;
allowConflictsResultShardings.resize(op->getNumResults());
- SmallVector<MeshShardingAttr> resultMustShardings;
+ std::vector<MeshSharding> resultMustShardings;
resultMustShardings.resize(op->getNumResults());
for (OpResult result : op->getResults()) {
- FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
- getMeshShardingAttr(result);
+ FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
+ getMeshSharding(result);
if (failed(maybeShardAttr))
continue;
if (!maybeShardAttr->first)
@@ -311,14 +309,14 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
maybeShardAttr->second;
}
- // collect MeshShardingAttr from operands
- SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
+ // collect MeshSharding from operands
+ std::vector<MeshSharding> allowConflictsOperandShardings;
allowConflictsOperandShardings.resize(op->getNumOperands());
- SmallVector<MeshShardingAttr> operandMustShardings;
+ std::vector<MeshSharding> operandMustShardings;
operandMustShardings.resize(op->getNumOperands());
for (OpOperand &opOperand : op->getOpOperands()) {
- FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
- getMeshShardingAttr(opOperand);
+ FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
+ getMeshSharding(opOperand);
if (failed(maybeShardAttr))
continue;
@@ -331,10 +329,10 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
}
// try to get the sharding option
- SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
+ SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
getOrderedPossibleShardingAttrs(operandMustShardings,
allowConflictsOperandShardings);
- SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
+ SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
getOrderedPossibleShardingAttrs(resultMustShardings,
allowConflictsResultShardings);
FailureOr<ShardingOption> shardingOption = selectShardingOption(
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 1df3cf62c2b53c..fdfed39972fd52 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -54,10 +54,10 @@ static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
// targetSharding = <@mesh_1d, [[]]>
// Then will apply all-reduce on the source value
// and return it with the sharding <@mesh_1d, [[0]]>.
-static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+static std::tuple<TypedValue<ShapedType>, MeshSharding>
handlePartialAxesDuringResharding(OpBuilder &builder,
- MeshShardingAttr sourceSharding,
- MeshShardingAttr targetSharding,
+ MeshSharding sourceSharding,
+ MeshSharding targetSharding,
TypedValue<ShapedType> sourceShard) {
if (sourceSharding.getPartialAxes().empty() &&
targetSharding.getPartialAxes().empty()) {
@@ -88,7 +88,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
builder
.create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
- sourceSharding.getMesh().getLeafReference(),
+ sourceSharding.getMeshAttr().getLeafReference(),
allReduceMeshAxes, sourceShard,
sourceSharding.getPartialType())
.getResult());
@@ -99,16 +99,16 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
[&targetShardingPartialAxesSet](Axis a) {
return targetShardingPartialAxesSet.contains(a);
});
- MeshShardingAttr resultSharding =
- MeshShardingAttr::get(builder.getContext(), sourceSharding.getMesh(),
- sourceSharding.getSplitAxes(), remainingPartialAxes,
- sourceSharding.getPartialType());
+ MeshSharding resultSharding = MeshSharding::get(
+ sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
+ remainingPartialAxes, sourceSharding.getPartialType());
return {resultValue, resultSharding};
}
-static MeshShardingAttr
-targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
- int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
+ MeshSharding sourceSharding,
+ int64_t splitTensorAxis,
+ MeshAxis splitMeshAxis) {
SmallVector<MeshAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
@@ -120,17 +120,17 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
targetSplitAxes.push_back(splitMeshAxis);
targetShardingSplitAxes[splitTensorAxis] =
MeshAxesAttr::get(ctx, targetSplitAxes);
- return MeshShardingAttr::get(
- ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
+ return MeshSharding::get(
+ sourceSharding.getMeshAttr(), targetShardingSplitAxes,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
}
// Split a replicated tensor along a mesh axis.
// e.g. [[0, 1]] -> [[0, 1, 2]].
// Returns the spmdized target value with its sharding.
-static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+static std::tuple<TypedValue<ShapedType>, MeshSharding>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
- MeshShardingAttr sourceSharding,
+ MeshSharding sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
@@ -139,7 +139,7 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
ArrayRef<MeshAxis>(splitMeshAxis),
splitTensorAxis)
.getResult());
- MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
+ MeshSharding targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
return {targetShard, targetSharding};
}
@@ -150,8 +150,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
// Does not detect insertions like
// [[0, 1]] -> [[0, 2, 1]].
static std::optional<std::tuple<int64_t, MeshAxis>>
-detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
- MeshShardingAttr targetSharding) {
+detectSplitLastAxisInResharding(MeshSharding sourceSharding,
+ MeshSharding targetSharding) {
for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
++tensorAxis) {
if (sourceSharding.getSplitAxes().size() > tensorAxis) {
@@ -181,10 +181,10 @@ detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
return std::nullopt;
}
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshShardingAttr sourceSharding,
- MeshShardingAttr targetSharding,
+ MeshSharding sourceSharding,
+ MeshSharding targetSharding,
TypedValue<ShapedType> sourceShard) {
if (auto detectRes =
detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
@@ -200,8 +200,8 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// [[0, 1, 2]] -> [[0, 1]].
// If detected, returns the corresponding tensor axis mesh axis pair.
static std::optional<std::tuple<int64_t, MeshAxis>>
-detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
- MeshShardingAttr targetSharding) {
+detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
+ MeshSharding targetSharding) {
for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
++tensorAxis) {
if (targetSharding.getSplitAxes().size() > tensorAxis) {
@@ -228,10 +228,9 @@ detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
return std::nullopt;
}
-static MeshShardingAttr
-targetShardingInUnsplitLastAxis(MLIRContext *ctx,
- MeshShardingAttr sourceSharding,
- int64_t splitTensorAxis) {
+static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+ MeshSharding sourceSharding,
+ int64_t splitTensorAxis) {
SmallVector<MeshAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
@@ -242,8 +241,8 @@ targetShardingInUnsplitLastAxis(MLIRContext *ctx,
targetSplitAxes.pop_back();
targetShardingSplitAxes[splitTensorAxis] =
MeshAxesAttr::get(ctx, targetSplitAxes);
- return MeshShardingAttr::get(
- ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
+ return MeshSharding::get(
+ sourceSharding.getMeshAttr(), targetShardingSplitAxes,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
}
@@ -255,16 +254,16 @@ static ShapedType allGatherResultShapeInUnsplitLastAxis(
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
}
-static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+static std::tuple<TypedValue<ShapedType>, MeshSharding>
unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
- MeshShardingAttr sourceSharding,
+ MeshSharding sourceSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
MLIRContext *ctx = builder.getContext();
builder.setInsertionPointAfterValue(sourceShard);
- MeshShardingAttr targetSharding =
+ MeshSharding targetSharding =
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
@@ -280,10 +279,10 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
return {targetShard, targetSharding};
}
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshShardingAttr sourceSharding,
- MeshShardingAttr targetSharding,
+ MeshSharding sourceSharding,
+ MeshSharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
if (auto detectRes =
@@ -303,8 +302,8 @@ tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// If detected, returns the corresponding (source_tensor_axis,
// target_tensor_axis, mesh_axis) tuple.
static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
-detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
- MeshShardingAttr targetSharding) {
+detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
+ MeshSharding targetSharding) {
for (size_t sourceTensorAxis = 0;
sourceTensorAxis < sourceSharding.getSplitAxes().size();
++sourceTensorAxis) {
@@ -344,10 +343,10 @@ detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
return std::nullopt;
}
-static MeshShardingAttr
-targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
- int64_t sourceTensorAxis,
- int64_t targetTensorAxis) {
+static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
+ MeshSharding sourceSharding,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis) {
SmallVector<MeshAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
@@ -369,8 +368,8 @@ targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
targetShardingSplitAxes[targetTensorAxis] =
MeshAxesAttr::get(ctx, targetSplitAxes);
- return MeshShardingAttr::get(
- ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
+ return MeshSharding::get(
+ sourceSharding.getMeshAttr(), targetShardingSplitAxes,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
}
@@ -386,9 +385,9 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
}
-static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+static std::tuple<TypedValue<ShapedType>, MeshSharding>
moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshShardingAttr sourceSharding,
+ MeshSharding sourceSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard,
int64_t sourceTensorAxis,
@@ -396,7 +395,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
MLIRContext *ctx = builder.getContext();
builder.setInsertionPointAfterValue(sourceShard);
- MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
+ MeshSharding targetSharding = targetShardingInMoveLastAxis(
ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
@@ -413,10 +412,10 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
return {targetShard, targetSharding};
}
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshShardingAttr sourceSharding,
- MeshShardingAttr targetSharding,
+ MeshSharding sourceSharding,
+ MeshSharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
if (auto detectRes =
@@ -435,8 +434,7 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// mesh axis size.
static TypedValue<ShapedType>
reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshShardingAttr sourceSharding,
- MeshShardingAttr targetSharding,
+ MeshSharding sourceSharding, MeshSharding targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
assert(sourceShard.getType() ==
@@ -455,31 +453,34 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
}
TypedValue<ShapedType> targetShard;
- MeshShardingAttr actualTargetSharding;
- if (auto tryRes = tryMoveLastSplitAxisInResharding(
- builder, mesh, reducedSourceSharding, targetSharding,
- sourceUnshardedValue.getType(), reducedSourceShard)) {
- std::tie(targetShard, actualTargetSharding) = tryRes.value();
- } else if (auto tryRes = trySplitLastAxisInResharding(
- builder, mesh, reducedSourceSharding, targetSharding,
- reducedSourceShard)) {
- std::tie(targetShard, actualTargetSharding) = tryRes.value();
- } else if (auto tryRes = tryUnsplitLastAxisInResharding(
- builder, mesh, reducedSourceSharding, targetSharding,
- sourceUnshardedValue.getType(), reducedSourceShard)) {
- std::tie(targetShard, actualTargetSharding) = tryRes.value();
- } else {
- assert(false && "Did not find any pattern to apply.");
+ MeshSharding actualTargetSharding;
+ if (reducedSourceSharding.getStaticHaloSizes().empty() &&
+ targetSharding.getStaticHaloSizes().empty() &&
+ reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
+ targetSharding.getStaticShardedDimsSizes().empty()) {
+ if (auto tryRes = tryMoveLastSplitAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else if (auto tryRes = trySplitLastAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else if (auto tryRes = tryUnsplitLastAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ }
}
-
+ assert(targetShard && "Did not find any pattern to apply.");
assert(actualTargetSharding == targetSharding);
assert(targetShard.getType() == targetShardType);
return targetShard;
}
TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshShardingAttr sourceSharding,
- MeshShardingAttr targetSharding,
+ MeshSharding sourceSharding,
+ MeshSharding targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
// Resort to handling only 1D meshes since the general case is complicated if
@@ -492,11 +493,13 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue) {
- assert(source.getResult() == target.getOperand());
+ assert(source.getResult() == target.getSrc());
+ auto sourceSharding = source.getSharding();
+ auto targetSharding = target.getSharding();
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
- return reshard(
- implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
- cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
+ return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
+ cast<TypedValue<ShapedType>>(source.getSrc()),
+ sourceShardValue);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
@@ -538,15 +541,23 @@ shardedBlockArgumentTypes(Block &block,
assert(shardOp);
MeshOp mesh = getMesh(shardOp, symbolTableCollection);
return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
- shardOp.getShardAttr()));
+ shardOp.getSharding()));
});
return res;
}
+void spmdizeTriviallyShardableOperation(Operation &op,
+ ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder);
+
static LogicalResult spmdizeOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshShardingAttr> operandShardings,
- ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
if (!shardingInterface) {
@@ -572,41 +583,41 @@ static LogicalResult spmdizeOperation(
// Retrieve the sharding annotations for the operands of the given operation.
// If the type is not a ranked tensor it is not require to have an annotation.
-static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
- SmallVector<MeshShardingAttr> res;
+static std::vector<MeshSharding> getOperandShardings(Operation &op) {
+ std::vector<MeshSharding> res;
res.reserve(op.getNumOperands());
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(operand);
if (!rankedTensor) {
- return MeshShardingAttr();
+ return MeshSharding();
}
Operation *definingOp = operand.getDefiningOp();
assert(definingOp);
ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
- return shardOp.getShard();
+ return MeshSharding(shardOp.getSharding());
});
return res;
}
// Retrieve the sharding annotations for the results of the given operation.
// If the type is not a ranked tensor it is not require to have an annotation.
-static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
- SmallVector<MeshShardingAttr> res;
+static std::vector<MeshSharding> getResultShardings(Operation &op) {
+ std::vector<MeshSharding> res;
res.reserve(op.getNumResults());
llvm::transform(op.getResults(), std::back_inserter(res),
[](OpResult result) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
- return MeshShardingAttr();
+ return MeshSharding();
}
assert(result.hasOneUse());
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
- return shardOp.getShard();
+ return MeshSharding(shardOp.getSharding());
});
return res;
}
@@ -620,13 +631,13 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
// Check if 2 shard ops are chained. If not there is no need for resharding
// as the source and target shared the same sharding.
ShardOp srcShardOp =
- dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
+ dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
if (!srcShardOp) {
- targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
+ targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
} else {
// Insert resharding.
TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
- spmdizationMap.lookup(srcShardOp.getOperand()));
+ spmdizationMap.lookup(srcShardOp.getSrc()));
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
symbolTableCollection);
}
@@ -640,6 +651,10 @@ static LogicalResult
spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) {
+ if (isa<ShardingOp>(op)) {
+ return success();
+ }
+
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
if (shardOp) {
return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
index 59fd548dc2ef2c..f8521165e3244e 100644
--- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
@@ -3,7 +3,7 @@
// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
// RUN: %s | FileCheck %s
-mesh.mesh @mesh_2_2(shape = 2)
+mesh.mesh @mesh_2(shape = 2)
// CHECK-LABEL: func @matmul_shard_prallel_axis
func.func @matmul_shard_prallel_axis(
@@ -14,20 +14,28 @@ func.func @matmul_shard_prallel_axis(
// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
%out_dps: tensor<2x2xf32>
) -> tensor<2x2xf32> {
- // CHECK: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[}}[0]]> : tensor<2x3xf32>
- // CHECK: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x3xf32>
- // CHECK: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
- // CHECK: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x2xf32>
- %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+ // CHECK: %[[SIN1_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
+ // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
+ // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
+ // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
+ // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
+ // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
+ // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
+ // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
+ %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding
+ %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
// CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
// CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
%res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
- // CHECK: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to <@mesh_2, {{\[}}[0]]> : tensor<2x2xf32>
- // CHECK: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to <@mesh_2, {{\[}}[]]> annotate_for_users : tensor<2x2xf32>
- %res_sharded = mesh.shard %res to <@mesh_2, [[]]> annotate_for_users : tensor<2x2xf32>
+ // CHECK: %[[SRES_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
+ // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32>
+ // CHECK: %[[SRES_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
+ // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32>
+ %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding
+ %res_sharded = mesh.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32>
// CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
return %res_sharded : tensor<2x2xf32>
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index 52f352cfedd8e2..487cec00de16a3 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -18,12 +18,13 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor(
%dps_out: tensor<2xi8>
// CHECK-SAME: -> tensor<1xi8> {
) -> tensor<2xi8> {
- %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<2xi8>
- %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
- %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<2xi8>
- %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
- %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<2xi8>
- %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %in1_sharded1 = mesh.shard %in1 to %sharding : tensor<2xi8>
+ %in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+ %in2_sharded1 = mesh.shard %in2 to %sharding : tensor<2xi8>
+ %in2_sharded2 = mesh.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+ %dps_out_sharded1 = mesh.shard %dps_out to %sharding : tensor<2xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8>
// CHECK: %[[RES:.*]] = linalg.generic {
// CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]],
// CHECK-SAME: iterator_types = ["parallel"]}
@@ -32,14 +33,14 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor(
%res = linalg.generic {
indexing_maps = [#map_identity_1d, #map_identity_1d, #map_identity_1d],
iterator_types = ["parallel"]
- } ins(%in1_shared2, %in2_shared2 : tensor<2xi8>, tensor<2xi8>)
+ } ins(%in1_sharded2, %in2_sharded2 : tensor<2xi8>, tensor<2xi8>)
outs(%dps_out_shared2 : tensor<2xi8>) {
^bb0(%in1_scalar: i8, %in2_scalar: i8, %out: i8):
%res_scalar = arith.muli %in1_scalar, %in2_scalar : i8
linalg.yield %res_scalar : i8
} -> tensor<2xi8>
- %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<2xi8>
- %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %res_sharded1 = mesh.shard %res to %sharding : tensor<2xi8>
+ %res_shared2 = mesh.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8>
// CHECK: return %[[RES]] : tensor<1xi8>
return %res_shared2 : tensor<2xi8>
}
@@ -58,20 +59,22 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<1x8xi8> {
) -> tensor<4x8xi8> {
- %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<4x3xi8>
- %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x3xi8>
- %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[]]> : tensor<3x8xi8>
- %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<3x8xi8>
- %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<4x8xi8>
- %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8>
+ %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x3xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8>
+ %sharding2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<3x8xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8>
+ %dps_out_shared1 = mesh.shard %dps_out to %sharding : tensor<4x8xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
// CHECK: %[[RES:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>)
// CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>)
// CHECK-SAME: -> tensor<1x8xi8>
%res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>)
outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
- %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<4x8xi8>
- %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8>
+ %res_shared1 = mesh.shard %res to %sharding : tensor<4x8xi8>
+ %res_shared2 = mesh.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
// CHECK: return %[[RES]] : tensor<1x8xi8>
return %res_shared2 : tensor<4x8xi8>
}
@@ -90,12 +93,15 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<4x8xi8> {
) -> tensor<4x8xi8> {
- %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8>
- %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8>
- %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8>
- %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8>
- %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8>
- %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+ %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
+ %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
+ %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
+ %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
// CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
@@ -114,8 +120,8 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
// CHECK: %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
%res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
- %res_shared1 = mesh.shard %res to <@mesh_1d, [[]]> : tensor<4x8xi8>
- %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+ %res_shared1 = mesh.shard %res to %sharding3 : tensor<4x8xi8>
+ %res_shared2 = mesh.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
// CHECK: return %[[ALL_REDUCED]] : tensor<4x8xi8>
return %res_shared2 : tensor<4x8xi8>
}
@@ -134,12 +140,16 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partia
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<4x8xi8> {
) -> tensor<4x8xi8> {
- %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8>
- %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8>
- %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8>
- %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8>
- %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8>
- %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+ %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
+ %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
+ %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
+ %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8>
+ %sdps_out_shared2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
// CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
@@ -157,8 +167,9 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partia
// CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
%res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
- %res_shared1 = mesh.shard %res to <@mesh_1d, [[]], partial = sum[0]> : tensor<4x8xi8>
- %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]], partial = sum[0]> annotate_for_users: tensor<4x8xi8>
+ %sharding4 = mesh.sharding @mesh_1d split_axes = [[]] partial = sum[0] : !mesh.sharding
+ %res_shared1 = mesh.shard %res to %sharding4 : tensor<4x8xi8>
+ %res_shared2 = mesh.shard %res_shared1 to %sharding4 annotate_for_users : tensor<4x8xi8>
// CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
return %res_shared2 : tensor<4x8xi8>
}
@@ -177,14 +188,16 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<4x8xi8> {
) -> tensor<4x8xi8> {
- %in1_replicated1 = mesh.shard %in1 to <@mesh_1d, [[], []]> : tensor<4x6xi8>
- %in1_replicated2 = mesh.shard %in1_replicated1 to <@mesh_1d, [[], []]> annotate_for_users : tensor<4x6xi8>
+ %sharding1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
+ %in1_replicated1 = mesh.shard %in1 to %sharding1 : tensor<4x6xi8>
+ %in1_replicated2 = mesh.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8>
// CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
- %in2_replicated = mesh.shard %in2 to <@mesh_1d, [[], []]> : tensor<6x8xi8>
- %in2_sharded = mesh.shard %in2_replicated to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<6x8xi8>
+ %in2_replicated = mesh.shard %in2 to %sharding1 : tensor<6x8xi8>
+ %sharding2 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
+ %in2_sharded = mesh.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8>
// CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
- %dps_out_replicated = mesh.shard %dps_out to <@mesh_1d, [[], []]> : tensor<4x8xi8>
- %dps_out_sharded = mesh.shard %dps_out_replicated to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x8xi8>
+ %dps_out_replicated = mesh.shard %dps_out to %sharding1 : tensor<4x8xi8>
+ %dps_out_sharded = mesh.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8>
// CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
// CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
@@ -192,8 +205,8 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
%res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>)
outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8>
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
- %res_sharded = mesh.shard %res to <@mesh_1d, [[], [0]]> : tensor<4x8xi8>
- %res_replicated = mesh.shard %res_sharded to <@mesh_1d, [[], []]> annotate_for_users: tensor<4x8xi8>
+ %res_sharded = mesh.shard %res to %sharding2 : tensor<4x8xi8>
+ %res_replicated = mesh.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8>
// CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
return %res_replicated : tensor<4x8xi8>
}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 633324ae680eb1..ea2bd29056ec78 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -31,7 +31,7 @@ func.func @all_reduce_default_reduction(
%0 = mesh.all_reduce %arg0 on @mesh0
mesh_axes = [0]
// CHECK-NOT: reduction
- reduction = <sum>
+ reduction = sum
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
@@ -159,7 +159,7 @@ func.func @reduce_scatter_default_reduction(
%0 = mesh.reduce_scatter %arg0 on @mesh0
mesh_axes = [0]
// CHECK-NOT: reduction
- reduction = <sum>
+ reduction = sum
scatter_axis = 0
: tensor<4xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 6d7df86d78406f..3827df90e6962f 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -15,7 +15,8 @@ mesh.mesh @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_
diff erent_subarray(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// expected-error at +1 {{mesh axis duplicated}}
- %0 = mesh.shard %arg0 to <@mesh0, [[0], [0]]> : tensor<4x8xf32>
+ %s = mesh.sharding @mesh0 split_axes = [[0], [0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
@@ -26,7 +27,8 @@ mesh.mesh @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_same_subarray(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// expected-error at +1 {{mesh axis duplicated}}
- %0 = mesh.shard %arg0 to <@mesh0, [[0, 0]]> : tensor<4x8xf32>
+ %s = mesh.sharding @mesh0 split_axes = [[0, 0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
@@ -37,7 +39,8 @@ mesh.mesh @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_bewteen_split_and_partial(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// expected-error at +1 {{mesh axis duplicated}}
- %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[0]> : tensor<4x8xf32>
+ %s = mesh.sharding @mesh0 split_axes = [[0]] partial=max[0] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
@@ -48,7 +51,8 @@ mesh.mesh @mesh0(shape = 2x4)
func.func @mesh_axis_negtive_in_split_part(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// expected-error at +1 {{mesh axis is expected to be non-negative}}
- %0 = mesh.shard %arg0 to <@mesh0, [[-1]]> : tensor<4x8xf32>
+ %s = mesh.sharding @mesh0 split_axes = [[-1]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
@@ -59,16 +63,46 @@ mesh.mesh @mesh0(shape = 2x4)
func.func @mesh_axis_negtive_in_partial(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// expected-error at +1 {{mesh axis is expected to be non-negative}}
- %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[-1]> : tensor<4x8xf32>
+ %s = mesh.sharding @mesh0 split_axes = [[0]] partial=max[-1] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// -----
func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
- // expected-error at +2 {{custom op 'mesh.shard' invalid kind of attribute specified}}
- // expected-error at +1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'mesh' which is to be a `::mlir::FlatSymbolRefAttr`}}
- %0 = mesh.shard %arg0 to <@a::@b, [[0]]> : tensor<4x8xf32>
+ // expected-error at +1 {{custom op 'mesh.sharding' invalid kind of attribute specified}}
+ %s = mesh.sharding @a::@b split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ return
+}
+
+// -----
+
+func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
+ // expected-error at +1 {{halo sizes must be specified for all split axes}}
+ %s = mesh.sharding @mesh0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ return
+}
+
+// -----
+
+func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
+ // expected-error at +1 {{halo sizes and shard shapes are mutually exclusive}}
+ %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_sizes = [2, 2] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ return
+}
+
+// -----
+
+mesh.mesh @mesh_dyn(shape = ?x?)
+func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
+ // expected-error at +1 {{sharded dims sizes are not allowed for devices meshes with dynamic shape}}
+ %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_sizes = [2, 2] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ return
}
// -----
@@ -180,7 +214,7 @@ func.func @process_linear_index_invalid_mesh_name() -> (index) {
func.func @all_reduce_invalid_mesh_symbol(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = <sum>
+ %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = sum
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
@@ -192,7 +226,7 @@ mesh.mesh @mesh0(shape = 2x4)
func.func @all_reduce_invalid_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = <sum>
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = sum
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
@@ -204,7 +238,7 @@ mesh.mesh @mesh0(shape = 2x4)
func.func @all_reduce_duplicate_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = <sum>
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = sum
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 6e5df86b13106a..5ead7babe2c084 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -20,24 +20,30 @@ mesh.mesh @mesh5(shape = ?)
// CHECK-LABEL: func @mesh_shard_op_fully_replicated
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}]]> : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh0, [[]]> : tensor<4x8xf32>
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
+ %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
+ // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_1st_dim
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
+ %s = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_2nd_dim
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh1, {{\[\[}}], [0]]> : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh1, [[], [0]]> : tensor<4x8xf32>
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
+ %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
+ // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
@@ -45,8 +51,10 @@ func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
func.func @mesh_shard_op_1st_and_3rd_dim(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32>
%arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
- // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]]> : tensor<4x8x16xf32>
- %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]]> : tensor<4x8x16xf32>
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0], [], [1]] : !mesh.sharding
+ %s = mesh.sharding @mesh3 split_axes = [[0], [], [1]] : !mesh.sharding
+ // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32>
+ %0 = mesh.shard %arg0 to %s : tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
@@ -54,8 +62,10 @@ func.func @mesh_shard_op_1st_and_3rd_dim(
func.func @mesh_shard_op_partial_max(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = max[1]> : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = max[1]> : tensor<4x8xf32>
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = max [1] : !mesh.sharding
+ %s = mesh.sharding @mesh3 split_axes = [[0]] partial = max[1] : !mesh.sharding
+ // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
@@ -63,8 +73,10 @@ func.func @mesh_shard_op_partial_max(
func.func @mesh_shard_op_partial_min(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = min[1]> : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = min[1]> : tensor<4x8xf32>
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = min [1] : !mesh.sharding
+ %s = mesh.sharding @mesh3 split_axes = [[0]] partial = min[1] : !mesh.sharding
+ // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
@@ -72,8 +84,10 @@ func.func @mesh_shard_op_partial_min(
func.func @mesh_shard_op_partial_generic(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = generic[1]> : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = generic[1]> : tensor<4x8xf32>
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = generic [1] : !mesh.sharding
+ %s = mesh.sharding @mesh3 split_axes = [[0]] partial = generic[1] : !mesh.sharding
+ // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
@@ -81,8 +95,10 @@ func.func @mesh_shard_op_partial_generic(
func.func @mesh_shard_op_partial_sum(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1]> : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1]> : tensor<4x8xf32>
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = sum [1] : !mesh.sharding
+ %s = mesh.sharding @mesh3 split_axes = [[0]] partial = sum[1] : !mesh.sharding
+ // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
@@ -90,24 +106,64 @@ func.func @mesh_shard_op_partial_sum(
func.func @mesh_shard_op_partial_sum_multi_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1, 2]> : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1, 2]> : tensor<4x8xf32>
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0]] partial = sum [1, 2] : !mesh.sharding
+ %s = mesh.sharding @mesh3 split_axes = [[0]] partial = sum[1, 2] : !mesh.sharding
+ // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_two_users
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
+func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
(tensor<4x8xf32>, tensor<4x8xf32>) {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
- // CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}1]]> annotate_for_users : tensor<4x8xf32>
- %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
- // CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}2]]> annotate_for_users : tensor<4x8xf32>
- %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
+ // CHECK-NEXT: %[[V0:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
+ %s0 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<4x8xf32>
+ // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}1]] : !mesh.sharding
+ %s1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<4x8xf32>
+ // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}2]] : !mesh.sharding
+ %s2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
+ %2 = mesh.shard %0 to %s2 annotate_for_users : tensor<4x8xf32>
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
+// CHECK-LABEL: func @mesh_shard_halo_sizes
+func.func @mesh_shard_halo_sizes() -> () {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : i64
+ %c3 = arith.constant 3 : i64
+ // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !mesh.sharding
+ %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [1, 4] : !mesh.sharding
+ // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !mesh.sharding
+ %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [4, %c3] : !mesh.sharding
+ return
+}
+
+// CHECK-LABEL: func @mesh_shard_dims_sizes
+func.func @mesh_shard_dims_sizes() -> () {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : i64
+ %c3 = arith.constant 3 : i64
+ // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
+ %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
+ // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [4, %[[C3]], 1] : !mesh.sharding
+ %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [4, %c3, 1] : !mesh.sharding
+ return
+}
+
+// CHECK-LABEL: func @mesh_shard_shape
+func.func @mesh_shard_shape() {
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ %c3 = arith.constant 3 : index
+ // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
+ %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
+ // CHECK-NEXT: mesh.shard_shape 8x? %[[S]] %[[C3]] : index, index
+ %shp:2 = mesh.shard_shape 8x? %s %c3 : index, index
+ // CHECK-NEXT: mesh.shard_shape 8x4 %[[S]] %[[C3]] : index, index
+ %shp1:2 = mesh.shard_shape 8x4 %s %c3 : index, index
+ return
+}
+
// CHECK-LABEL: func @mesh_shape
func.func @mesh_shape() -> (index, index) {
// CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
@@ -168,9 +224,9 @@ func.func @process_linear_index() -> index {
func.func @all_reduce(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
- // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = <max>
+ // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = max
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = max
: tensor<3x4xf32> -> tensor<3x4xf64>
return %0 : tensor<3x4xf64>
}
@@ -442,10 +498,10 @@ func.func @reduce_scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
// CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = <max> scatter_axis = 1
+ // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = max scatter_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
- reduction = <max> scatter_axis = 1
+ reduction = max scatter_axis = 1
: tensor<3x4xf32> -> tensor<3x1xf64>
return %0 : tensor<3x1xf64>
}
@@ -553,3 +609,24 @@ func.func @shift(
: tensor<2xi8> -> tensor<2xi8>
return %0 : tensor<2xi8>
}
+
+// CHECK-LABEL: func @update_halo
+func.func @update_halo(
+ // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+ %arg0 : memref<12x12xi8>) {
+ // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
+ // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+ // CHECK-SAME: split_axes = {{\[\[}}0]]
+ // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
+ %c2 = arith.constant 2 : i64
+ mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+ halo_sizes = [2, %c2] : memref<12x12xi8>
+ // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+ // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
+ // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
+ // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
+ mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
+ halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
+ : memref<12x12xi8>
+ return
+}
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index b3e305135ad8b7..9ceaadacd6f664 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -8,8 +8,22 @@ func.func @same_source_and_target_sharding(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
%arg0: tensor<2xf32>
) -> tensor<2xf32> {
- %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xf32>
- %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<2xf32>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xf32>
+ // CHECK: return %[[ARG]]
+ return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @identical_source_and_target_sharding
+func.func @identical_source_and_target_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+ %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+ %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
+ %1 = mesh.shard %0 to %s0 annotate_for_users : tensor<2xf32>
// CHECK: return %[[ARG]]
return %1 : tensor<2xf32>
}
@@ -22,8 +36,10 @@ func.func @split_replicated_tensor_axis(
// CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1
// CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
// CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
- %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
- %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<3x14xf32>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<3x14xf32>
// CHECK: return %[[RESULT]] : tensor<3x14xf32>
return %1 : tensor<3x14xf32>
}
@@ -35,8 +51,10 @@ func.func @split_replicated_tensor_axis_dynamic(
) -> tensor<?x3x?xf32> {
// CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0
// CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
- %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
- %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
+ %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [], []] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<?x3x?xf32>
+ %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32>
// CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
return %1 : tensor<?x3x?xf32>
}
@@ -49,8 +67,10 @@ func.func @move_split_axis(
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
// CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
- %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
- %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
// CHECK: return %[[RES]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}
@@ -64,8 +84,10 @@ func.func @move_split_axis_dynamic_mesh(
// CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
// CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
- %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
- %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+ %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [0]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
// CHECK: return %[[RES]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}
@@ -77,8 +99,10 @@ func.func @move_split_dynamic_axis(
) -> tensor<?x14xf32> {
// CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
- %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
- %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<?x14xf32>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
// CHECK: return %[[RES]] : tensor<?x14xf32>
return %1 : tensor<?x14xf32>
}
@@ -90,8 +114,10 @@ func.func @unshard_static_axis(
) -> tensor<10x14xf32> {
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
- %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
- %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
// CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}
@@ -103,8 +129,10 @@ func.func @unshard_static_last_axis(
) -> tensor<10x14xf32> {
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
- %0 = mesh.shard %arg0 to <@mesh_1d, [[], [0]]> : tensor<10x14xf32>
- %1 = mesh.shard %0 to <@mesh_1d, [[], []]> annotate_for_users : tensor<10x14xf32>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
// CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}
@@ -115,8 +143,10 @@ func.func @unshard_dynamic_axis(
%arg0: tensor<?x14xf32>
) -> tensor<?x14xf32> {
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
- %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
- %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<?x14xf32>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
// CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
return %1 : tensor<?x14xf32>
}
@@ -129,8 +159,10 @@ func.func @unshard_static_axis_on_dynamic_mesh_axis(
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
// CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
- %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
- %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[]]> annotate_for_users : tensor<10x14xf32>
+ %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
// CHECK: return %[[RES]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}
@@ -141,8 +173,10 @@ func.func @partial_axis_to_full_replication(
%arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
// CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
- %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
- %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[]] partial = sum[0] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
// CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 11a80594adb79b..5b00b45653dbb6 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -16,11 +16,14 @@ func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8
// CHECK-LABEL: func.func @element_wise_on_def
// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
%0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] : tensor<8x16xf32>
+ %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 : tensor<8x16xf32>
// CHECK-NEXT: return %[[V2]]
return %1 : tensor<8x16xf32>
}
@@ -28,11 +31,14 @@ func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK-LABEL: func.func @element_wise_on_use
// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
%1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] : tensor<8x16xf32>
// CHECK-NEXT: return %[[V2]]
return %1 : tensor<8x16xf32>
}
@@ -40,12 +46,15 @@ func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK-LABEL: func.func @element_wise_on_graph_output
// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
%0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] annotate_for_users : tensor<8x16xf32>
+ %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: return %[[V3]]
return %1 : tensor<8x16xf32>
}
@@ -53,12 +62,15 @@ func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16x
// CHECK-LABEL: func.func @element_wise_on_graph_input
// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
+ %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<8x16xf32>
// CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]]
%1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]] : tensor<8x16xf32>
// CHECK-NEXT: return %[[V3]]
return %1 : tensor<8x16xf32>
}
@@ -66,18 +78,21 @@ func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf
// CHECK-LABEL: func.func @arrow_structure
// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]]
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]] : tensor<8x16xf32>
%0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
// CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]]
- // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]] : tensor<8x16xf32>
%1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]]
- // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[S8:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]] : tensor<8x16xf32>
%2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- %3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+ %s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
+ %3 = mesh.shard %2 to %s3 : tensor<8x16xf32>
// CHECK-NEXT: return %[[V6]], %[[V8]]
return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
}
@@ -85,12 +100,16 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}0]]> annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
// CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
%0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<2x16x32xf32>
- %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
+ %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32>
// CHECK-NEXT: return %[[V3]]
return %1 : tensor<2x16x32xf32>
}
@@ -98,12 +117,16 @@ func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: ten
// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
// CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
%0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
- %1 = mesh.shard %0 to <@mesh_2d, [[], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
+ %s1 = mesh.sharding @mesh_2d split_axes = [[], [1]] partial = sum [0] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32>
// CHECK-NEXT: return %[[V3]]
return %1 : tensor<2x16x32xf32>
}
@@ -111,12 +134,16 @@ func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<
// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
- %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
+ %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
// CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
%1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
// CHECK-NEXT: return %[[V3]]
return %1 : tensor<2x16x32xf32>
}
@@ -124,13 +151,18 @@ func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<
// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
- %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
- %1 = mesh.shard %arg1 to <@mesh_2d, [[], [0]]> annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
+ %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
+ %s1 = mesh.sharding @mesh_2d split_axes = [[], [0]] : !mesh.sharding
+ %1 = mesh.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32>
// CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
%2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
// CHECK-NEXT: return %[[V3]]
return %2 : tensor<2x16x32xf32>
}
@@ -145,21 +177,23 @@ func.func @resolve_conflicting_annotations(
%out_dps: tensor<2x2xf32>
// CHECK-SAME: ) -> tensor<2x2xf32> {
) -> tensor<2x2xf32> {
- // CHECK: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x3xf32>
- // CHECK: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to <@mesh_2, {{\[}}]> annotate_for_users : tensor<2x3xf32>
- // CHECK: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
- // CHECK: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to <@mesh_2, {{\[}}]> annotate_for_users : tensor<2x2xf32>
- %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
-
+ // CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK-NEXT: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32>
+ // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
+ // CHECK-NEXT: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32>
+ // CHECK-NEXT: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32>
+ // CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32>
+ %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding
+ %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
// CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
%res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
-
- // CHECK: %[[MATMUL_SHARDED1:.*]] = mesh.shard %[[MATMUL]] to <@mesh_2, {{\[\[}}]]> : tensor<2x2xf32>
- %res_sharded = mesh.shard %res to <@mesh_2, [[]]> : tensor<2x2xf32>
-
- // CHECK: return %[[MATMUL_SHARDED1]] : tensor<2x2xf32>
+ // CHECK-NEXT: %[[SRES:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK-NEXT: %[[RES:.*]] = mesh.shard %[[MATMUL]] to %[[SRES]] : tensor<2x2xf32>
+ %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding
+ %res_sharded = mesh.shard %res to %sres_sharded : tensor<2x2xf32>
+ // CHECK: return %[[RES]] : tensor<2x2xf32>
return %res_sharded : tensor<2x2xf32>
}
@@ -167,23 +201,30 @@ func.func @resolve_conflicting_annotations(
// CHECK-LABEL: func.func @mlp_1d_weight_stationary
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
- %0 = mesh.shard %arg0 to <@mesh_1d, [[], [], [0]]> : tensor<2x4x8xf32>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
+ // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding
+ // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding
// CHECK: %[[V0:.*]] = tosa.matmul
%1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
- // CHECK-DAG: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_1d, {{\[\[}}], [], [0]]> : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x32xf32>
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S2]] : tensor<2x4x32xf32>
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] annotate_for_users : tensor<2x4x32xf32>
// CHECK-DAG: %[[V3:.*]] = tosa.sigmoid %[[V2]]
%2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- // CHECK-DAG: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_1d, {{\[\[}}], [], [0]]> : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V5:.*]] = mesh.shard %[[V4]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V6:.*]] = mesh.shard %[[ARG2]] to <@mesh_1d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x32x8xf32>
+ // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S2]] : tensor<2x4x32xf32>
+ // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK-DAG: %[[S6:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0]] : !mesh.sharding
+ // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[ARG2]] to %[[S6]] annotate_for_users : tensor<2x32x8xf32>
// CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]]
%3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
- // CHECK-DAG: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_1d, {{\[\[}}], [], []], partial = sum[0]> : tensor<2x4x8xf32>
- %4 = mesh.shard %3 to <@mesh_1d, [[], [], []], partial = sum[0]> : tensor<2x4x8xf32>
- // CHECK-DAG: %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
- %5 = mesh.shard %4 to <@mesh_1d, [[], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
- // CHECK-DAG: return %[[V9]]
+ %s4 = mesh.sharding @mesh_1d split_axes = [[], [], []] partial = sum [0] : !mesh.sharding
+ %4 = mesh.shard %3 to %s4 : tensor<2x4x8xf32>
+ // CHECK: %[[S8:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], []] partial = sum [0] : !mesh.sharding
+ // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]] : tensor<2x4x8xf32>
+ %s5 = mesh.sharding @mesh_1d split_axes = [[], [], [0]] : !mesh.sharding
+ %5 = mesh.shard %4 to %s5 annotate_for_users : tensor<2x4x8xf32>
+ // CHECK: %[[V9:.*]] = mesh.shard %[[V8]] to %[[S1]] annotate_for_users : tensor<2x4x8xf32>
+ // CHECK-NEXT: return %[[V9]]
return %5 : tensor<2x4x8xf32>
}
@@ -191,26 +232,37 @@ func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
// CHECK-LABEL: func.func @mlp_2d_weight_stationary
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
- // CHECK-DAG: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_3d, {{\[\[}}], [], [0, 1, 2]]> : tensor<2x4x8xf32>
- %0 = mesh.shard %arg0 to <@mesh_3d, [[], [], [0, 1, 2]]> : tensor<2x4x8xf32>
- // CHECK-DAG: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_3d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
- // CHECK-DAG: %[[V2:.*]] = mesh.shard %[[ARG1]] to <@mesh_3d, {{\[\[}}], [0], [1, 2]]> annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-DAG: %[[S0:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] : tensor<2x4x8xf32>
+ %s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
+ // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users : tensor<2x4x8xf32>
+ // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[ARG1]] to %[[S2]] annotate_for_users : tensor<2x8x32xf32>
// CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]]
%1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
- // CHECK-DAG: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_3d, {{\[\[}}], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
- %2 = mesh.shard %1 to <@mesh_3d, [[], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V5:.*]] = mesh.shard %[[V4]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> annotate_for_users : tensor<2x4x32xf32>
+ // CHECK-DAG: %[[S4:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [1, 2]] partial = sum [0] : !mesh.sharding
+ // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S4]] : tensor<2x4x32xf32>
+ %s2 = mesh.sharding @mesh_3d split_axes = [[], [], [1, 2]] partial = sum [0] : !mesh.sharding
+ %2 = mesh.shard %1 to %s2 : tensor<2x4x32xf32>
+ // CHECK-DAG: %[[S5:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [1, 2]] : !mesh.sharding
+ // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S5]] annotate_for_users : tensor<2x4x32xf32>
// CHECK-DAG: %[[V6:.*]] = tosa.sigmoid %[[V5]]
%3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- // CHECK-DAG: %[[V7:.*]] = mesh.shard %[[V6]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> annotate_for_users : tensor<2x4x32xf32>
- // CHECK-DAG: %[[V9:.*]] = mesh.shard %[[ARG2]] to <@mesh_3d, {{\[\[}}], [1, 2], [0]]> annotate_for_users : tensor<2x32x8xf32>
+ // CHECK-NEXT: %[[V7:.*]] = mesh.shard %[[V6]] to %[[S5]] : tensor<2x4x32xf32>
+ // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S5]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK-DAG: %[[S9:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding
+ // CHECK-NEXT: %[[V9:.*]] = mesh.shard %[[ARG2]] to %[[S9]] annotate_for_users : tensor<2x32x8xf32>
// CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]]
%4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
- // CHECK-DAG: %[[V11:.*]] = mesh.shard %[[V10]] to <@mesh_3d, {{\[\[}}], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
- %5 = mesh.shard %4 to <@mesh_3d, [[], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
- // CHECK-DAG: %[[V12:.*]] = mesh.shard %[[V11]] to <@mesh_3d, {{\[\[}}], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
- %6 = mesh.shard %5 to <@mesh_3d, [[], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
+ // CHECK-DAG: %[[S11:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0]] partial = sum [1, 2] : !mesh.sharding
+ // CHECK-NEXT: %[[V11:.*]] = mesh.shard %[[V10]] to %[[S11]] : tensor<2x4x8xf32>
+ %s5 = mesh.sharding @mesh_3d split_axes = [[], [], [0]] partial = sum[1, 2] : !mesh.sharding
+ %5 = mesh.shard %4 to %s5 : tensor<2x4x8xf32>
+ // CHECK-NEXT: %[[V12:.*]] = mesh.shard %[[V11]] to %[[S0]] annotate_for_users : tensor<2x4x8xf32>
+ %s6 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
+ %6 = mesh.shard %5 to %s6 annotate_for_users : tensor<2x4x8xf32>
// CHECK-DAG: return %[[V12]]
return %6 : tensor<2x4x8xf32>
}
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
index d748be82c5a46a..2540fbf9510c40 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -100,8 +100,8 @@ func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = <max>
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <max>
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = max
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = max
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
%1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
@@ -138,13 +138,13 @@ func.func @all_reduce_arith_minimumf_endomorphism(
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
: tensor<5xf32> -> tensor<5xf32>
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
%2 = arith.minimumf %0, %1 : tensor<5xf32>
- // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xf32>
}
@@ -155,13 +155,13 @@ func.func @all_reduce_arith_minsi_endomorphism(
%arg0: tensor<5xi32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
%arg1: tensor<5xi32>) -> tensor<5xi32> {
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
: tensor<5xi32> -> tensor<5xi32>
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
: tensor<5xi32> -> tensor<5xi32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
%2 = arith.minsi %0, %1 : tensor<5xi32>
- // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xi32>
}
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index d7a1e2fd9d2790..8b0c4053b0dc7e 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -10,8 +10,10 @@ func.func @full_replication(
%arg0: tensor<2xi8>
// CHECK-SAME: -> tensor<2xi8> {
) -> tensor<2xi8> {
- %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
- %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
// CHECK: return %[[ARG]] : tensor<2xi8>
return %1 : tensor<2xi8>
}
@@ -23,9 +25,12 @@ func.func @sharding_triplet(
// CHECK-SAME: ) -> tensor<2xf32> {
) -> tensor<2xf32> {
// CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
- %sharding_annotated = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xf32>
- %sharding_annotated_0 = mesh.shard %sharding_annotated to <@mesh_1d, [[0]]> annotate_for_users : tensor<2xf32>
- %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to <@mesh_1d, [[]]> : tensor<2xf32>
+ %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32>
+ %ssharding_annotated_0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated_0 = mesh.shard %sharding_annotated to %ssharding_annotated_0 annotate_for_users : tensor<2xf32>
+ %ssharding_annotated_1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 : tensor<2xf32>
// CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
return %sharding_annotated_1 : tensor<2xf32>
}
@@ -39,8 +44,10 @@ func.func @move_split_axis(
) -> tensor<2x2xi8> {
// CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d
// CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
- %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2x2xi8>
- %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<2x2xi8>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<2x2xi8>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2x2xi8>
// CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
return %1 : tensor<2x2xi8>
}
@@ -63,12 +70,16 @@ func.func @unary_elementwise(
%arg0: tensor<2xi8>
// CHECK-SAME: -> tensor<1xi8> {
) -> tensor<2xi8> {
- %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
- %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
// CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
%2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
- %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
- %4 = mesh.shard %3 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %3 = mesh.shard %2 to %s3 : tensor<2xi8>
+ %s4 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8>
// CHECK: return %[[RES]] : tensor<1xi8>
return %4 : tensor<2xi8>
}
@@ -82,14 +93,18 @@ func.func @unary_elementwise_with_resharding(
) -> tensor<2xi8> {
// CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
// CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
- %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
- %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
// CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
%2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
// CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d
// CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
- %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
- %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+ %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %3 = mesh.shard %2 to %s3 : tensor<2xi8>
+ %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8>
// CHECK: return %[[RES]] : tensor<2xi8>
return %4 : tensor<2xi8>
}
@@ -102,14 +117,20 @@ func.func @binary_elementwise(
%arg1: tensor<2xi8>
// CHECK-SAME: -> tensor<1xi8> {
) -> tensor<2xi8> {
- %arg0_sharded = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
- %op_arg0 = mesh.shard %arg0_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
- %arg1_sharded = mesh.shard %arg1 to <@mesh_1d, [[0]]> : tensor<2xi8>
- %op_arg1 = mesh.shard %arg1_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %sarg0_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2xi8>
+ %sop_arg0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %op_arg0 = mesh.shard %arg0_sharded to %sop_arg0 annotate_for_users : tensor<2xi8>
+ %sarg1_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %arg1_sharded = mesh.shard %arg1 to %sarg1_sharded : tensor<2xi8>
+ %sop_arg1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %op_arg1 = mesh.shard %arg1_sharded to %sop_arg1 annotate_for_users : tensor<2xi8>
// CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
%op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
- %op_res_sharded = mesh.shard %op_res to <@mesh_1d, [[0]]> : tensor<2xi8>
- %res = mesh.shard %op_res_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %sop_res_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %op_res_sharded = mesh.shard %op_res to %sop_res_sharded : tensor<2xi8>
+ %sres = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %res = mesh.shard %op_res_sharded to %sres annotate_for_users : tensor<2xi8>
// CHECK: return %[[RES]] : tensor<1xi8>
return %res : tensor<2xi8>
}
@@ -127,20 +148,26 @@ func.func @multiple_chained_ops(
) -> tensor<2xi8> {
// CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
// CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
- %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
- %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
// CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
%2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
// CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d
// CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
- %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
- %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+ %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %3 = mesh.shard %2 to %s3 : tensor<2xi8>
+ %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8>
// CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
%5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
// CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 :
// CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
- %6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8>
- %7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
+ %6 = mesh.shard %5 to %s6 : tensor<2xi8>
+ %s7 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %7 = mesh.shard %6 to %s7 annotate_for_users : tensor<2xi8>
// CHECK: return %[[RESHARD3]] : tensor<1xi8>
return %7 : tensor<2xi8>
}
@@ -151,10 +178,44 @@ func.func @incomplete_sharding(
%arg0: tensor<8x16xf32>
// CHECK-SAME: -> tensor<4x16xf32> {
) -> tensor<8x16xf32> {
- %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32>
+ %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
// CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
%1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- %2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32>
+ %s2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
+ %2 = mesh.shard %1 to %s2 : tensor<8x16xf32>
// CHECK: return %[[RES]] : tensor<4x16xf32>
return %2 : tensor<8x16xf32>
}
+
+mesh.mesh @mesh_1d_4(shape = 4)
+
+// CHECK-LABEL: func @ew_chain_with_halo
+func.func @ew_chain_with_halo(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
+ %arg0: tensor<8x16xf32>)
+ // CHECK-SAME: -> tensor<5x16xf32>
+ -> tensor<8x16xf32> {
+ %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+ %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated annotate_for_users : tensor<8x16xf32>
+ // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+ %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %ssharding_annotated_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+ %sharding_annotated_0 = mesh.shard %0 to %ssharding_annotated_0 : tensor<8x16xf32>
+ %ssharding_annotated_1 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+ %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+ %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %ssharding_annotated_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+ %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2 : tensor<8x16xf32>
+ %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+ %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+ %2 = tosa.negate %sharding_annotated_4 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+ %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5 : tensor<8x16xf32>
+ %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+ %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to %ssharding_annotated_6 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
+ return %sharding_annotated_6 : tensor<8x16xf32>
+}
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
new file mode 100644
index 00000000000000..611acb5b41445b
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN: %s | FileCheck %s
+
+mesh.mesh @mesh_1d_4(shape = 4)
+
+// CHECK-LABEL: func @tensor_empty_static_sharded_dims_sizes
+func.func @tensor_empty_static_sharded_dims_sizes() -> () {
+ %b = tensor.empty() : tensor<8x16xf32>
+ %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+ %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
+ // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+ // CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
+ // CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
+ // CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_sizes
+// CHECK-SAME: %[[A0:.*]]: index
+func.func @tensor_empty_dynamic_sharded_dims_sizes(%arg0 : index) -> () {
+ %b = tensor.empty(%arg0) : tensor<8x?xf32>
+ %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+ %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
+ // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+ // CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
+ // CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
+ // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
+func.func @tensor_empty_same_static_dims_sizes() -> () {
+ %b = tensor.empty() : tensor<8x16xf32>
+ %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [4, 4, 4, 4] : !mesh.sharding
+ %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
+ // CHECK-NEXT: tensor.empty() : tensor<4x16xf32>
+
+ return
+}
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
index f96410245f2815..98992c4cc11f92 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -37,14 +37,17 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
SymbolTableCollection symbolTable;
mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
- op, op.getShard().getMesh());
+ op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
bool foundUser = false;
for (auto user : op->getUsers()) {
if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
if (targetShardOp.getAnnotateForUsers() &&
mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
- targetShardOp, targetShardOp.getShard().getMesh())) {
+ targetShardOp,
+ cast<ShardingOp>(
+ targetShardOp.getSharding().getDefiningOp())
+ .getMeshAttr())) {
foundUser = true;
break;
}
@@ -59,17 +62,18 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
- targetShardOp, targetShardOp.getShard().getMesh()) != mesh) {
+ targetShardOp,
+ cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
+ .getMeshAttr()) != mesh) {
continue;
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ShapedType sourceShardShape =
- shardShapedType(op.getResult().getType(), mesh, op.getShard());
+ shardShapedType(op.getResult().getType(), mesh, op.getSharding());
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
builder
- .create<UnrealizedConversionCastOp>(sourceShardShape,
- op.getOperand())
+ .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
->getResult(0));
TypedValue<ShapedType> targetShard =
reshard(builder, mesh, op, targetShardOp, sourceShard);
More information about the Mlir-commits
mailing list