[Mlir-commits] [mlir] [mlir][mesh] Handling changed halo region sizes during spmdization (PR #114238)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 30 07:26:58 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

- Changed `MeshSharding::sharded_dims_sizes` from representing sizes per shard to offsets to origin per shard.
  - Local shard size are now a simple subtraction
  - Offsets are now readily available without a reduction operation
  - Enables constant value/shape propagation through standard canonicalization
  - Renamed to `sharded_dims_offsets` accordingly.
- First spmdization pattern for halo regions.
  - Triggers when source and destination shardings differ only in their halo sizes
  - Copies local data from source into a new tensor and calls update_halo
  - Supports arbitrary mesh dimensions (unlike the other patterns which work on 1d meshes only)
- `UpdateHaloOp` implements `DestinationStyleOpInterface` and accepts tensors and memrefs
  - also accepts target and source halo sizes; both are required for proper lowering
- minor refactoring for testing partial MeshSharding equality
- Canonicalization for ShardingOp folding constant values into respective `static_*` attributes

At some point, we should probably refactor how spmdization treats various resharding patterns.

@<!-- -->sogartar @<!-- -->yaochengji @<!-- -->mfrancio Could you please have a look?

---

Patch is 41.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114238.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+11-8) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+46-25) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+107-55) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+94-5) 
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+17) 
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+3-3) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+12-14) 
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+31) 
- (modified) mlir/test/Dialect/Tensor/mesh-spmdization.mlir (+11-11) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index db7b64fda57d7b..75cb096130ca6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/MathExtras.h"
@@ -45,9 +46,9 @@ class MeshSharding {
   SmallVector<MeshAxis> partial_axes;
   ReductionKind partial_type;
   SmallVector<int64_t> static_halo_sizes;
-  SmallVector<int64_t> static_sharded_dims_sizes;
+  SmallVector<int64_t> static_sharded_dims_offsets;
   SmallVector<Value> dynamic_halo_sizes;
-  SmallVector<Value> dynamic_sharded_dims_sizes;
+  SmallVector<Value> dynamic_sharded_dims_offsets;
 
 public:
   MeshSharding() = default;
@@ -57,21 +58,21 @@ class MeshSharding {
                           ArrayRef<MeshAxis> partial_axes_ = {},
                           ReductionKind partial_type_ = ReductionKind::Sum,
                           ArrayRef<int64_t> static_halo_sizes_ = {},
-                          ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+                          ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
                           ArrayRef<Value> dynamic_halo_sizes_ = {},
-                          ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+                          ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
   ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
   ::llvm::StringRef getMesh() const { return mesh.getValue(); }
   ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
   ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
   ReductionKind getPartialType() const { return partial_type; }
   ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
-  ArrayRef<int64_t> getStaticShardedDimsSizes() const {
-    return static_sharded_dims_sizes;
+  ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
+    return static_sharded_dims_offsets;
   }
   ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
-  ArrayRef<Value> getDynamicShardedDimsSizes() const {
-    return dynamic_sharded_dims_sizes;
+  ArrayRef<Value> getDynamicShardedDimsOffsets() const {
+    return dynamic_sharded_dims_offsets;
   }
   operator bool() const { return (!mesh) == false; }
   bool operator==(Value rhs) const;
@@ -80,6 +81,8 @@ class MeshSharding {
   bool operator!=(const MeshSharding &rhs) const;
   bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
   bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
+  bool equalHaloSizes(const MeshSharding &rhs) const;
+  bool equalShardSizes(const MeshSharding &rhs) const;
 };
 
 } // namespace mesh
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..04b4b55a433803 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/Mesh/IR/MeshBase.td"
 include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -196,16 +197,18 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
     `?` indicates dynamic halo sizes.
     
-    6. [Optional] Sizes of sharded dimensions of each shard.
-    `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
-    device-mesh one value for each sharded tensor dimension.
+    6. [Optional] Offsets for each shard and sharded tensor dimension.
+    `sharded_dims_offsets` is provided as a flattened 1d array of i64s:
+    For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
+    The offset of each first shard is omitted and is implicitly assumed to be 0.
+    The last value per dimension denotes the end of the last shard.
     Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
-    `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
-    the device-mesh will get a shard of shape 16x8x32 and the second device will get a
-    shard of shape 16x24x32.
+    `sharded_dims_offsets` = [24, 32, 20, 32] means that the first device of
+    the device-mesh will get a shard of shape 24x20x32 and the second device will get a
+    shard of shape 8x12x32.
     `?` indicates dynamic shard dimensions.
     
-    `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+    `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
 
     Examples:
 
@@ -240,7 +243,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
     // and it has pre-defined shard sizes. The shards of the devices will have
     // the following shapes: [4x2, 4x3, 4x4, 4x5]
-    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
+    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_offsets = [2, 5, 9, 14]
     %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];
@@ -250,8 +253,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     Mesh_MeshAxesArrayAttr:$split_axes,
     OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
     OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
-    Variadic<I64>:$dynamic_sharded_dims_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
+    Variadic<I64>:$dynamic_sharded_dims_offsets,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
     Variadic<I64>:$dynamic_halo_sizes
   );
