[Mlir-commits] [mlir] ffc7fea - [mlir][mesh] Handling changed halo region sizes during spmdization (#114238)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 10 21:49:54 PST 2024


Author: Frank Schlimbach
Date: 2024-11-10T21:49:50-08:00
New Revision: ffc7feadece139c88f0e6930f16bfa9293747adc

URL: https://github.com/llvm/llvm-project/commit/ffc7feadece139c88f0e6930f16bfa9293747adc
DIFF: https://github.com/llvm/llvm-project/commit/ffc7feadece139c88f0e6930f16bfa9293747adc.diff

LOG: [mlir][mesh] Handling changed halo region sizes during spmdization (#114238)

* Changed `MeshSharding::sharded_dims_sizes` from representing sizes per
  shard to offsets to origin per shard.
  - Local shard size are now a simple subtraction
  - Offsets are now readily available without a reduction operation
  - Enables constant value/shape propagation through standard
    canonicalization
  - Renamed to `sharded_dims_offsets` accordingly.
* First spmdization pattern for halo regions.
  - Triggers when source and destination shardings differ only in their
    halo sizes
- Copies local data from source into a new tensor and calls update_halo
- Supports arbitrary mesh dimensions (unlike the other patterns which
  work on 1d meshes only)
* `UpdateHaloOp` implements `DestinationStyleOpInterface` and accepts
  tensors and memrefs
  - also accepts target and source halo sizes; both are required for
    proper lowering
* minor refactoring for testing partial MeshSharding equality
* Canonicalization for ShardingOp folding constant values into
  respective `static_*` attributes

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
    mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
    mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
    mlir/test/Dialect/Mesh/canonicalization.mlir
    mlir/test/Dialect/Mesh/invalid.mlir
    mlir/test/Dialect/Mesh/ops.mlir
    mlir/test/Dialect/Mesh/spmdization.mlir
    mlir/test/Dialect/Tensor/mesh-spmdization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index db7b64fda57d7b..75cb096130ca6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/MathExtras.h"
@@ -45,9 +46,9 @@ class MeshSharding {
   SmallVector<MeshAxis> partial_axes;
   ReductionKind partial_type;
   SmallVector<int64_t> static_halo_sizes;
-  SmallVector<int64_t> static_sharded_dims_sizes;
+  SmallVector<int64_t> static_sharded_dims_offsets;
   SmallVector<Value> dynamic_halo_sizes;
-  SmallVector<Value> dynamic_sharded_dims_sizes;
+  SmallVector<Value> dynamic_sharded_dims_offsets;
 
 public:
   MeshSharding() = default;
@@ -57,21 +58,21 @@ class MeshSharding {
                           ArrayRef<MeshAxis> partial_axes_ = {},
                           ReductionKind partial_type_ = ReductionKind::Sum,
                           ArrayRef<int64_t> static_halo_sizes_ = {},
-                          ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+                          ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
                           ArrayRef<Value> dynamic_halo_sizes_ = {},
-                          ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+                          ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
   ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
   ::llvm::StringRef getMesh() const { return mesh.getValue(); }
   ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
   ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
   ReductionKind getPartialType() const { return partial_type; }
   ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
-  ArrayRef<int64_t> getStaticShardedDimsSizes() const {
-    return static_sharded_dims_sizes;
+  ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
+    return static_sharded_dims_offsets;
   }
   ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
-  ArrayRef<Value> getDynamicShardedDimsSizes() const {
-    return dynamic_sharded_dims_sizes;
+  ArrayRef<Value> getDynamicShardedDimsOffsets() const {
+    return dynamic_sharded_dims_offsets;
   }
   operator bool() const { return (!mesh) == false; }
   bool operator==(Value rhs) const;
@@ -80,6 +81,8 @@ class MeshSharding {
   bool operator!=(const MeshSharding &rhs) const;
   bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
   bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
+  bool equalHaloSizes(const MeshSharding &rhs) const;
+  bool equalShardSizes(const MeshSharding &rhs) const;
 };
 
 } // namespace mesh

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..19498fe5a32d69 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/Mesh/IR/MeshBase.td"
 include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -189,23 +190,27 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     `generic`: is not an allowed value inside a shard attribute.
 
     5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
-    `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
-    `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
-    halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
-    `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
-    e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
-    `?` indicates dynamic halo sizes.
+    `halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
+    sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension
+    gets an additional halo of size 1 at the start of the first dimension and a halo
+    size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
+    sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
+    seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
+    
+    6. [Optional] Offsets for each shard and sharded tensor dimension.
+    `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
+    sharded tensor dimension the offsets (starting index) of all shards in that
+    dimension and an additional value for the end of the last shard are provided.
+    For a 1d sharding this means that position `i` has the exclusive prefix sum for
+    shard `i`, and since only contiguous sharding is supported, its inclusive prefix
+    sum is at position 'i+1'.
     
-    6. [Optional] Sizes of sharded dimensions of each shard.
-    `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
-    device-mesh one value for each sharded tensor dimension.
     Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
-    `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
-    the device-mesh will get a shard of shape 16x8x32 and the second device will get a
-    shard of shape 16x24x32.
-    `?` indicates dynamic shard dimensions.
+    `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
+    the device-mesh will get a shard of shape 24x20x32 and the second device will get
+    a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
     
-    `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+    `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
 
     Examples:
 
@@ -240,7 +245,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
     // and it has pre-defined shard sizes. The shards of the devices will have
     // the following shapes: [4x2, 4x3, 4x4, 4x5]
-    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
+    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
     %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];
@@ -250,8 +255,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     Mesh_MeshAxesArrayAttr:$split_axes,
     OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
     OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
-    Variadic<I64>:$dynamic_sharded_dims_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
+    Variadic<I64>:$dynamic_sharded_dims_offsets,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
     Variadic<I64>:$dynamic_halo_sizes
   );
@@ -263,7 +268,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     `split_axes` `=` $split_axes
     (`partial` `=` $partial_type $partial_axes^)?
     (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
-    (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+    (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
     attr-dict `:` type($result)
   }];
   let builders = [
@@ -272,16 +277,17 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
                    "ArrayRef<MeshAxis>":$partial_axes,
                    "mesh::ReductionKind":$partial_type,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
-                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
-                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
+                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
     OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
   ];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
@@ -1052,37 +1058,54 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 }
 
 def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DestinationStyleOpInterface,
+  TypesMatchWith<
+    "result has same type as destination",
+    "result", "destination", "$_self">,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  AttrSizedOperandSegments
 ]> {
   let summary = "Update halo data.";
   let description = [{
     This operation updates halo regions of shards, e.g. if their sharding
-    specified halos and the actual tensor data might have changed
+    specified halos and the actual tensor/memref data might have changed
     on the remote devices. Changes might be caused by mutating operations
     and/or if the new halo regions are larger than the existing ones.
 
+    Source and destination might have 
diff erent halo sizes.
+
     Assumes all devices hold tensors with same-sized halo data as specified
-    by `dynamic/static_halo_sizes`.
+    by `source_halo_sizes/static_source_halo_sizes` and
+    `destination_halo_sizes/static_destination_halo_sizes` in source shard
+    and destination/result shard.
 
     `split_axes` specifies for each tensor axis along which mesh axes its halo
     data is updated.
 
-    Optionally resizes to new halo sizes `target_halo_sizes`.
   }];
   let arguments = (ins
-    AnyNon0RankedMemRef:$input,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
     FlatSymbolRefAttr:$mesh,
     Mesh_MeshAxesArrayAttr:$split_axes,
-    Variadic<I64>:$dynamic_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
+    Variadic<I64>:$source_halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
+    Variadic<I64>:$destination_halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
+  );
+  let results = (outs
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh
+    $source `into` $destination
+    `on` $mesh
     `split_axes` `=` $split_axes
-    (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
-    (`target_halo_sizes` `=` $target_halo_sizes^)?
-    attr-dict `:` type($input)
+    (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
+    (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
+    attr-dict `:` type($source) `->` type($result)
+  }];
+  let extraClassDeclaration = [{
+    MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
   }];
 }
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD

diff  --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 19e9212157ae47..c5570d8ee8a443 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -192,33 +192,34 @@ template <typename InShape, typename MeshShape, typename SplitAxes,
           typename OutShape>
 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                        const SplitAxes &splitAxes, OutShape &outShape,
-                       ArrayRef<int64_t> shardedDimsSizes = {},
+                       ArrayRef<int64_t> shardedDimsOffsets = {},
                        ArrayRef<int64_t> haloSizes = {}) {
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
 
-  if (!shardedDimsSizes.empty()) {
+  if (!shardedDimsOffsets.empty()) {
+    auto isDynShape = ShapedType::isDynamicShape(meshShape);
+    uint64_t pos = 1;
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-      if (innerSplitAxes.empty()) {
-#ifndef NDEBUG
-        for (auto dimSz : shardedDimsSizes) {
-          auto inAxis = dimSz % inShape.size();
-          assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
-                 inShape[inAxis] == ShapedType::kDynamic);
-        }
-#endif // NDEBUG
-      } else {
-        // find sharded dims in sharded_dims_sizes with same static size on
-        // all devices. Use kDynamic for dimensions with dynamic or non-uniform
-        // sizes in sharded_dims_sizes.
-        auto sz = shardedDimsSizes[tensorAxis];
-        bool same = true;
-        for (size_t i = tensorAxis + inShape.size();
-             i < shardedDimsSizes.size(); i += inShape.size()) {
-          if (shardedDimsSizes[i] != sz) {
-            same = false;
-            break;
+      if (!innerSplitAxes.empty()) {
+        auto sz = shardedDimsOffsets[pos];
+        bool same = !isDynShape;
+        if (same) {
+          // Find sharded dims in shardedDimsOffsets with same static size on
+          // all devices. Use kDynamic for dimensions with dynamic or
+          // non-uniform offs in shardedDimsOffsets.
+          uint64_t numShards = 0;
+          for (auto i : innerSplitAxes.asArrayRef()) {
+            numShards += meshShape[i];
+          }
+          for (size_t i = 1; i < numShards; ++i) {
+            if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
+                sz) {
+              same = false;
+              break;
+            }
           }
+          pos += numShards + 1;
         }
         outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
       }
@@ -255,7 +256,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
   shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
-             resShapeArr, sharding.getStaticShardedDimsSizes(),
+             resShapeArr, sharding.getStaticShardedDimsOffsets(),
              sharding.getStaticHaloSizes());
   return shape.clone(resShapeArr);
 }
@@ -432,13 +433,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        ArrayRef<MeshAxis> partial_axes,
                        mesh::ReductionKind partial_type,
                        ArrayRef<int64_t> static_halo_sizes,
-                       ArrayRef<int64_t> static_sharded_dims_sizes) {
+                       ArrayRef<int64_t> static_sharded_dims_offsets) {
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(),
+                                     static_sharded_dims_offsets),
       {});
 }
 
@@ -455,11 +457,11 @@ void ShardingOp::build(
     ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
     FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
     ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
-    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
   mlir::SmallVector<int64_t> staticHalos, staticDims;
   mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
   dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
-  dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+  dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
@@ -477,10 +479,10 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
             : b.getDenseI16ArrayAttr(from.getPartialAxes()),
         ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
                                              from.getPartialType()),
-        from.getStaticShardedDimsSizes().empty()
+        from.getStaticShardedDimsOffsets().empty()
             ? DenseI64ArrayAttr()
-            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
-        from.getDynamicShardedDimsSizes(),
+            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
+        from.getDynamicShardedDimsOffsets(),
         from.getStaticHaloSizes().empty()
             ? DenseI64ArrayAttr()
             : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
@@ -509,8 +511,8 @@ LogicalResult ShardingOp::verify() {
       failed(checkMeshAxis(getPartialAxes().value())))
     return failure();
 
-  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
-    return emitOpError("halo sizes and shard shapes are mutually exclusive");
+  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
+    return emitOpError("halo sizes and shard offsets are mutually exclusive");
   }
 
   if (!getStaticHaloSizes().empty()) {
@@ -539,13 +541,81 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
   if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
-      getStaticShardedDimsSizes().size() > 0) {
-    return emitError() << "sharded dims sizes are not allowed for "
+      getStaticShardedDimsOffsets().size() > 0) {
+    return emitError() << "sharded dims offsets are not allowed for "
                           "devices meshes with dynamic shape.";
   }
