[Mlir-commits] [mlir] ffc7fea - [mlir][mesh] Handling changed halo region sizes during spmdization (#114238)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 10 21:49:54 PST 2024
Author: Frank Schlimbach
Date: 2024-11-10T21:49:50-08:00
New Revision: ffc7feadece139c88f0e6930f16bfa9293747adc
URL: https://github.com/llvm/llvm-project/commit/ffc7feadece139c88f0e6930f16bfa9293747adc
DIFF: https://github.com/llvm/llvm-project/commit/ffc7feadece139c88f0e6930f16bfa9293747adc.diff
LOG: [mlir][mesh] Handling changed halo region sizes during spmdization (#114238)
* Changed `MeshSharding::sharded_dims_sizes` from representing sizes per
shard to offsets to origin per shard.
- Local shard size are now a simple subtraction
- Offsets are now readily available without a reduction operation
- Enables constant value/shape propagation through standard
canonicalization
- Renamed to `sharded_dims_offsets` accordingly.
* First spmdization pattern for halo regions.
- Triggers when source and destination shardings differ only in their
halo sizes
- Copies local data from source into a new tensor and calls update_halo
- Supports arbitrary mesh dimensions (unlike the other patterns which
work on 1d meshes only)
* `UpdateHaloOp` implements `DestinationStyleOpInterface` and accepts
tensors and memrefs
- also accepts target and source halo sizes; both are required for
proper lowering
* minor refactoring for testing partial MeshSharding equality
* Canonicalization for ShardingOp folding constant values into
respective `static_*` attributes
Added:
Modified:
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
mlir/test/Dialect/Mesh/canonicalization.mlir
mlir/test/Dialect/Mesh/invalid.mlir
mlir/test/Dialect/Mesh/ops.mlir
mlir/test/Dialect/Mesh/spmdization.mlir
mlir/test/Dialect/Tensor/mesh-spmdization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index db7b64fda57d7b..75cb096130ca6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,6 +15,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/MathExtras.h"
@@ -45,9 +46,9 @@ class MeshSharding {
SmallVector<MeshAxis> partial_axes;
ReductionKind partial_type;
SmallVector<int64_t> static_halo_sizes;
- SmallVector<int64_t> static_sharded_dims_sizes;
+ SmallVector<int64_t> static_sharded_dims_offsets;
SmallVector<Value> dynamic_halo_sizes;
- SmallVector<Value> dynamic_sharded_dims_sizes;
+ SmallVector<Value> dynamic_sharded_dims_offsets;
public:
MeshSharding() = default;
@@ -57,21 +58,21 @@ class MeshSharding {
ArrayRef<MeshAxis> partial_axes_ = {},
ReductionKind partial_type_ = ReductionKind::Sum,
ArrayRef<int64_t> static_halo_sizes_ = {},
- ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+ ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
ArrayRef<Value> dynamic_halo_sizes_ = {},
- ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+ ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
::llvm::StringRef getMesh() const { return mesh.getValue(); }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
- ArrayRef<int64_t> getStaticShardedDimsSizes() const {
- return static_sharded_dims_sizes;
+ ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
+ return static_sharded_dims_offsets;
}
ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
- ArrayRef<Value> getDynamicShardedDimsSizes() const {
- return dynamic_sharded_dims_sizes;
+ ArrayRef<Value> getDynamicShardedDimsOffsets() const {
+ return dynamic_sharded_dims_offsets;
}
operator bool() const { return (!mesh) == false; }
bool operator==(Value rhs) const;
@@ -80,6 +81,8 @@ class MeshSharding {
bool operator!=(const MeshSharding &rhs) const;
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
+ bool equalHaloSizes(const MeshSharding &rhs) const;
+ bool equalShardSizes(const MeshSharding &rhs) const;
};
} // namespace mesh
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..19498fe5a32d69 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -11,6 +11,7 @@
include "mlir/Dialect/Mesh/IR/MeshBase.td"
include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
@@ -189,23 +190,27 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
`generic`: is not an allowed value inside a shard attribute.
5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
- `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
- `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
- halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
- `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
- e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
- `?` indicates dynamic halo sizes.
+ `halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
+ sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension
+ gets an additional halo of size 1 at the start of the first dimension and a halo
+ size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
+ sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
+ seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
+
+ 6. [Optional] Offsets for each shard and sharded tensor dimension.
+ `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
+ sharded tensor dimension the offsets (starting index) of all shards in that
+ dimension and an additional value for the end of the last shard are provided.
+ For a 1d sharding this means that position `i` has the exclusive prefix sum for
+ shard `i`, and since only contiguous sharding is supported, its inclusive prefix
+ sum is at position 'i+1'.
- 6. [Optional] Sizes of sharded dimensions of each shard.
- `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
- device-mesh one value for each sharded tensor dimension.
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
- `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
- the device-mesh will get a shard of shape 16x8x32 and the second device will get a
- shard of shape 16x24x32.
- `?` indicates dynamic shard dimensions.
+ `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
+ the device-mesh will get a shard of shape 24x20x32 and the second device will get
+ a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
- `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+ `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
Examples:
@@ -240,7 +245,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
// The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
- %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
+ %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
%sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
```
}];
@@ -250,8 +255,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
Mesh_MeshAxesArrayAttr:$split_axes,
OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
- Variadic<I64>:$dynamic_sharded_dims_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
+ Variadic<I64>:$dynamic_sharded_dims_offsets,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
Variadic<I64>:$dynamic_halo_sizes
);
@@ -263,7 +268,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
`split_axes` `=` $split_axes
(`partial` `=` $partial_type $partial_axes^)?
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
- (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+ (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
attr-dict `:` type($result)
}];
let builders = [
@@ -272,16 +277,17 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
"ArrayRef<MeshAxis>":$partial_axes,
"mesh::ReductionKind":$partial_type,
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
- CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes)>,
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
+ "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
@@ -1052,37 +1058,54 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
}
def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
- DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ DestinationStyleOpInterface,
+ TypesMatchWith<
+ "result has same type as destination",
+ "result", "destination", "$_self">,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ AttrSizedOperandSegments
]> {
let summary = "Update halo data.";
let description = [{
This operation updates halo regions of shards, e.g. if their sharding
- specified halos and the actual tensor data might have changed
+ specified halos and the actual tensor/memref data might have changed
on the remote devices. Changes might be caused by mutating operations
and/or if the new halo regions are larger than the existing ones.
+ Source and destination might have
diff erent halo sizes.
+
Assumes all devices hold tensors with same-sized halo data as specified
- by `dynamic/static_halo_sizes`.
+ by `source_halo_sizes/static_source_halo_sizes` and
+ `destination_halo_sizes/static_destination_halo_sizes` in source shard
+ and destination/result shard.
`split_axes` specifies for each tensor axis along which mesh axes its halo
data is updated.
- Optionally resizes to new halo sizes `target_halo_sizes`.
}];
let arguments = (ins
- AnyNon0RankedMemRef:$input,
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
FlatSymbolRefAttr:$mesh,
Mesh_MeshAxesArrayAttr:$split_axes,
- Variadic<I64>:$dynamic_halo_sizes,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
+ Variadic<I64>:$source_halo_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
+ Variadic<I64>:$destination_halo_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
+ );
+ let results = (outs
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
);
let assemblyFormat = [{
- $input `on` $mesh
+ $source `into` $destination
+ `on` $mesh
`split_axes` `=` $split_axes
- (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
- (`target_halo_sizes` `=` $target_halo_sizes^)?
- attr-dict `:` type($input)
+ (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
+ (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
+ attr-dict `:` type($source) `->` type($result)
+ }];
+ let extraClassDeclaration = [{
+ MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
}];
}
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 19e9212157ae47..c5570d8ee8a443 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -192,33 +192,34 @@ template <typename InShape, typename MeshShape, typename SplitAxes,
typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
const SplitAxes &splitAxes, OutShape &outShape,
- ArrayRef<int64_t> shardedDimsSizes = {},
+ ArrayRef<int64_t> shardedDimsOffsets = {},
ArrayRef<int64_t> haloSizes = {}) {
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
llvm::adl_begin(outShape));
- if (!shardedDimsSizes.empty()) {
+ if (!shardedDimsOffsets.empty()) {
+ auto isDynShape = ShapedType::isDynamicShape(meshShape);
+ uint64_t pos = 1;
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
- if (innerSplitAxes.empty()) {
-#ifndef NDEBUG
- for (auto dimSz : shardedDimsSizes) {
- auto inAxis = dimSz % inShape.size();
- assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
- inShape[inAxis] == ShapedType::kDynamic);
- }
-#endif // NDEBUG
- } else {
- // find sharded dims in sharded_dims_sizes with same static size on
- // all devices. Use kDynamic for dimensions with dynamic or non-uniform
- // sizes in sharded_dims_sizes.
- auto sz = shardedDimsSizes[tensorAxis];
- bool same = true;
- for (size_t i = tensorAxis + inShape.size();
- i < shardedDimsSizes.size(); i += inShape.size()) {
- if (shardedDimsSizes[i] != sz) {
- same = false;
- break;
+ if (!innerSplitAxes.empty()) {
+ auto sz = shardedDimsOffsets[pos];
+ bool same = !isDynShape;
+ if (same) {
+ // Find sharded dims in shardedDimsOffsets with same static size on
+ // all devices. Use kDynamic for dimensions with dynamic or
+ // non-uniform offs in shardedDimsOffsets.
+ uint64_t numShards = 0;
+ for (auto i : innerSplitAxes.asArrayRef()) {
+ numShards += meshShape[i];
+ }
+ for (size_t i = 1; i < numShards; ++i) {
+ if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
+ sz) {
+ same = false;
+ break;
+ }
}
+ pos += numShards + 1;
}
outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
}
@@ -255,7 +256,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
SmallVector<Dim> resShapeArr(shape.getShape().size());
shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
- resShapeArr, sharding.getStaticShardedDimsSizes(),
+ resShapeArr, sharding.getStaticShardedDimsOffsets(),
sharding.getStaticHaloSizes());
return shape.clone(resShapeArr);
}
@@ -432,13 +433,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
ArrayRef<MeshAxis> partial_axes,
mesh::ReductionKind partial_type,
ArrayRef<int64_t> static_halo_sizes,
- ArrayRef<int64_t> static_sharded_dims_sizes) {
+ ArrayRef<int64_t> static_sharded_dims_offsets) {
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(),
+ static_sharded_dims_offsets),
{});
}
@@ -455,11 +457,11 @@ void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
- ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+ ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
mlir::SmallVector<int64_t> staticHalos, staticDims;
mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
- dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+ dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
@@ -477,10 +479,10 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
: b.getDenseI16ArrayAttr(from.getPartialAxes()),
::mlir::mesh::ReductionKindAttr::get(b.getContext(),
from.getPartialType()),
- from.getStaticShardedDimsSizes().empty()
+ from.getStaticShardedDimsOffsets().empty()
? DenseI64ArrayAttr()
- : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
- from.getDynamicShardedDimsSizes(),
+ : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
+ from.getDynamicShardedDimsOffsets(),
from.getStaticHaloSizes().empty()
? DenseI64ArrayAttr()
: b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
@@ -509,8 +511,8 @@ LogicalResult ShardingOp::verify() {
failed(checkMeshAxis(getPartialAxes().value())))
return failure();
- if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
- return emitOpError("halo sizes and shard shapes are mutually exclusive");
+ if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
+ return emitOpError("halo sizes and shard offsets are mutually exclusive");
}
if (!getStaticHaloSizes().empty()) {
@@ -539,13 +541,81 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
- getStaticShardedDimsSizes().size() > 0) {
- return emitError() << "sharded dims sizes are not allowed for "
+ getStaticShardedDimsOffsets().size() > 0) {
+ return emitError() << "sharded dims offsets are not allowed for "
"devices meshes with dynamic shape.";
}
+
+ auto shardedDimsOffsets = getStaticShardedDimsOffsets();
+ if (!shardedDimsOffsets.empty()) {
+ auto meshShape = mesh.value().getShape();
+ assert(!ShapedType::isDynamicShape(meshShape));
+ uint64_t pos = 0;
+ for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
+ if (!innerSplitAxes.empty()) {
+ int64_t numShards = 0, off = 0;
+ for (auto i : innerSplitAxes.asArrayRef()) {
+ numShards += meshShape[i];
+ }
+ for (int64_t i = 0; i <= numShards; ++i) {
+ if (shardedDimsOffsets.size() <= pos + i) {
+ return emitError() << "sharded dims offsets has wrong size.";
+ }
+ if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
+ if (shardedDimsOffsets[pos + i] < off) {
+ return emitError()
+ << "sharded dims offsets must be non-decreasing.";
+ }
+ off = shardedDimsOffsets[pos + i];
+ }
+ }
+ pos += numShards + 1;
+ }
+ }
+ }
return success();
}
+namespace {
+// Sharding annotations "halo sizes" and "sharded dims offsets"
+// are a mix of attributes and dynamic values. This canonicalization moves
+// constant values to the respective attribute lists and so minimizes the number
+// of values.
+class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+public:
+ using OpRewritePattern<ShardingOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShardingOp op,
+ PatternRewriter &b) const override {
+ auto mixedHalos =
+ getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
+ auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
+ op.getDynamicShardedDimsOffsets(), b);
+
+ // No constant operands were folded, just return;
+ if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
+ failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
+ return failure();
+ }
+
+ auto halos = decomposeMixedValues(mixedHalos);
+ auto offs = decomposeMixedValues(mixedOffs);
+
+ op.setStaticHaloSizes(halos.first);
+ op.getDynamicHaloSizesMutable().assign(halos.second);
+ op.setStaticShardedDimsOffsets(offs.first);
+ op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
+
+ return success();
+ }
+};
+} // namespace
+
+void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+ mlir::MLIRContext *context) {
+ results.add<FoldDynamicLists>(context);
+}
+
//===----------------------------------------------------------------------===//
// MeshSharding
//===----------------------------------------------------------------------===//
@@ -555,7 +625,12 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
return false;
}
- if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+ if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
+ (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
+ !llvm::equal(
+ llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
+ llvm::make_range(rhs.getPartialAxes().begin(),
+ rhs.getPartialAxes().end()))) {
return false;
}
@@ -576,6 +651,31 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
}
bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
+ return equalShardSizes(rhs) && equalHaloSizes(rhs);
+}
+
+bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
+ if (rhs.getStaticShardedDimsOffsets().size() !=
+ getStaticShardedDimsOffsets().size() ||
+ !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
+ getStaticShardedDimsOffsets().end()),
+ llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
+ rhs.getStaticShardedDimsOffsets().end()))) {
+ return false;
+ }
+ if (rhs.getDynamicShardedDimsOffsets().size() !=
+ getDynamicShardedDimsOffsets().size() ||
+ !llvm::equal(
+ llvm::make_range(getDynamicShardedDimsOffsets().begin(),
+ getDynamicShardedDimsOffsets().end()),
+ llvm::make_range(rhs.getDynamicShardedDimsOffsets().begin(),
+ rhs.getDynamicShardedDimsOffsets().end()))) {
+ return false;
+ }
+ return true;
+}
+
+bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
!llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
getStaticHaloSizes().end()),
@@ -583,28 +683,13 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
rhs.getStaticHaloSizes().end()))) {
return false;
}
- if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size() ||
- !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
- getStaticShardedDimsSizes().end()),
- llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
- rhs.getStaticShardedDimsSizes().end()))) {
- return false;
- }
- if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size() ||
+ if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
!llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
getDynamicHaloSizes().end()),
llvm::make_range(rhs.getDynamicHaloSizes().begin(),
rhs.getDynamicHaloSizes().end()))) {
return false;
}
- if (rhs.getDynamicShardedDimsSizes().size() !=
- getDynamicShardedDimsSizes().size() ||
- !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
- getDynamicShardedDimsSizes().end()),
- llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(),
- rhs.getDynamicShardedDimsSizes().end()))) {
- return false;
- }
return true;
}
@@ -629,9 +714,9 @@ MeshSharding::MeshSharding(Value rhs) {
shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
shardingOp.getPartialType().value_or(ReductionKind::Sum),
shardingOp.getStaticHaloSizes(),
- shardingOp.getStaticShardedDimsSizes(),
+ shardingOp.getStaticShardedDimsOffsets(),
SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
- SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
+ SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
}
MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
@@ -639,9 +724,9 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<MeshAxis> partial_axes_,
ReductionKind partial_type_,
ArrayRef<int64_t> static_halo_sizes_,
- ArrayRef<int64_t> static_sharded_dims_sizes_,
+ ArrayRef<int64_t> static_sharded_dims_offsets_,
ArrayRef<Value> dynamic_halo_sizes_,
- ArrayRef<Value> dynamic_sharded_dims_sizes_) {
+ ArrayRef<Value> dynamic_sharded_dims_offsets_) {
MeshSharding res;
res.mesh = mesh_;
res.split_axes.resize(split_axes_.size());
@@ -658,9 +743,9 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
clone(partial_axes_, res.partial_axes);
res.partial_type = partial_type_;
clone(static_halo_sizes_, res.static_halo_sizes);
- clone(static_sharded_dims_sizes_, res.static_sharded_dims_sizes);
+ clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
- clone(dynamic_sharded_dims_sizes_, res.dynamic_sharded_dims_sizes);
+ clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
return res;
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index fdfed39972fd52..b4d088cbd7088d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -126,7 +126,7 @@ static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
}
// Split a replicated tensor along a mesh axis.
-// e.g. [[0, 1]] -> [[0, 1, 2]].
+// E.g. [[0, 1]] -> [[0, 1, 2]].
// Returns the spmdized target value with its sharding.
static std::tuple<TypedValue<ShapedType>, MeshSharding>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
@@ -429,6 +429,85 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
return std::nullopt;
}
+// Detect a change in the halo size (only) and create necessary operations if
+// needed. A changed halo sizes requires copying the "core" of the source tensor
+// into the "core" of the destination tensor followed by an update halo
+// operation.
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
+tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
+ MeshSharding sourceSharding,
+ MeshSharding targetSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard) {
+ // Currently handles only cases where halo sizes
diff er but everything else
+ // stays the same (from source to destination sharding).
+ if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
+ !sourceSharding.getPartialAxes().empty() ||
+ !targetSharding.getPartialAxes().empty() ||
+ !sourceSharding.getStaticShardedDimsOffsets().empty() ||
+ !targetSharding.getStaticShardedDimsOffsets().empty() ||
+ sourceSharding.equalHaloSizes(targetSharding)) {
+ return std::nullopt;
+ }
+
+ auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
+ auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
+ assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
+ assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
+ !ShapedType::isDynamicShape(tgtHaloSizes) &&
+ sourceShard.getType().hasStaticShape()) &&
+ "dynamic shapes/halos are not supported yet for mesh-spmdization");
+ auto rank = sourceShard.getType().getRank();
+ auto splitAxes = sourceSharding.getSplitAxes();
+ SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
+ strides(rank, 1), outShape(sourceShard.getType().getShape()),
+ coreShape(sourceShard.getType().getShape());
+
+ // Determine "core" of source and destination.
+ // The core is the local part of the shard excluding halo regions.
+ for (auto i = 0u; i < rank; ++i) {
+ if (i < splitAxes.size() && !splitAxes[i].empty()) {
+ if (!srcHaloSizes.empty()) {
+ coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
+ srcCoreOffs[i] = srcHaloSizes[i * 2];
+ }
+ tgtCoreOffs[i] = tgtHaloSizes[i * 2];
+ outShape[i] =
+ coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
+ }
+ }
+
+ // Extract core from source and copy into destination core.
+ auto noVals = ValueRange{};
+ auto initVal = builder.create<tensor::EmptyOp>(
+ sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
+ auto core = builder.create<tensor::ExtractSliceOp>(
+ sourceShard.getLoc(),
+ RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
+ sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
+ auto initOprnd = builder.create<tensor::InsertSliceOp>(
+ sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
+ coreShape, strides);
+
+ // Finally update the halo.
+ auto updateHaloResult =
+ builder
+ .create<UpdateHaloOp>(
+ sourceShard.getLoc(),
+ RankedTensorType::get(outShape,
+ sourceShard.getType().getElementType()),
+ sourceShard, initOprnd, mesh.getSymName(),
+ MeshAxesArrayAttr::get(builder.getContext(),
+ sourceSharding.getSplitAxes()),
+ sourceSharding.getDynamicHaloSizes(),
+ sourceSharding.getStaticHaloSizes(),
+ targetSharding.getDynamicHaloSizes(),
+ targetSharding.getStaticHaloSizes())
+ .getResult();
+ return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
+ targetSharding);
+}
+
// Handles only resharding on a 1D mesh.
// Currently the sharded tensor axes must be exactly divisible by the single
// mesh axis size.
@@ -454,10 +533,10 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
TypedValue<ShapedType> targetShard;
MeshSharding actualTargetSharding;
- if (reducedSourceSharding.getStaticHaloSizes().empty() &&
- targetSharding.getStaticHaloSizes().empty() &&
- reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
- targetSharding.getStaticShardedDimsSizes().empty()) {
+ if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
+ targetSharding.getStaticShardedDimsOffsets().empty() &&
+ reducedSourceSharding.getStaticHaloSizes().empty() &&
+ targetSharding.getStaticHaloSizes().empty()) {
if (auto tryRes = tryMoveLastSplitAxisInResharding(
builder, mesh, reducedSourceSharding, targetSharding,
sourceUnshardedValue.getType(), reducedSourceShard)) {
@@ -483,6 +562,19 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshSharding targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
+ // If source and destination sharding are the same, no need to do anything.
+ if (sourceSharding == targetSharding) {
+ return sourceShard;
+ }
+
+ // Tries to handle the case where the resharding is needed because the halo
+ // sizes are
diff erent. Supports arbitrary mesh dimensionality.
+ if (auto tryRes = tryUpdateHaloInResharding(
+ builder, mesh, sourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), sourceShard)) {
+ return std::get<0>(tryRes.value()); // targetShard
+ }
+
// Resort to handling only 1D meshes since the general case is complicated if
// it needs to be communication efficient in terms of minimizing the data
// transfered between devices.
@@ -636,8 +728,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
} else {
// Insert resharding.
- TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
- spmdizationMap.lookup(srcShardOp.getSrc()));
+ TypedValue<ShapedType> srcSpmdValue =
+ cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
symbolTableCollection);
}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index ea2bd29056ec78..f0112d689805d3 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -191,3 +191,20 @@ func.func @send_empty_mesh_axes(
// CHECK: return %[[ARG]]
return %0 : tensor<4xf32>
}
+
+mesh.mesh @mesh4x4(shape = 4x4)
+// CHECK-LABEL: func @test_halo_sizes
+func.func @test_halo_sizes() -> !mesh.sharding {
+ %c2_i64 = arith.constant 2 : i64
+ // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding
+ %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
+ return %sharding : !mesh.sharding
+}
+
+// CHECK-LABEL: func @test_shard_offs
+func.func @test_shard_offs() -> !mesh.sharding {
+ %c2_i64 = arith.constant 2 : i64
+ // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
+ %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
+ return %sharding : !mesh.sharding
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3827df90e6962f..29b900a8da4a60 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -89,8 +89,8 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
// -----
func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
- // expected-error at +1 {{halo sizes and shard shapes are mutually exclusive}}
- %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_sizes = [2, 2] : !mesh.sharding
+ // expected-error at +1 {{halo sizes and shard offsets are mutually exclusive}}
+ %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return
}
@@ -99,8 +99,28 @@ func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
mesh.mesh @mesh_dyn(shape = ?x?)
func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
- // expected-error at +1 {{sharded dims sizes are not allowed for devices meshes with dynamic shape}}
- %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_sizes = [2, 2] : !mesh.sharding
+ // expected-error at +1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}}
+ %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ return
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 2x4)
+func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) {
+ // expected-error at +1 {{sharded dims offsets has wrong size}}
+ %s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding
+ %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ return
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 4)
+func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) {
+ // expected-error at +1 {{sharded dims offsets must be non-decreasing}}
+ %s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding
%0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
return
}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 5ead7babe2c084..d8df01c3d6520d 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -144,10 +144,10 @@ func.func @mesh_shard_halo_sizes() -> () {
func.func @mesh_shard_dims_sizes() -> () {
// CHECK: %[[C3:.*]] = arith.constant 3 : i64
%c3 = arith.constant 3 : i64
- // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
- %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
- // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [4, %[[C3]], 1] : !mesh.sharding
- %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [4, %c3, 1] : !mesh.sharding
+ // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
+ %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
+ // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !mesh.sharding
+ %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !mesh.sharding
return
}
@@ -615,18 +615,16 @@ func.func @update_halo(
// CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
%arg0 : memref<12x12xi8>) {
// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
- // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+ // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] into %[[ARG]] on @mesh0
// CHECK-SAME: split_axes = {{\[\[}}0]]
- // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
+ // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> -> memref<12x12xi8>
%c2 = arith.constant 2 : i64
- mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
- halo_sizes = [2, %c2] : memref<12x12xi8>
- // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+ %uh1 = mesh.update_halo %arg0 into %arg0 on @mesh0 split_axes = [[0]]
+ source_halo_sizes = [2, %c2] : memref<12x12xi8> -> memref<12x12xi8>
+ // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[ARG]] into %[[UH1]] on @mesh0
// CHECK-SAME: split_axes = {{\[\[}}0], [1]]
- // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
- // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
- mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
- halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
- : memref<12x12xi8>
+ // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> -> memref<12x12xi8>
+ %uh2 = mesh.update_halo %arg0 into %uh1 on @mesh0 split_axes = [[0], [1]]
+ source_halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> -> memref<12x12xi8>
return
}
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 8b0c4053b0dc7e..22ddb72569835d 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -219,3 +219,34 @@ func.func @ew_chain_with_halo(
// CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
return %sharding_annotated_6 : tensor<8x16xf32>
}
+
+// CHECK-LABEL: func @test_shard_update_halo
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
+func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+ %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
+ // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
+ // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
+ // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64>
+ %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
+ %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
+ %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
+ %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+ // CHECK: return %[[UH]] : tensor<304x1200xi64>
+ return %sharding_annotated_3 : tensor<1200x1200xi64>
+}
+
+mesh.mesh @mesh4x4(shape = 4x4)
+// CHECK-LABEL: func @test_shard_update_halo2d
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
+func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+ %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
+ // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
+ // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
+ // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64>
+ %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
+ %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
+ %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
+ %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+ // CHECK: return %[[UH]] : tensor<303x307xi64>
+ return %sharding_annotated_3 : tensor<1200x1200xi64>
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 611acb5b41445b..5443eea83aa2d8 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -4,12 +4,12 @@
mesh.mesh @mesh_1d_4(shape = 4)
-// CHECK-LABEL: func @tensor_empty_static_sharded_dims_sizes
-func.func @tensor_empty_static_sharded_dims_sizes() -> () {
+// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
+func.func @tensor_empty_static_sharded_dims_offsets() -> () {
%b = tensor.empty() : tensor<8x16xf32>
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+ %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
%sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
- // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+ // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
// CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
// CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
// CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
@@ -17,13 +17,13 @@ func.func @tensor_empty_static_sharded_dims_sizes() -> () {
return
}
-// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_sizes
+// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
// CHECK-SAME: %[[A0:.*]]: index
-func.func @tensor_empty_dynamic_sharded_dims_sizes(%arg0 : index) -> () {
+func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
%b = tensor.empty(%arg0) : tensor<8x?xf32>
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+ %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
%sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
- // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+ // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
// CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
// CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
// CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
@@ -33,9 +33,9 @@ func.func @tensor_empty_dynamic_sharded_dims_sizes(%arg0 : index) -> () {
// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
func.func @tensor_empty_same_static_dims_sizes() -> () {
- %b = tensor.empty() : tensor<8x16xf32>
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [4, 4, 4, 4] : !mesh.sharding
- %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
+ %b = tensor.empty() : tensor<16x16xf32>
+ %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !mesh.sharding
+ %sharded= mesh.shard %b to %sharding : tensor<16x16xf32>
// CHECK-NEXT: tensor.empty() : tensor<4x16xf32>
return
More information about the Mlir-commits
mailing list