@@ -263,7 +266,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     `split_axes` `=` $split_axes
     (`partial` `=` $partial_type $partial_axes^)?
     (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
-    (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+    (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
     attr-dict `:` type($result)
   }];
   let builders = [
@@ -272,16 +275,17 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
                    "ArrayRef<MeshAxis>":$partial_axes,
                    "mesh::ReductionKind":$partial_type,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
-                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
-                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
+                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
     OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
   ];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
@@ -1052,37 +1056,54 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 }
 
 def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DestinationStyleOpInterface,
+  TypesMatchWith<
+    "result has same type as destination",
+    "result", "destination", "$_self">,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  AttrSizedOperandSegments
 ]> {
   let summary = "Update halo data.";
   let description = [{
     This operation updates halo regions of shards, e.g. if their sharding
-    specified halos and the actual tensor data might have changed
+    specified halos and the actual tensor/memref data might have changed
     on the remote devices. Changes might be caused by mutating operations
     and/or if the new halo regions are larger than the existing ones.
 
+    Source and destination might have different halo sizes.
+
     Assumes all devices hold tensors with same-sized halo data as specified
-    by `dynamic/static_halo_sizes`.
+    by `source_halo_sizes/static_source_halo_sizes` and
+    `destination_halo_sizes/static_destination_halo_sizes` in source shard
+    and destination/result shard.
 
     `split_axes` specifies for each tensor axis along which mesh axes its halo
     data is updated.
 
-    Optionally resizes to new halo sizes `target_halo_sizes`.
   }];
   let arguments = (ins
-    AnyNon0RankedMemRef:$input,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
     FlatSymbolRefAttr:$mesh,
     Mesh_MeshAxesArrayAttr:$split_axes,
-    Variadic<I64>:$dynamic_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
+    Variadic<I64>:$source_halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
+    Variadic<I64>:$destination_halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
+  );
+  let results = (outs
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh
+    $source `into` $destination
+    `on` $mesh
     `split_axes` `=` $split_axes
-    (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
-    (`target_halo_sizes` `=` $target_halo_sizes^)?
-    attr-dict `:` type($input)
+    (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
+    (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
+    attr-dict `:` type($source) `->` type($result)
+  }];
+  let extraClassDeclaration = [{
+    MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
   }];
 }
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 19e9212157ae47..d65f7e4bbadd1a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -192,33 +192,33 @@ template <typename InShape, typename MeshShape, typename SplitAxes,
           typename OutShape>
 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                        const SplitAxes &splitAxes, OutShape &outShape,
-                       ArrayRef<int64_t> shardedDimsSizes = {},
+                       ArrayRef<int64_t> shardedDimsOffsets = {},
                        ArrayRef<int64_t> haloSizes = {}) {
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
 
-  if (!shardedDimsSizes.empty()) {
+  if (!shardedDimsOffsets.empty()) {
+    uint64_t pos = 0;
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-      if (innerSplitAxes.empty()) {
-#ifndef NDEBUG
-        for (auto dimSz : shardedDimsSizes) {
-          auto inAxis = dimSz % inShape.size();
-          assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
-                 inShape[inAxis] == ShapedType::kDynamic);
-        }
-#endif // NDEBUG
-      } else {
-        // find sharded dims in sharded_dims_sizes with same static size on
-        // all devices. Use kDynamic for dimensions with dynamic or non-uniform
-        // sizes in sharded_dims_sizes.
-        auto sz = shardedDimsSizes[tensorAxis];
-        bool same = true;
-        for (size_t i = tensorAxis + inShape.size();
-             i < shardedDimsSizes.size(); i += inShape.size()) {
-          if (shardedDimsSizes[i] != sz) {
-            same = false;
-            break;
+      if (!innerSplitAxes.empty()) {
+        auto sz = shardedDimsOffsets[pos];
+        bool same = !ShapedType::isDynamicShape(meshShape);
+        if (same) {
+          // find sharded dims in shardedDimsOffsets with same static size on
+          // all devices. Use kDynamic for dimensions with dynamic or
+          // non-uniform offs in shardedDimsOffsets.
+          uint64_t numShards = 0;
+          for (auto i : innerSplitAxes.asArrayRef()) {
+            numShards += meshShape[i];
+          }
+          for (size_t i = 1; i < numShards; ++i) {
+            if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
+                sz) {
+              same = false;
+              break;
+            }
           }
+          pos += numShards;
         }
         outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
       }
@@ -255,7 +255,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
   shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
-             resShapeArr, sharding.getStaticShardedDimsSizes(),
+             resShapeArr, sharding.getStaticShardedDimsOffsets(),
              sharding.getStaticHaloSizes());
   return shape.clone(resShapeArr);
 }
@@ -432,13 +432,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        ArrayRef<MeshAxis> partial_axes,
                        mesh::ReductionKind partial_type,
                        ArrayRef<int64_t> static_halo_sizes,
-                       ArrayRef<int64_t> static_sharded_dims_sizes) {
+                       ArrayRef<int64_t> static_sharded_dims_offsets) {
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(),
+                                     static_sharded_dims_offsets),
       {});
 }
 