+
+  auto shardedDimsOffsets = getStaticShardedDimsOffsets();
+  if (!shardedDimsOffsets.empty()) {
+    auto meshShape = mesh.value().getShape();
+    assert(!ShapedType::isDynamicShape(meshShape));
+    uint64_t pos = 0;
+    for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
+      if (!innerSplitAxes.empty()) {
+        int64_t numShards = 0, off = 0;
+        for (auto i : innerSplitAxes.asArrayRef()) {
+          numShards += meshShape[i];
+        }
+        for (int64_t i = 0; i <= numShards; ++i) {
+          if (shardedDimsOffsets.size() <= pos + i) {
+            return emitError() << "sharded dims offsets has wrong size.";
+          }
+          if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
+            if (shardedDimsOffsets[pos + i] < off) {
+              return emitError()
+                     << "sharded dims offsets must be non-decreasing.";
+            }
+            off = shardedDimsOffsets[pos + i];
+          }
+        }
+        pos += numShards + 1;
+      }
+    }
+  }
   return success();
 }
 
+namespace {
+// Sharding annotations "halo sizes" and "sharded dims offsets"
+// are a mix of attributes and dynamic values. This canonicalization moves
+// constant values to the respective attribute lists and so minimizes the number
+// of values.
+class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+public:
+  using OpRewritePattern<ShardingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShardingOp op,
+                                PatternRewriter &b) const override {
+    auto mixedHalos =
+        getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
+    auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
+                                    op.getDynamicShardedDimsOffsets(), b);
+
+    // No constant operands were folded, just return;
+    if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
+      return failure();
+    }
+
+    auto halos = decomposeMixedValues(mixedHalos);
+    auto offs = decomposeMixedValues(mixedOffs);
+
+    op.setStaticHaloSizes(halos.first);
+    op.getDynamicHaloSizesMutable().assign(halos.second);
+    op.setStaticShardedDimsOffsets(offs.first);
+    op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
+
+    return success();
+  }
+};
+} // namespace
+
+void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+                                             mlir::MLIRContext *context) {
+  results.add<FoldDynamicLists>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MeshSharding
 //===----------------------------------------------------------------------===//
@@ -555,7 +625,12 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
     return false;
   }
 
