[Mlir-commits] [mlir] [mlir][mesh] Handling changed halo region sizes during spmdization (PR #114238)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 30 07:26:58 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
- Changed `MeshSharding::sharded_dims_sizes` from representing sizes per shard to offsets to origin per shard.
- Local shard size are now a simple subtraction
- Offsets are now readily available without a reduction operation
- Enables constant value/shape propagation through standard canonicalization
- Renamed to `sharded_dims_offsets` accordingly.
- First spmdization pattern for halo regions.
- Triggers when source and destination shardings differ only in their halo sizes
- Copies local data from source into a new tensor and calls update_halo
- Supports arbitrary mesh dimensions (unlike the other patterns which work on 1d meshes only)
- `UpdateHaloOp` implements `DestinationStyleOpInterface` and accepts tensors and memrefs
- also accepts target and source halo sizes; both are required for proper lowering
- minor refactoring for testing partial MeshSharding equality
- Canonicalization for ShardingOp folding constant values into respective `static_*` attributes
At some point, we should probably refactor how spmdization treats various resharding patterns.
@<!-- -->sogartar @<!-- -->yaochengji @<!-- -->mfrancio Could you please have a look?
---
Patch is 41.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114238.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+11-8)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+46-25)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+107-55)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+94-5)
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+17)
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+3-3)
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+12-14)
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+31)
- (modified) mlir/test/Dialect/Tensor/mesh-spmdization.mlir (+11-11)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index db7b64fda57d7b..75cb096130ca6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,6 +15,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/MathExtras.h"
@@ -45,9 +46,9 @@ class MeshSharding {
SmallVector<MeshAxis> partial_axes;
ReductionKind partial_type;
SmallVector<int64_t> static_halo_sizes;
- SmallVector<int64_t> static_sharded_dims_sizes;
+ SmallVector<int64_t> static_sharded_dims_offsets;
SmallVector<Value> dynamic_halo_sizes;
- SmallVector<Value> dynamic_sharded_dims_sizes;
+ SmallVector<Value> dynamic_sharded_dims_offsets;
public:
MeshSharding() = default;
@@ -57,21 +58,21 @@ class MeshSharding {
ArrayRef<MeshAxis> partial_axes_ = {},
ReductionKind partial_type_ = ReductionKind::Sum,
ArrayRef<int64_t> static_halo_sizes_ = {},
- ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+ ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
ArrayRef<Value> dynamic_halo_sizes_ = {},
- ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+ ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
::llvm::StringRef getMesh() const { return mesh.getValue(); }
ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
ReductionKind getPartialType() const { return partial_type; }
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
- ArrayRef<int64_t> getStaticShardedDimsSizes() const {
- return static_sharded_dims_sizes;
+ ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
+ return static_sharded_dims_offsets;
}
ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
- ArrayRef<Value> getDynamicShardedDimsSizes() const {
- return dynamic_sharded_dims_sizes;
+ ArrayRef<Value> getDynamicShardedDimsOffsets() const {
+ return dynamic_sharded_dims_offsets;
}
operator bool() const { return (!mesh) == false; }
bool operator==(Value rhs) const;
@@ -80,6 +81,8 @@ class MeshSharding {
bool operator!=(const MeshSharding &rhs) const;
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
+ bool equalHaloSizes(const MeshSharding &rhs) const;
+ bool equalShardSizes(const MeshSharding &rhs) const;
};
} // namespace mesh
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..04b4b55a433803 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -11,6 +11,7 @@
include "mlir/Dialect/Mesh/IR/MeshBase.td"
include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
@@ -196,16 +197,18 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
`?` indicates dynamic halo sizes.
- 6. [Optional] Sizes of sharded dimensions of each shard.
- `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
- device-mesh one value for each sharded tensor dimension.
+ 6. [Optional] Offsets for each shard and sharded tensor dimension.
+ `sharded_dims_offsets` is provided as a flattened 1d array of i64s:
+ For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
+ The offset of each first shard is omitted and is implicitly assumed to be 0.
+ The last value per dimension denotes the end of the last shard.
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
- `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
- the device-mesh will get a shard of shape 16x8x32 and the second device will get a
- shard of shape 16x24x32.
+ `sharded_dims_offsets` = [24, 32, 20, 32] means that the first device of
+ the device-mesh will get a shard of shape 24x20x32 and the second device will get a
+ shard of shape 8x12x32.
`?` indicates dynamic shard dimensions.
- `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+ `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
Examples:
@@ -240,7 +243,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
// The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
- %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
+ %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_offsets = [2, 5, 9, 14]
%sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
```
}];
@@ -250,8 +253,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
Mesh_MeshAxesArrayAttr:$split_axes,
OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
- Variadic<I64>:$dynamic_sharded_dims_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
+ Variadic<I64>:$dynamic_sharded_dims_offsets,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
Variadic<I64>:$dynamic_halo_sizes
);
@@ -263,7 +266,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
`split_axes` `=` $split_axes
(`partial` `=` $partial_type $partial_axes^)?
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
- (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+ (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
attr-dict `:` type($result)
}];
let builders = [
@@ -272,16 +275,17 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
"ArrayRef<MeshAxis>":$partial_axes,
"mesh::ReductionKind":$partial_type,
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
- CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+ CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes)>,
OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<MeshAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
+ "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
@@ -1052,37 +1056,54 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
}
def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
- DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ DestinationStyleOpInterface,
+ TypesMatchWith<
+ "result has same type as destination",
+ "result", "destination", "$_self">,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ AttrSizedOperandSegments
]> {
let summary = "Update halo data.";
let description = [{
This operation updates halo regions of shards, e.g. if their sharding
- specified halos and the actual tensor data might have changed
+ specified halos and the actual tensor/memref data might have changed
on the remote devices. Changes might be caused by mutating operations
and/or if the new halo regions are larger than the existing ones.
+ Source and destination might have different halo sizes.
+
Assumes all devices hold tensors with same-sized halo data as specified
- by `dynamic/static_halo_sizes`.
+ by `source_halo_sizes/static_source_halo_sizes` and
+ `destination_halo_sizes/static_destination_halo_sizes` in source shard
+ and destination/result shard.
`split_axes` specifies for each tensor axis along which mesh axes its halo
data is updated.
- Optionally resizes to new halo sizes `target_halo_sizes`.
}];
let arguments = (ins
- AnyNon0RankedMemRef:$input,
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
FlatSymbolRefAttr:$mesh,
Mesh_MeshAxesArrayAttr:$split_axes,
- Variadic<I64>:$dynamic_halo_sizes,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
+ Variadic<I64>:$source_halo_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
+ Variadic<I64>:$destination_halo_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
+ );
+ let results = (outs
+ AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
);
let assemblyFormat = [{
- $input `on` $mesh
+ $source `into` $destination
+ `on` $mesh
`split_axes` `=` $split_axes
- (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
- (`target_halo_sizes` `=` $target_halo_sizes^)?
- attr-dict `:` type($input)
+ (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
+ (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
+ attr-dict `:` type($source) `->` type($result)
+ }];
+ let extraClassDeclaration = [{
+ MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
}];
}
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 19e9212157ae47..d65f7e4bbadd1a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -192,33 +192,33 @@ template <typename InShape, typename MeshShape, typename SplitAxes,
typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
const SplitAxes &splitAxes, OutShape &outShape,
- ArrayRef<int64_t> shardedDimsSizes = {},
+ ArrayRef<int64_t> shardedDimsOffsets = {},
ArrayRef<int64_t> haloSizes = {}) {
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
llvm::adl_begin(outShape));
- if (!shardedDimsSizes.empty()) {
+ if (!shardedDimsOffsets.empty()) {
+ uint64_t pos = 0;
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
- if (innerSplitAxes.empty()) {
-#ifndef NDEBUG
- for (auto dimSz : shardedDimsSizes) {
- auto inAxis = dimSz % inShape.size();
- assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
- inShape[inAxis] == ShapedType::kDynamic);
- }
-#endif // NDEBUG
- } else {
- // find sharded dims in sharded_dims_sizes with same static size on
- // all devices. Use kDynamic for dimensions with dynamic or non-uniform
- // sizes in sharded_dims_sizes.
- auto sz = shardedDimsSizes[tensorAxis];
- bool same = true;
- for (size_t i = tensorAxis + inShape.size();
- i < shardedDimsSizes.size(); i += inShape.size()) {
- if (shardedDimsSizes[i] != sz) {
- same = false;
- break;
+ if (!innerSplitAxes.empty()) {
+ auto sz = shardedDimsOffsets[pos];
+ bool same = !ShapedType::isDynamicShape(meshShape);
+ if (same) {
+ // find sharded dims in shardedDimsOffsets with same static size on
+ // all devices. Use kDynamic for dimensions with dynamic or
+ // non-uniform offs in shardedDimsOffsets.
+ uint64_t numShards = 0;
+ for (auto i : innerSplitAxes.asArrayRef()) {
+ numShards += meshShape[i];
+ }
+ for (size_t i = 1; i < numShards; ++i) {
+ if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
+ sz) {
+ same = false;
+ break;
+ }
}
+ pos += numShards;
}
outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
}
@@ -255,7 +255,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
SmallVector<Dim> resShapeArr(shape.getShape().size());
shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
- resShapeArr, sharding.getStaticShardedDimsSizes(),
+ resShapeArr, sharding.getStaticShardedDimsOffsets(),
sharding.getStaticHaloSizes());
return shape.clone(resShapeArr);
}
@@ -432,13 +432,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
ArrayRef<MeshAxis> partial_axes,
mesh::ReductionKind partial_type,
ArrayRef<int64_t> static_halo_sizes,
- ArrayRef<int64_t> static_sharded_dims_sizes) {
+ ArrayRef<int64_t> static_sharded_dims_offsets) {
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
- ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
+ ::mlir::DenseI64ArrayAttr::get(b.getContext(),
+ static_sharded_dims_offsets),
{});
}
@@ -455,11 +456,11 @@ void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
- ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+ ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
mlir::SmallVector<int64_t> staticHalos, staticDims;
mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
- dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+ dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
return build(
b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
@@ -477,10 +478,10 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
: b.getDenseI16ArrayAttr(from.getPartialAxes()),
::mlir::mesh::ReductionKindAttr::get(b.getContext(),
from.getPartialType()),
- from.getStaticShardedDimsSizes().empty()
+ from.getStaticShardedDimsOffsets().empty()
? DenseI64ArrayAttr()
- : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
- from.getDynamicShardedDimsSizes(),
+ : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
+ from.getDynamicShardedDimsOffsets(),
from.getStaticHaloSizes().empty()
? DenseI64ArrayAttr()
: b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
@@ -509,7 +510,7 @@ LogicalResult ShardingOp::verify() {
failed(checkMeshAxis(getPartialAxes().value())))
return failure();
- if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
+ if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
return emitOpError("halo sizes and shard shapes are mutually exclusive");
}
@@ -539,13 +540,49 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
- getStaticShardedDimsSizes().size() > 0) {
- return emitError() << "sharded dims sizes are not allowed for "
+ getStaticShardedDimsOffsets().size() > 0) {
+ return emitError() << "sharded dims offsets are not allowed for "
"devices meshes with dynamic shape.";
}
return success();
}
+namespace {
+class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+public:
+ using OpRewritePattern<ShardingOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShardingOp op,
+ PatternRewriter &b) const override {
+ auto mixedHalos =
+ getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
+ auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
+ op.getDynamicShardedDimsOffsets(), b);
+
+ // No constant operands were folded, just return;
+ if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
+ failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
+ return failure();
+ }
+
+ auto halos = decomposeMixedValues(mixedHalos);
+ auto offs = decomposeMixedValues(mixedOffs);
+
+ op.setStaticHaloSizes(halos.first);
+ op.getDynamicHaloSizesMutable().assign(halos.second);
+ op.setStaticShardedDimsOffsets(offs.first);
+ op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
+
+ return success();
+ }
+};
+} // namespace
+
+void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+ mlir::MLIRContext *context) {
+ results.add<FoldDynamicLists>(context);
+}
+
//===----------------------------------------------------------------------===//
// MeshSharding
//===----------------------------------------------------------------------===//
@@ -555,7 +592,12 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
return false;
}
- if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+ if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
+ (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
+ !llvm::equal(
+ llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
+ llvm::make_range(rhs.getPartialAxes().begin(),
+ rhs.getPartialAxes().end()))) {
return false;
}
@@ -576,6 +618,31 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
}
bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
+ return equalShardSizes(rhs) && equalHaloSizes(rhs);
+}
+
+bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
+ if (rhs.getStaticShardedDimsOffsets().size() !=
+ getStaticShardedDimsOffsets().size() ||
+ !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
+ getStaticShardedDimsOffsets().end()),
+ llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
+ rhs.getStaticShardedDimsOffsets().end()))) {
+ return false;
+ }
+ if (rhs.getDynamicShardedDimsOffsets().size() !=
+ getDynamicShardedDimsOffsets().size() ||
+ !llvm::equal(
+ llvm::make_range(getDynamicShardedDimsOffsets().begin(),
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/114238
More information about the Mlir-commits
mailing list