@@ -455,11 +456,11 @@ void ShardingOp::build(
     ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
     FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
     ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
-    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
   mlir::SmallVector<int64_t> staticHalos, staticDims;
   mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
   dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
-  dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+  dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
@@ -477,10 +478,10 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
             : b.getDenseI16ArrayAttr(from.getPartialAxes()),
         ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
                                              from.getPartialType()),
-        from.getStaticShardedDimsSizes().empty()
+        from.getStaticShardedDimsOffsets().empty()
             ? DenseI64ArrayAttr()
-            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
-        from.getDynamicShardedDimsSizes(),
+            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
+        from.getDynamicShardedDimsOffsets(),
         from.getStaticHaloSizes().empty()
             ? DenseI64ArrayAttr()
             : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
@@ -509,7 +510,7 @@ LogicalResult ShardingOp::verify() {
       failed(checkMeshAxis(getPartialAxes().value())))
     return failure();
 
-  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
+  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
     return emitOpError("halo sizes and shard shapes are mutually exclusive");
   }
 
@@ -539,13 +540,49 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
   if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
-      getStaticShardedDimsSizes().size() > 0) {
-    return emitError() << "sharded dims sizes are not allowed for "
+      getStaticShardedDimsOffsets().size() > 0) {
+    return emitError() << "sharded dims offsets are not allowed for "
                           "devices meshes with dynamic shape.";
   }
   return success();
 }
 
+namespace {
+class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+public:
+  using OpRewritePattern<ShardingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShardingOp op,
+                                PatternRewriter &b) const override {
+    auto mixedHalos =
+        getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
+    auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
+                                    op.getDynamicShardedDimsOffsets(), b);
+
+    // No constant operands were folded, just return;
+    if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
+      return failure();
+    }
+
+    auto halos = decomposeMixedValues(mixedHalos);
+    auto offs = decomposeMixedValues(mixedOffs);
+
+    op.setStaticHaloSizes(halos.first);
+    op.getDynamicHaloSizesMutable().assign(halos.second);
+    op.setStaticShardedDimsOffsets(offs.first);
+    op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
+
+    return success();
+  }
+};
+} // namespace
+
+void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+                                             mlir::MLIRContext *context) {
+  results.add<FoldDynamicLists>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MeshSharding
 //===----------------------------------------------------------------------===//
@@ -555,7 +592,12 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
     return false;
   }
 
-  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+  if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
+      (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
+      !llvm::equal(
+          llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
+          llvm::make_range(rhs.getPartialAxes().begin(),
+                           rhs.getPartialAxes().end()))) {
     return false;
   }
 
@@ -576,6 +618,31 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
 }
 
 bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
+  return equalShardSizes(rhs) && equalHaloSizes(rhs);
+}
+
+bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
+  if (rhs.getStaticShardedDimsOffsets().size() !=
+          getStaticShardedDimsOffsets().size() ||
+      !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
+                                    getStaticShardedDimsOffsets().end()),
+                   llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
+                                    rhs.getStaticShardedDimsOffsets().end()))) {
+    return false;
+  }
+  if (rhs.getDynamicShardedDimsOffsets().size() !=
+          getDynamicShardedDimsOffsets().size() ||
+      !llvm::equal(
+          llvm::make_range(getDynamicShardedDimsOffsets().begin(),
+                   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/114238


More information about the Mlir-commits mailing list