-  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+  if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
+      (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
+      !llvm::equal(
+          llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
+          llvm::make_range(rhs.getPartialAxes().begin(),
+                           rhs.getPartialAxes().end()))) {
     return false;
   }
 
@@ -576,6 +651,31 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
 }
 
 bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
+  return equalShardSizes(rhs) && equalHaloSizes(rhs);
+}
+
+bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
+  if (rhs.getStaticShardedDimsOffsets().size() !=
+          getStaticShardedDimsOffsets().size() ||
+      !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
+                                    getStaticShardedDimsOffsets().end()),
+                   llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
+                                    rhs.getStaticShardedDimsOffsets().end()))) {
+    return false;
+  }
+  if (rhs.getDynamicShardedDimsOffsets().size() !=
+          getDynamicShardedDimsOffsets().size() ||
+      !llvm::equal(
+          llvm::make_range(getDynamicShardedDimsOffsets().begin(),
+                           getDynamicShardedDimsOffsets().end()),
+          llvm::make_range(rhs.getDynamicShardedDimsOffsets().begin(),
+                           rhs.getDynamicShardedDimsOffsets().end()))) {
+    return false;
+  }
+  return true;
+}
+
+bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
   if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
       !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
                                     getStaticHaloSizes().end()),
@@ -583,28 +683,13 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
                                     rhs.getStaticHaloSizes().end()))) {
     return false;
   }
-  if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size() ||
-      !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
-                                    getStaticShardedDimsSizes().end()),
-                   llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
-                                    rhs.getStaticShardedDimsSizes().end()))) {
-    return false;
-  }
-  if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size() ||
+  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
       !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
                                     getDynamicHaloSizes().end()),
                    llvm::make_range(rhs.getDynamicHaloSizes().begin(),
                                     rhs.getDynamicHaloSizes().end()))) {
     return false;
   }
-  if (rhs.getDynamicShardedDimsSizes().size() !=
-          getDynamicShardedDimsSizes().size() ||
-      !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
-                                    getDynamicShardedDimsSizes().end()),
-                   llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(),
-                                    rhs.getDynamicShardedDimsSizes().end()))) {
-    return false;
-  }
   return true;
 }
 
@@ -629,9 +714,9 @@ MeshSharding::MeshSharding(Value rhs) {
               shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
               shardingOp.getPartialType().value_or(ReductionKind::Sum),
               shardingOp.getStaticHaloSizes(),
-              shardingOp.getStaticShardedDimsSizes(),
+              shardingOp.getStaticShardedDimsOffsets(),
               SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
-              SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
+              SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
 }
 
 MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
@@ -639,9 +724,9 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
                                ArrayRef<MeshAxis> partial_axes_,
                                ReductionKind partial_type_,
                                ArrayRef<int64_t> static_halo_sizes_,
-                               ArrayRef<int64_t> static_sharded_dims_sizes_,
+                               ArrayRef<int64_t> static_sharded_dims_offsets_,
                                ArrayRef<Value> dynamic_halo_sizes_,
-                               ArrayRef<Value> dynamic_sharded_dims_sizes_) {
+                               ArrayRef<Value> dynamic_sharded_dims_offsets_) {
   MeshSharding res;
   res.mesh = mesh_;
   res.split_axes.resize(split_axes_.size());
@@ -658,9 +743,9 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
   clone(partial_axes_, res.partial_axes);
   res.partial_type = partial_type_;
   clone(static_halo_sizes_, res.static_halo_sizes);
-  clone(static_sharded_dims_sizes_, res.static_sharded_dims_sizes);
+  clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
   clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
-  clone(dynamic_sharded_dims_sizes_, res.dynamic_sharded_dims_sizes);
+  clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
 
   return res;
 }

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index fdfed39972fd52..b4d088cbd7088d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -126,7 +126,7 @@ static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
 }
 
 // Split a replicated tensor along a mesh axis.
-// e.g. [[0, 1]] -> [[0, 1, 2]].
+// E.g. [[0, 1]] -> [[0, 1, 2]].
 // Returns the spmdized target value with its sharding.
 static std::tuple<TypedValue<ShapedType>, MeshSharding>
 splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
@@ -429,6 +429,85 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   return std::nullopt;
 }
 
+// Detect a change in the halo size (only) and create necessary operations if
+// needed. A changed halo sizes requires copying the "core" of the source tensor
+// into the "core" of the destination tensor followed by an update halo
+// operation.
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
+tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
+                          MeshSharding sourceSharding,
+                          MeshSharding targetSharding,
+                          ShapedType sourceUnshardedShape,
+                          TypedValue<ShapedType> sourceShard) {
+  // Currently handles only cases where halo sizes 
diff er but everything else
+  // stays the same (from source to destination sharding).
+  if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
+      !sourceSharding.getPartialAxes().empty() ||
+      !targetSharding.getPartialAxes().empty() ||
+      !sourceSharding.getStaticShardedDimsOffsets().empty() ||
+      !targetSharding.getStaticShardedDimsOffsets().empty() ||
+      sourceSharding.equalHaloSizes(targetSharding)) {
+    return std::nullopt;
+  }
+
+  auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
+  auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
+  assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
+  assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
+          !ShapedType::isDynamicShape(tgtHaloSizes) &&
+          sourceShard.getType().hasStaticShape()) &&
+         "dynamic shapes/halos are not supported yet for mesh-spmdization");
+  auto rank = sourceShard.getType().getRank();
+  auto splitAxes = sourceSharding.getSplitAxes();
+  SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
+      strides(rank, 1), outShape(sourceShard.getType().getShape()),
+      coreShape(sourceShard.getType().getShape());
+
+  // Determine "core" of source and destination.
+  // The core is the local part of the shard excluding halo regions.
+  for (auto i = 0u; i < rank; ++i) {
+    if (i < splitAxes.size() && !splitAxes[i].empty()) {
+      if (!srcHaloSizes.empty()) {
+        coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
+        srcCoreOffs[i] = srcHaloSizes[i * 2];
+      }
+      tgtCoreOffs[i] = tgtHaloSizes[i * 2];
+      outShape[i] =
+          coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
+    }
+  }
+
+  // Extract core from source and copy into destination core.
+  auto noVals = ValueRange{};
+  auto initVal = builder.create<tensor::EmptyOp>(
+      sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
+  auto core = builder.create<tensor::ExtractSliceOp>(
+      sourceShard.getLoc(),
+      RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
+      sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
+  auto initOprnd = builder.create<tensor::InsertSliceOp>(
+      sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
+      coreShape, strides);
+
+  // Finally update the halo.
+  auto updateHaloResult =
+      builder
+          .create<UpdateHaloOp>(
+              sourceShard.getLoc(),
+              RankedTensorType::get(outShape,
+                                    sourceShard.getType().getElementType()),
+              sourceShard, initOprnd, mesh.getSymName(),
+              MeshAxesArrayAttr::get(builder.getContext(),
+                                     sourceSharding.getSplitAxes()),
+              sourceSharding.getDynamicHaloSizes(),
+              sourceSharding.getStaticHaloSizes(),
+              targetSharding.getDynamicHaloSizes(),
+              targetSharding.getStaticHaloSizes())
+          .getResult();
+  return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
+                         targetSharding);
+}
+
 // Handles only resharding on a 1D mesh.
 // Currently the sharded tensor axes must be exactly divisible by the single
 // mesh axis size.
@@ -454,10 +533,10 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
 
   TypedValue<ShapedType> targetShard;
   MeshSharding actualTargetSharding;
-  if (reducedSourceSharding.getStaticHaloSizes().empty() &&
-      targetSharding.getStaticHaloSizes().empty() &&
-      reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
-      targetSharding.getStaticShardedDimsSizes().empty()) {
+  if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
+      targetSharding.getStaticShardedDimsOffsets().empty() &&
+      reducedSourceSharding.getStaticHaloSizes().empty() &&
+      targetSharding.getStaticHaloSizes().empty()) {
     if (auto tryRes = tryMoveLastSplitAxisInResharding(
             builder, mesh, reducedSourceSharding, targetSharding,
             sourceUnshardedValue.getType(), reducedSourceShard)) {
@@ -483,6 +562,19 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                MeshSharding targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
+  // If source and destination sharding are the same, no need to do anything.
+  if (sourceSharding == targetSharding) {
+    return sourceShard;
+  }
+
+  // Tries to handle the case where the resharding is needed because the halo
+  // sizes are 
diff erent. Supports arbitrary mesh dimensionality.
+  if (auto tryRes = tryUpdateHaloInResharding(
+          builder, mesh, sourceSharding, targetSharding,
+          sourceUnshardedValue.getType(), sourceShard)) {
+    return std::get<0>(tryRes.value()); // targetShard
+  }
+
   // Resort to handling only 1D meshes since the general case is complicated if
   // it needs to be communication efficient in terms of minimizing the data
   // transfered between devices.
@@ -636,8 +728,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
     targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
   } else {
     // Insert resharding.
-    TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
-        spmdizationMap.lookup(srcShardOp.getSrc()));
+    TypedValue<ShapedType> srcSpmdValue =
+        cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
                               symbolTableCollection);
   }

diff  --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index ea2bd29056ec78..f0112d689805d3 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -191,3 +191,20 @@ func.func @send_empty_mesh_axes(
 // CHECK: return %[[ARG]]
   return %0 : tensor<4xf32>
 }
+
+mesh.mesh @mesh4x4(shape = 4x4)
+// CHECK-LABEL: func @test_halo_sizes
+func.func @test_halo_sizes() -> !mesh.sharding {
+  %c2_i64 = arith.constant 2 : i64
+  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
+  return %sharding : !mesh.sharding
+}
+
+// CHECK-LABEL: func @test_shard_offs
+func.func @test_shard_offs() -> !mesh.sharding {
+  %c2_i64 = arith.constant 2 : i64
+  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
+  return %sharding : !mesh.sharding
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3827df90e6962f..29b900a8da4a60 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -89,8 +89,8 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
 // -----
 
 func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
-  // expected-error at +1 {{halo sizes and shard shapes are mutually exclusive}}
-  %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_sizes = [2, 2] : !mesh.sharding
+  // expected-error at +1 {{halo sizes and shard offsets are mutually exclusive}}
+  %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
@@ -99,8 +99,28 @@ func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
 
 mesh.mesh @mesh_dyn(shape = ?x?)
 func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
-  // expected-error at +1 {{sharded dims sizes are not allowed for devices meshes with dynamic shape}}
-  %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_sizes = [2, 2] : !mesh.sharding
+  // expected-error at +1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}}
+  %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  return
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 2x4)
+func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) {
+  // expected-error at +1 {{sharded dims offsets has wrong size}}
+  %s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  return
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 4)
+func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) {
+  // expected-error at +1 {{sharded dims offsets must be non-decreasing}}
+  %s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return
 }

diff  --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 5ead7babe2c084..d8df01c3d6520d 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -144,10 +144,10 @@ func.func @mesh_shard_halo_sizes() -> () {
 func.func @mesh_shard_dims_sizes() -> () {
   // CHECK: %[[C3:.*]] = arith.constant 3 : i64
   %c3 = arith.constant 3 : i64
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
-  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [4, %[[C3]], 1] : !mesh.sharding
-  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [4, %c3, 1] : !mesh.sharding
+  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
+  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
+  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !mesh.sharding
+  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !mesh.sharding
   return
 }
 
@@ -615,18 +615,16 @@ func.func @update_halo(
     // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
   // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] into %[[ARG]] on @mesh0
   // CHECK-SAME: split_axes = {{\[\[}}0]]
-  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
+  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> -> memref<12x12xi8>
   %c2 = arith.constant 2 : i64
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
-    halo_sizes = [2, %c2] : memref<12x12xi8>
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  %uh1 = mesh.update_halo %arg0 into %arg0 on @mesh0 split_axes = [[0]]
+    source_halo_sizes = [2, %c2] : memref<12x12xi8> -> memref<12x12xi8>
+  // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[ARG]] into %[[UH1]] on @mesh0
   // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
-  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
-  // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
-    halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
-    : memref<12x12xi8>
+  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> -> memref<12x12xi8>
+  %uh2 = mesh.update_halo %arg0 into %uh1 on @mesh0 split_axes = [[0], [1]]
+    source_halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> -> memref<12x12xi8>
   return
 }

diff  --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 8b0c4053b0dc7e..22ddb72569835d 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -219,3 +219,34 @@ func.func @ew_chain_with_halo(
   // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
   return %sharding_annotated_6 : tensor<8x16xf32>
 }
+
+// CHECK-LABEL: func @test_shard_update_halo
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
+func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
+  // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
+  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
+  // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64>
+  %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
+  %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
+  %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
+  %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+  // CHECK: return %[[UH]] : tensor<304x1200xi64>
+  return %sharding_annotated_3 : tensor<1200x1200xi64>
+}
+
+mesh.mesh @mesh4x4(shape = 4x4)
+// CHECK-LABEL: func @test_shard_update_halo2d
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
+func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
+  // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
+  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
+  // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64>
+  %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
+  %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
+  %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
+  %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+  // CHECK: return %[[UH]] : tensor<303x307xi64>
+  return %sharding_annotated_3 : tensor<1200x1200xi64>
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 611acb5b41445b..5443eea83aa2d8 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -4,12 +4,12 @@
 
 mesh.mesh @mesh_1d_4(shape = 4)
 
-// CHECK-LABEL: func @tensor_empty_static_sharded_dims_sizes
-func.func @tensor_empty_static_sharded_dims_sizes() -> () {
+// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
+func.func @tensor_empty_static_sharded_dims_offsets() -> () {
   %b = tensor.empty() : tensor<8x16xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
-  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
   // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
@@ -17,13 +17,13 @@ func.func @tensor_empty_static_sharded_dims_sizes() -> () {
   return
 }
 
-// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_sizes
+// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
 // CHECK-SAME: %[[A0:.*]]: index
-func.func @tensor_empty_dynamic_sharded_dims_sizes(%arg0 : index) -> () {
+func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
   %b = tensor.empty(%arg0) : tensor<8x?xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
-  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
   // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
@@ -33,9 +33,9 @@ func.func @tensor_empty_dynamic_sharded_dims_sizes(%arg0 : index) -> () {
 
 // CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
 func.func @tensor_empty_same_static_dims_sizes() -> () {
-  %b = tensor.empty() : tensor<8x16xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [4, 4, 4, 4] : !mesh.sharding
-  %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
+  %b = tensor.empty() : tensor<16x16xf32>
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !mesh.sharding
+  %sharded= mesh.shard %b to %sharding : tensor<16x16xf32>
   // CHECK-NEXT:  tensor.empty() : tensor<4x16xf32>
 
   return


        


More information about the Mlir-commits mailing list