[Mlir-commits] [mlir] [mlir][mesh] Handling changed halo region sizes during spmdization (PR #114238)

Frank Schlimbach llvmlistbot at llvm.org
Wed Oct 30 07:38:54 PDT 2024


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/114238

>From 0133634490675c26b548f00372e4c8e6d0b375f4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 3 Sep 2024 09:29:35 +0200
Subject: [PATCH 1/9] cut&paste error

---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 19e9212157ae47..e0ecbee7ee11f5 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -583,14 +583,14 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
                                     rhs.getStaticHaloSizes().end()))) {
     return false;
   }
-  if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size() ||
+  if (rhs.getStaticShardedDimsSizes().size() != getStaticShardedDimsSizes().size() ||
       !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
                                     getStaticShardedDimsSizes().end()),
                    llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
                                     rhs.getStaticShardedDimsSizes().end()))) {
     return false;
   }
-  if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size() ||
+  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
       !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
                                     getDynamicHaloSizes().end()),
                    llvm::make_range(rhs.getDynamicHaloSizes().begin(),

>From e99596fe154f921003e6acd4b1e3d957c9d05208 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 10 Oct 2024 12:09:44 +0200
Subject: [PATCH 2/9] initial spmdizeation pattern for halo updates

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  3 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 32 +++++---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 35 +++++----
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 77 +++++++++++++++++--
 4 files changed, 116 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index db7b64fda57d7b..d1cab4d819a0e0 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/MathExtras.h"
@@ -80,6 +81,8 @@ class MeshSharding {
   bool operator!=(const MeshSharding &rhs) const;
   bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
   bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
+  bool equalHaloSizes(const MeshSharding &rhs) const;
+  bool equalShardSizes(const MeshSharding &rhs) const;
 };
 
 } // namespace mesh
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..0c5606a94a6dc7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/Mesh/IR/MeshBase.td"
 include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -1052,6 +1053,10 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 }
 
 def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
+  DestinationStyleOpInterface,
+  TypesMatchWith<
+    "result has same type as destination",
+    "result", "destination", "$_self">,
   DeclareOpInterfaceMethods<SymbolUserOpInterface>
 ]> {
   let summary = "Update halo data.";
@@ -1062,27 +1067,34 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     and/or if the new halo regions are larger than the existing ones.
 
     Assumes all devices hold tensors with same-sized halo data as specified
-    by `dynamic/static_halo_sizes`.
+    by `source_halo_sizes/static_source_halo_sizes`.
 
     `split_axes` specifies for each tensor axis along which mesh axes its halo
     data is updated.
 
-    Optionally resizes to new halo sizes `target_halo_sizes`.
+    The destination halo sizes are allowed differ from the source sizes. The sizes
+    of the inner (local) shard is inferred from the destination shape and source sharding.
   }];
   let arguments = (ins
-    AnyNon0RankedMemRef:$input,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
     FlatSymbolRefAttr:$mesh,
     Mesh_MeshAxesArrayAttr:$split_axes,
-    Variadic<I64>:$dynamic_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
+    Variadic<I64>:$source_halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes
+  );
+  let results = (outs
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh
+    $source `into` $destination
+    `on` $mesh
     `split_axes` `=` $split_axes
-    (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
-    (`target_halo_sizes` `=` $target_halo_sizes^)?
-    attr-dict `:` type($input)
+    (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
+    attr-dict `:` type($source) `->` type($result)
+  }];
+  let extraClassDeclaration = [{
+    MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
   }];
 }
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index e0ecbee7ee11f5..29356a9de73f56 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -576,13 +576,10 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
 }
 
 bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
-  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
-      !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
-                                    getStaticHaloSizes().end()),
-                   llvm::make_range(rhs.getStaticHaloSizes().begin(),
-                                    rhs.getStaticHaloSizes().end()))) {
-    return false;
-  }
+  return equalShardSizes(rhs) && equalHaloSizes(rhs);
+}
+
+bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
   if (rhs.getStaticShardedDimsSizes().size() != getStaticShardedDimsSizes().size() ||
       !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
                                     getStaticShardedDimsSizes().end()),
@@ -590,13 +587,6 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
                                     rhs.getStaticShardedDimsSizes().end()))) {
     return false;
   }
-  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
-      !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
-                                    getDynamicHaloSizes().end()),
-                   llvm::make_range(rhs.getDynamicHaloSizes().begin(),
-                                    rhs.getDynamicHaloSizes().end()))) {
-    return false;
-  }
   if (rhs.getDynamicShardedDimsSizes().size() !=
           getDynamicShardedDimsSizes().size() ||
       !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
@@ -605,6 +595,23 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
                                     rhs.getDynamicShardedDimsSizes().end()))) {
     return false;
   }
+}
+
+bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
+  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
+      !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
+                                    getStaticHaloSizes().end()),
+                   llvm::make_range(rhs.getStaticHaloSizes().begin(),
+                                    rhs.getStaticHaloSizes().end()))) {
+    return false;
+  }
+  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
+      !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
+                                    getDynamicHaloSizes().end()),
+                   llvm::make_range(rhs.getDynamicHaloSizes().begin(),
+                                    rhs.getDynamicHaloSizes().end()))) {
+    return false;
+  }
   return true;
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index fdfed39972fd52..8635890f9c5afd 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -429,6 +429,62 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   return std::nullopt;
 }
 
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
+tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
+                          MeshSharding sourceSharding,
+                          MeshSharding targetSharding,
+                          ShapedType sourceUnshardedShape,
+                          TypedValue<ShapedType> sourceShard) {
+  if (sourceSharding.equalSplitAndPartialAxes(targetSharding) &&
+      sourceSharding.getPartialAxes().empty() &&
+      targetSharding.getPartialAxes().empty() &&
+      sourceSharding.getStaticShardedDimsSizes().empty() &&
+      targetSharding.getStaticShardedDimsSizes().empty() &&
+      !sourceSharding.equalHaloSizes(targetSharding)) {
+    auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
+    auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
+    assert(
+        ((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
+         !ShapedType::isDynamicShape(tgtHaloSizes) &&
+         sourceShard.getType().hasStaticShape()) &&
+        "dynamic shapes/halos are not supported yet for mesh-spmdization");
+    auto rank = sourceShard.getType().getRank();
+    SmallVector<int64_t> outShape, srcCoreOffs(rank, 0), tgtCoreOffs,
+        strides(rank, 1), coreShape(sourceShard.getType().getShape());
+    for (auto i = 0u; i < rank; ++i) {
+      if (!srcHaloSizes.empty()) {
+        coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
+        srcCoreOffs[i] = srcHaloSizes[i * 2];
+      }
+      tgtCoreOffs.emplace_back(tgtHaloSizes[i * 2]);
+      outShape.emplace_back(coreShape[i] + tgtHaloSizes[i * 2] +
+                            tgtHaloSizes[i * 2 + 1]);
+    }
+    auto noVals = ValueRange{};
+    auto initVal = builder.create<tensor::EmptyOp>(
+        sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
+    auto core = builder.create<tensor::ExtractSliceOp>(
+        sourceShard.getLoc(),
+        RankedTensorType::get(coreShape,
+                              sourceShard.getType().getElementType()),
+        sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
+    auto initOprnd = builder.create<tensor::InsertSliceOp>(
+        sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
+        tgtCoreOffs, coreShape, strides);
+    auto targetShard = builder.create<UpdateHaloOp>(
+        sourceShard.getLoc(),
+        RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
+        sourceShard, initOprnd, mesh.getSymName(),
+        MeshAxesArrayAttr::get(builder.getContext(),
+                               sourceSharding.getSplitAxes()),
+        sourceSharding.getDynamicHaloSizes(),
+        sourceSharding.getStaticHaloSizes());
+    return std::make_tuple(
+        cast<TypedValue<ShapedType>>(targetShard.getResult()), targetSharding);
+  }
+  return std::nullopt;
+}
+
 // Handles only resharding on a 1D mesh.
 // Currently the sharded tensor axes must be exactly divisible by the single
 // mesh axis size.
@@ -454,10 +510,10 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
 
   TypedValue<ShapedType> targetShard;
   MeshSharding actualTargetSharding;
-  if (reducedSourceSharding.getStaticHaloSizes().empty() &&
-      targetSharding.getStaticHaloSizes().empty() &&
-      reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
-      targetSharding.getStaticShardedDimsSizes().empty()) {
+  if (reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
+      targetSharding.getStaticShardedDimsSizes().empty() &&
+      reducedSourceSharding.getStaticHaloSizes().empty() &&
+      targetSharding.getStaticHaloSizes().empty()) {
     if (auto tryRes = tryMoveLastSplitAxisInResharding(
             builder, mesh, reducedSourceSharding, targetSharding,
             sourceUnshardedValue.getType(), reducedSourceShard)) {
@@ -483,6 +539,12 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                MeshSharding targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
+  if (auto tryRes = tryUpdateHaloInResharding(
+          builder, mesh, sourceSharding, targetSharding,
+          sourceUnshardedValue.getType(), sourceShard)) {
+    return std::get<0>(tryRes.value()); // targetShard
+  }
+
   // Resort to handling only 1D meshes since the general case is complicated if
   // it needs to be communication efficient in terms of minimizing the data
   // transfered between devices.
@@ -497,9 +559,10 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
   auto sourceSharding = source.getSharding();
   auto targetSharding = target.getSharding();
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
-  return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
-                 cast<TypedValue<ShapedType>>(source.getSrc()),
-                 sourceShardValue);
+  auto shard =
+      reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
+              cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
+  return shard;
 }
 
 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,

>From 3c1b4eb829aa44b3d68b1b09b83265bf97576a41 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 15 Oct 2024 17:44:55 +0200
Subject: [PATCH 3/9] update to mesh.update_halo; sharded_dims_sizes ->
 sharded_dims_offsets

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  16 +--
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  28 ++---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 102 +++++++++---------
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |   8 +-
 mlir/test/Dialect/Mesh/invalid.mlir           |   6 +-
 mlir/test/Dialect/Mesh/ops.mlir               |  26 +++--
 .../test/Dialect/Tensor/mesh-spmdization.mlir |  22 ++--
 7 files changed, 106 insertions(+), 102 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index d1cab4d819a0e0..75cb096130ca6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -46,9 +46,9 @@ class MeshSharding {
   SmallVector<MeshAxis> partial_axes;
   ReductionKind partial_type;
   SmallVector<int64_t> static_halo_sizes;
-  SmallVector<int64_t> static_sharded_dims_sizes;
+  SmallVector<int64_t> static_sharded_dims_offsets;
   SmallVector<Value> dynamic_halo_sizes;
-  SmallVector<Value> dynamic_sharded_dims_sizes;
+  SmallVector<Value> dynamic_sharded_dims_offsets;
 
 public:
   MeshSharding() = default;
@@ -58,21 +58,21 @@ class MeshSharding {
                           ArrayRef<MeshAxis> partial_axes_ = {},
                           ReductionKind partial_type_ = ReductionKind::Sum,
                           ArrayRef<int64_t> static_halo_sizes_ = {},
-                          ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+                          ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
                           ArrayRef<Value> dynamic_halo_sizes_ = {},
-                          ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+                          ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
   ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
   ::llvm::StringRef getMesh() const { return mesh.getValue(); }
   ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
   ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
   ReductionKind getPartialType() const { return partial_type; }
   ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
-  ArrayRef<int64_t> getStaticShardedDimsSizes() const {
-    return static_sharded_dims_sizes;
+  ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
+    return static_sharded_dims_offsets;
   }
   ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
-  ArrayRef<Value> getDynamicShardedDimsSizes() const {
-    return dynamic_sharded_dims_sizes;
+  ArrayRef<Value> getDynamicShardedDimsOffsets() const {
+    return dynamic_sharded_dims_offsets;
   }
   operator bool() const { return (!mesh) == false; }
   bool operator==(Value rhs) const;
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 0c5606a94a6dc7..f881f22adcf52a 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -197,16 +197,18 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
     `?` indicates dynamic halo sizes.
     
-    6. [Optional] Sizes of sharded dimensions of each shard.
-    `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
-    device-mesh one value for each sharded tensor dimension.
+    6. [Optional] Offsets for each shard and sharded tensor dimension.
+    `sharded_dims_offsets`is provided as a flattened 1d array of i64s:
+    For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
+    The offset of each first shard is omitted and is implicitly assumed to be 0.
+    The last value per dimension denotes the end of the last shard.
     Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
-    `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
-    the device-mesh will get a shard of shape 16x8x32 and the second device will get a
-    shard of shape 16x24x32.
+    `sharded_dims_offsets` = [24, 32, 20, 32] means that the first device of
+    the device-mesh will get a shard of shape 24x20x32 and the second device will get a
+    shard of shape 8x12x32.
     `?` indicates dynamic shard dimensions.
     
-    `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+    `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
 
     Examples:
 
@@ -241,7 +243,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
     // and it has pre-defined shard sizes. The shards of the devices will have
     // the following shapes: [4x2, 4x3, 4x4, 4x5]
-    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
+    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_offsets = [2, 5, 9, 14]
     %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];
@@ -251,8 +253,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     Mesh_MeshAxesArrayAttr:$split_axes,
     OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
     OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
-    Variadic<I64>:$dynamic_sharded_dims_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
+    Variadic<I64>:$dynamic_sharded_dims_offsets,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
     Variadic<I64>:$dynamic_halo_sizes
   );
@@ -264,7 +266,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     `split_axes` `=` $split_axes
     (`partial` `=` $partial_type $partial_axes^)?
     (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
-    (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+    (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
     attr-dict `:` type($result)
   }];
   let builders = [
@@ -273,13 +275,13 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
                    "ArrayRef<MeshAxis>":$partial_axes,
                    "mesh::ReductionKind":$partial_type,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
-                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
-                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
+                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
     OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
   ];
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 29356a9de73f56..b10f8ac98a27d1 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -192,33 +192,33 @@ template <typename InShape, typename MeshShape, typename SplitAxes,
           typename OutShape>
 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                        const SplitAxes &splitAxes, OutShape &outShape,
-                       ArrayRef<int64_t> shardedDimsSizes = {},
+                       ArrayRef<int64_t> shardedDimsOffsets = {},
                        ArrayRef<int64_t> haloSizes = {}) {
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
 
-  if (!shardedDimsSizes.empty()) {
+  if (!shardedDimsOffsets.empty()) {
+    uint64_t pos = 0;
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-      if (innerSplitAxes.empty()) {
-#ifndef NDEBUG
-        for (auto dimSz : shardedDimsSizes) {
-          auto inAxis = dimSz % inShape.size();
-          assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
-                 inShape[inAxis] == ShapedType::kDynamic);
-        }
-#endif // NDEBUG
-      } else {
-        // find sharded dims in sharded_dims_sizes with same static size on
-        // all devices. Use kDynamic for dimensions with dynamic or non-uniform
-        // sizes in sharded_dims_sizes.
-        auto sz = shardedDimsSizes[tensorAxis];
-        bool same = true;
-        for (size_t i = tensorAxis + inShape.size();
-             i < shardedDimsSizes.size(); i += inShape.size()) {
-          if (shardedDimsSizes[i] != sz) {
-            same = false;
-            break;
+      if (!innerSplitAxes.empty()) {
+        auto sz = shardedDimsOffsets[pos];
+        bool same = !ShapedType::isDynamicShape(meshShape);
+        if (same) {
+          // find sharded dims in shardedDimsOffsets with same static size on
+          // all devices. Use kDynamic for dimensions with dynamic or
+          // non-uniform offs in shardedDimsOffsets.
+          uint64_t numShards = 0;
+          for (auto i : innerSplitAxes.asArrayRef()) {
+            numShards += meshShape[i];
+          }
+          for (size_t i = 1; i < numShards; ++i) {
+            if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
+                sz) {
+              same = false;
+              break;
+            }
           }
+          pos += numShards;
         }
         outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
       }
@@ -255,7 +255,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
   shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
-             resShapeArr, sharding.getStaticShardedDimsSizes(),
+             resShapeArr, sharding.getStaticShardedDimsOffsets(),
              sharding.getStaticHaloSizes());
   return shape.clone(resShapeArr);
 }
@@ -432,13 +432,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        ArrayRef<MeshAxis> partial_axes,
                        mesh::ReductionKind partial_type,
                        ArrayRef<int64_t> static_halo_sizes,
-                       ArrayRef<int64_t> static_sharded_dims_sizes) {
+                       ArrayRef<int64_t> static_sharded_dims_offsets) {
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(),
+                                     static_sharded_dims_offsets),
       {});
 }
 
@@ -455,11 +456,11 @@ void ShardingOp::build(
     ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
     FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
     ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
-    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
   mlir::SmallVector<int64_t> staticHalos, staticDims;
   mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
   dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
-  dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+  dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
@@ -477,10 +478,10 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
             : b.getDenseI16ArrayAttr(from.getPartialAxes()),
         ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
                                              from.getPartialType()),
-        from.getStaticShardedDimsSizes().empty()
+        from.getStaticShardedDimsOffsets().empty()
             ? DenseI64ArrayAttr()
-            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
-        from.getDynamicShardedDimsSizes(),
+            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
+        from.getDynamicShardedDimsOffsets(),
         from.getStaticHaloSizes().empty()
             ? DenseI64ArrayAttr()
             : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
@@ -509,7 +510,7 @@ LogicalResult ShardingOp::verify() {
       failed(checkMeshAxis(getPartialAxes().value())))
     return failure();
 
-  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
+  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
     return emitOpError("halo sizes and shard shapes are mutually exclusive");
   }
 
@@ -539,8 +540,8 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
   if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
-      getStaticShardedDimsSizes().size() > 0) {
-    return emitError() << "sharded dims sizes are not allowed for "
+      getStaticShardedDimsOffsets().size() > 0) {
+    return emitError() << "sharded dims offsets are not allowed for "
                           "devices meshes with dynamic shape.";
   }
   return success();
@@ -580,21 +581,24 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
 }
 
 bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
-  if (rhs.getStaticShardedDimsSizes().size() != getStaticShardedDimsSizes().size() ||
-      !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
-                                    getStaticShardedDimsSizes().end()),
-                   llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
-                                    rhs.getStaticShardedDimsSizes().end()))) {
+  if (rhs.getStaticShardedDimsOffsets().size() !=
+          getStaticShardedDimsOffsets().size() ||
+      !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
+                                    getStaticShardedDimsOffsets().end()),
+                   llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
+                                    rhs.getStaticShardedDimsOffsets().end()))) {
     return false;
   }
-  if (rhs.getDynamicShardedDimsSizes().size() !=
-          getDynamicShardedDimsSizes().size() ||
-      !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
-                                    getDynamicShardedDimsSizes().end()),
-                   llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(),
-                                    rhs.getDynamicShardedDimsSizes().end()))) {
+  if (rhs.getDynamicShardedDimsOffsets().size() !=
+          getDynamicShardedDimsOffsets().size() ||
+      !llvm::equal(
+          llvm::make_range(getDynamicShardedDimsOffsets().begin(),
+                           getDynamicShardedDimsOffsets().end()),
+          llvm::make_range(rhs.getDynamicShardedDimsOffsets().begin(),
+                           rhs.getDynamicShardedDimsOffsets().end()))) {
     return false;
   }
+  return true;
 }
 
 bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
@@ -636,9 +640,9 @@ MeshSharding::MeshSharding(Value rhs) {
               shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
               shardingOp.getPartialType().value_or(ReductionKind::Sum),
               shardingOp.getStaticHaloSizes(),
-              shardingOp.getStaticShardedDimsSizes(),
+              shardingOp.getStaticShardedDimsOffsets(),
               SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
-              SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
+              SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
 }
 
 MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
@@ -646,9 +650,9 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
                                ArrayRef<MeshAxis> partial_axes_,
                                ReductionKind partial_type_,
                                ArrayRef<int64_t> static_halo_sizes_,
-                               ArrayRef<int64_t> static_sharded_dims_sizes_,
+                               ArrayRef<int64_t> static_sharded_dims_offsets_,
                                ArrayRef<Value> dynamic_halo_sizes_,
-                               ArrayRef<Value> dynamic_sharded_dims_sizes_) {
+                               ArrayRef<Value> dynamic_sharded_dims_offsets_) {
   MeshSharding res;
   res.mesh = mesh_;
   res.split_axes.resize(split_axes_.size());
@@ -665,9 +669,9 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
   clone(partial_axes_, res.partial_axes);
   res.partial_type = partial_type_;
   clone(static_halo_sizes_, res.static_halo_sizes);
-  clone(static_sharded_dims_sizes_, res.static_sharded_dims_sizes);
+  clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
   clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
-  clone(dynamic_sharded_dims_sizes_, res.dynamic_sharded_dims_sizes);
+  clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
 
   return res;
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 8635890f9c5afd..6684fbfa0aec0b 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -438,8 +438,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   if (sourceSharding.equalSplitAndPartialAxes(targetSharding) &&
       sourceSharding.getPartialAxes().empty() &&
       targetSharding.getPartialAxes().empty() &&
-      sourceSharding.getStaticShardedDimsSizes().empty() &&
-      targetSharding.getStaticShardedDimsSizes().empty() &&
+      sourceSharding.getStaticShardedDimsOffsets().empty() &&
+      targetSharding.getStaticShardedDimsOffsets().empty() &&
       !sourceSharding.equalHaloSizes(targetSharding)) {
     auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
     auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
@@ -510,8 +510,8 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
 
   TypedValue<ShapedType> targetShard;
   MeshSharding actualTargetSharding;
-  if (reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
-      targetSharding.getStaticShardedDimsSizes().empty() &&
+  if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
+      targetSharding.getStaticShardedDimsOffsets().empty() &&
       reducedSourceSharding.getStaticHaloSizes().empty() &&
       targetSharding.getStaticHaloSizes().empty()) {
     if (auto tryRes = tryMoveLastSplitAxisInResharding(
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3827df90e6962f..f8abc6959a13a4 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -90,7 +90,7 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
 
 func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
   // expected-error at +1 {{halo sizes and shard shapes are mutually exclusive}}
-  %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_sizes = [2, 2] : !mesh.sharding
+  %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [2, 2] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
@@ -99,8 +99,8 @@ func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
 
 mesh.mesh @mesh_dyn(shape = ?x?)
 func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
-  // expected-error at +1 {{sharded dims sizes are not allowed for devices meshes with dynamic shape}}
-  %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_sizes = [2, 2] : !mesh.sharding
+  // expected-error at +1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}}
+  %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [2, 2] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 5ead7babe2c084..3b0ef39a6e212e 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -144,10 +144,10 @@ func.func @mesh_shard_halo_sizes() -> () {
 func.func @mesh_shard_dims_sizes() -> () {
   // CHECK: %[[C3:.*]] = arith.constant 3 : i64
   %c3 = arith.constant 3 : i64
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
-  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [4, %[[C3]], 1] : !mesh.sharding
-  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [4, %c3, 1] : !mesh.sharding
+  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 4, 2] : !mesh.sharding
+  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [1, 4, 2] : !mesh.sharding
+  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [4, %[[C3]], 1] : !mesh.sharding
+  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [4, %c3, 1] : !mesh.sharding
   return
 }
 
@@ -615,18 +615,16 @@ func.func @update_halo(
     // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
   // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] into %[[ARG]] on @mesh0
   // CHECK-SAME: split_axes = {{\[\[}}0]]
-  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
+  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> -> memref<12x12xi8>
   %c2 = arith.constant 2 : i64
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
-    halo_sizes = [2, %c2] : memref<12x12xi8>
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  %uh1 = mesh.update_halo %arg0 into %arg0 on @mesh0 split_axes = [[0]]
+    source_halo_sizes = [2, %c2] : memref<12x12xi8> -> memref<12x12xi8>
+  // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[ARG]] into %[[UH1]] on @mesh0
   // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
-  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
-  // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
-    halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
-    : memref<12x12xi8>
+  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> -> memref<12x12xi8>
+  %uh2 = mesh.update_halo %arg0 into %uh1 on @mesh0 split_axes = [[0], [1]]
+    source_halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> -> memref<12x12xi8>
   return
 }
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 611acb5b41445b..fba9e17b53934a 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -4,12 +4,12 @@
 
 mesh.mesh @mesh_1d_4(shape = 4)
 
-// CHECK-LABEL: func @tensor_empty_static_sharded_dims_sizes
-func.func @tensor_empty_static_sharded_dims_sizes() -> () {
+// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
+func.func @tensor_empty_static_sharded_dims_offsets() -> () {
   %b = tensor.empty() : tensor<8x16xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
-  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
   // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
   // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
@@ -17,13 +17,13 @@ func.func @tensor_empty_static_sharded_dims_sizes() -> () {
   return
 }
 
-// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_sizes
+// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
 // CHECK-SAME: %[[A0:.*]]: index
-func.func @tensor_empty_dynamic_sharded_dims_sizes(%arg0 : index) -> () {
+func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
   %b = tensor.empty(%arg0) : tensor<8x?xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
-  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
   // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
   // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
@@ -33,9 +33,9 @@ func.func @tensor_empty_dynamic_sharded_dims_sizes(%arg0 : index) -> () {
 
 // CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
 func.func @tensor_empty_same_static_dims_sizes() -> () {
-  %b = tensor.empty() : tensor<8x16xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [4, 4, 4, 4] : !mesh.sharding
-  %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
+  %b = tensor.empty() : tensor<16x16xf32>
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [4, 8, 12, 16] : !mesh.sharding
+  %sharded= mesh.shard %b to %sharding : tensor<16x16xf32>
   // CHECK-NEXT:  tensor.empty() : tensor<4x16xf32>
 
   return

>From aa26d2be192f5c7f0823e6c27f7697358db40c60 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 17 Oct 2024 11:57:00 +0200
Subject: [PATCH 4/9] canonicalize (debug) ProcessMultiIndexOp

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  4 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 92 +++++++++++++++++++
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 25 +++--
 mlir/test/Dialect/Mesh/canonicalization.mlir  | 17 ++++
 4 files changed, 129 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index f881f22adcf52a..c967bb74d5c1a8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -132,6 +132,7 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
     OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
@@ -198,7 +199,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     `?` indicates dynamic halo sizes.
     
     6. [Optional] Offsets for each shard and sharded tensor dimension.
-    `sharded_dims_offsets`is provided as a flattened 1d array of i64s:
+    `sharded_dims_offsets` is provided as a flattened 1d array of i64s:
     For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
     The offset of each first shard is omitted and is implicitly assumed to be 0.
     The last value per dimension denotes the end of the last shard.
@@ -285,6 +286,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
   ];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b10f8ac98a27d1..fb952bff077a65 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -547,6 +547,42 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
+namespace {
+class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+public:
+  using OpRewritePattern<ShardingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShardingOp op,
+                                PatternRewriter &b) const override {
+    auto mixedHalos =
+        getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
+    auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
+                                    op.getDynamicShardedDimsOffsets(), b);
+
+    // No constant operands were folded, just return;
+    if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
+      return failure();
+    }
+
+    auto halos = decomposeMixedValues(mixedHalos);
+    auto offs = decomposeMixedValues(mixedOffs);
+
+    op.setStaticHaloSizes(halos.first);
+    op.getDynamicHaloSizesMutable().assign(halos.second);
+    op.setStaticShardedDimsOffsets(offs.first);
+    op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
+
+    return success();
+  }
+};
+} // namespace
+
+void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+                                             mlir::MLIRContext *context) {
+  results.add<FoldDynamicLists>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MeshSharding
 //===----------------------------------------------------------------------===//
@@ -740,6 +776,62 @@ void ProcessMultiIndexOp::getAsmResultNames(
   setNameFn(getResults()[0], "proc_linear_idx");
 }
 
+namespace {
+#ifndef NDEBUG
+static std::vector<int> convertStringToVector(const std::string &str) {
+  std::vector<int> result;
+  std::stringstream ss(str);
+  std::string item;
+  while (std::getline(ss, item, ',')) {
+    result.push_back(std::stoi(item));
+  }
+  return result;
+}
+#endif // NDEBUG
+
+std::optional<SmallVector<Value>> getMyMultiIndex(OpBuilder &b,
+                                                  ::mlir::mesh::MeshOp mesh) {
+#ifndef NDEBUG
+  if (auto envStr = getenv("DEBUG_MESH_INDEX")) {
+    auto myIdx = convertStringToVector(envStr);
+    if (myIdx.size() == mesh.getShape().size()) {
+      SmallVector<Value> idxs;
+      for (auto i : myIdx) {
+        idxs.push_back(b.create<::mlir::arith::ConstantOp>(mesh->getLoc(),
+                                                           b.getIndexAttr(i)));
+      }
+      return idxs;
+    } else {
+      mesh->emitError() << "DEBUG_MESH_INDEX has wrong size";
+    }
+  }
+#endif // NDEBUG
+  return std::nullopt;
+}
+
+class FoldStaticIndex final : public OpRewritePattern<ProcessMultiIndexOp> {
+public:
+  using OpRewritePattern<ProcessMultiIndexOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
+                                PatternRewriter &b) const override {
+#ifndef NDEBUG
+    SymbolTableCollection tmp;
+    if (auto idxs = getMyMultiIndex(b, getMesh(op, tmp))) {
+      b.replaceOp(op, idxs.value());
+      return success();
+    }
+#endif // NDEBUG
+    return failure();
+  }
+};
+} // namespace
+
+void ProcessMultiIndexOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+  results.add<FoldStaticIndex>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.process_linear_index op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 6684fbfa0aec0b..c1eda4cb257238 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -443,22 +443,27 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
       !sourceSharding.equalHaloSizes(targetSharding)) {
     auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
     auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
+    assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
     assert(
         ((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
          !ShapedType::isDynamicShape(tgtHaloSizes) &&
          sourceShard.getType().hasStaticShape()) &&
         "dynamic shapes/halos are not supported yet for mesh-spmdization");
     auto rank = sourceShard.getType().getRank();
-    SmallVector<int64_t> outShape, srcCoreOffs(rank, 0), tgtCoreOffs,
-        strides(rank, 1), coreShape(sourceShard.getType().getShape());
+    auto splitAxes = sourceSharding.getSplitAxes();
+    SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
+        strides(rank, 1), outShape(sourceShard.getType().getShape()),
+        coreShape(sourceShard.getType().getShape());
     for (auto i = 0u; i < rank; ++i) {
-      if (!srcHaloSizes.empty()) {
-        coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
-        srcCoreOffs[i] = srcHaloSizes[i * 2];
+      if (i < splitAxes.size() && !splitAxes[i].empty()) {
+        if (!srcHaloSizes.empty()) {
+          coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
+          srcCoreOffs[i] = srcHaloSizes[i * 2];
+        }
+        tgtCoreOffs[i] = tgtHaloSizes[i * 2];
+        outShape[i] =
+            coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
       }
-      tgtCoreOffs.emplace_back(tgtHaloSizes[i * 2]);
-      outShape.emplace_back(coreShape[i] + tgtHaloSizes[i * 2] +
-                            tgtHaloSizes[i * 2 + 1]);
     }
     auto noVals = ValueRange{};
     auto initVal = builder.create<tensor::EmptyOp>(
@@ -539,6 +544,10 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                MeshSharding targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
+  if (sourceSharding == targetSharding) {
+    return sourceShard;
+  }
+
   if (auto tryRes = tryUpdateHaloInResharding(
           builder, mesh, sourceSharding, targetSharding,
           sourceUnshardedValue.getType(), sourceShard)) {
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index ea2bd29056ec78..0bd09056835d30 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -191,3 +191,20 @@ func.func @send_empty_mesh_axes(
 // CHECK: return %[[ARG]]
   return %0 : tensor<4xf32>
 }
+
+mesh.mesh @mesh4x4(shape = 4x4)
+// CHECK-LABEL: func @test_halo_sizes
+func.func @test_halo_sizes() -> !mesh.sharding {
+  %c2_i64 = arith.constant 2 : i64
+  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
+  return %sharding : !mesh.sharding
+}
+
+// CHECK-LABEL: func @test_shard_offs
+func.func @test_shard_offs() -> !mesh.sharding {
+  %c2_i64 = arith.constant 2 : i64
+  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [1, 2, 2, 22] : !mesh.sharding
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
+  return %sharding : !mesh.sharding
+}
\ No newline at end of file

>From 5e4cadaf26ecdba44b2d59f561fca982ac80274a Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 17 Oct 2024 14:00:22 +0200
Subject: [PATCH 5/9] removing ProcessMultiIndexOp canonicalizer, fixing
 MeshSharding.equalSplitAndPartialAxes

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 15 +++--
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 63 ++-----------------
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 11 ++--
 3 files changed, 21 insertions(+), 68 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index c967bb74d5c1a8..32af92b99bcd71 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -132,7 +132,6 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
     OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
-  let hasCanonicalizer = 1;
 }
 
 def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
@@ -1061,7 +1060,8 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
   TypesMatchWith<
     "result has same type as destination",
     "result", "destination", "$_self">,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  AttrSizedOperandSegments
 ]> {
   let summary = "Update halo data.";
   let description = [{
@@ -1071,13 +1071,13 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     and/or if the new halo regions are larger than the existing ones.
 
     Assumes all devices hold tensors with same-sized halo data as specified
-    by `source_halo_sizes/static_source_halo_sizes`.
+    by `source_halo_sizes/static_source_halo_sizes` and
+    `destination_halo_sizes/static_destination_halo_sizes`
 
     `split_axes` specifies for each tensor axis along which mesh axes its halo
     data is updated.
 
-    The destination halo sizes are allowed differ from the source sizes. The sizes
-    of the inner (local) shard is inferred from the destination shape and source sharding.
+    Source and destination might have different halo sizes.
   }];
   let arguments = (ins
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
@@ -1085,7 +1085,9 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     FlatSymbolRefAttr:$mesh,
     Mesh_MeshAxesArrayAttr:$split_axes,
     Variadic<I64>:$source_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
+    Variadic<I64>:$destination_halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
   );
   let results = (outs
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
@@ -1095,6 +1097,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     `on` $mesh
     `split_axes` `=` $split_axes
     (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
+    (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
     attr-dict `:` type($source) `->` type($result)
   }];
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fb952bff077a65..d65f7e4bbadd1a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -592,7 +592,12 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
     return false;
   }
 
-  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+  if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
+      (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
+      !llvm::equal(
+          llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
+          llvm::make_range(rhs.getPartialAxes().begin(),
+                           rhs.getPartialAxes().end()))) {
     return false;
   }
 
@@ -776,62 +781,6 @@ void ProcessMultiIndexOp::getAsmResultNames(
   setNameFn(getResults()[0], "proc_linear_idx");
 }
 
-namespace {
-#ifndef NDEBUG
-static std::vector<int> convertStringToVector(const std::string &str) {
-  std::vector<int> result;
-  std::stringstream ss(str);
-  std::string item;
-  while (std::getline(ss, item, ',')) {
-    result.push_back(std::stoi(item));
-  }
-  return result;
-}
-#endif // NDEBUG
-
-std::optional<SmallVector<Value>> getMyMultiIndex(OpBuilder &b,
-                                                  ::mlir::mesh::MeshOp mesh) {
-#ifndef NDEBUG
-  if (auto envStr = getenv("DEBUG_MESH_INDEX")) {
-    auto myIdx = convertStringToVector(envStr);
-    if (myIdx.size() == mesh.getShape().size()) {
-      SmallVector<Value> idxs;
-      for (auto i : myIdx) {
-        idxs.push_back(b.create<::mlir::arith::ConstantOp>(mesh->getLoc(),
-                                                           b.getIndexAttr(i)));
-      }
-      return idxs;
-    } else {
-      mesh->emitError() << "DEBUG_MESH_INDEX has wrong size";
-    }
-  }
-#endif // NDEBUG
-  return std::nullopt;
-}
-
-class FoldStaticIndex final : public OpRewritePattern<ProcessMultiIndexOp> {
-public:
-  using OpRewritePattern<ProcessMultiIndexOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
-                                PatternRewriter &b) const override {
-#ifndef NDEBUG
-    SymbolTableCollection tmp;
-    if (auto idxs = getMyMultiIndex(b, getMesh(op, tmp))) {
-      b.replaceOp(op, idxs.value());
-      return success();
-    }
-#endif // NDEBUG
-    return failure();
-  }
-};
-} // namespace
-
-void ProcessMultiIndexOp::getCanonicalizationPatterns(
-    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
-  results.add<FoldStaticIndex>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // mesh.process_linear_index op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index c1eda4cb257238..6d7d036b2ea2e4 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -483,7 +483,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
         MeshAxesArrayAttr::get(builder.getContext(),
                                sourceSharding.getSplitAxes()),
         sourceSharding.getDynamicHaloSizes(),
-        sourceSharding.getStaticHaloSizes());
+        sourceSharding.getStaticHaloSizes(),
+        targetSharding.getDynamicHaloSizes(),
+        targetSharding.getStaticHaloSizes());
     return std::make_tuple(
         cast<TypedValue<ShapedType>>(targetShard.getResult()), targetSharding);
   }
@@ -568,10 +570,9 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
   auto sourceSharding = source.getSharding();
   auto targetSharding = target.getSharding();
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
-  auto shard =
-      reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
-              cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
-  return shard;
+  return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
+                 cast<TypedValue<ShapedType>>(source.getSrc()),
+                 sourceShardValue);
 }
 
 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,

>From 1cce571499b1ee67b28be7aef9c10fce2d1b0596 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 18 Oct 2024 17:40:38 +0200
Subject: [PATCH 6/9] fixing resharding

---
 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp |  8 ++++----
 mlir/test/Dialect/Mesh/spmdization.mlir          | 16 ++++++++++++++++
 2 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 6d7d036b2ea2e4..89ba282731c44a 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -476,7 +476,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
     auto initOprnd = builder.create<tensor::InsertSliceOp>(
         sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
         tgtCoreOffs, coreShape, strides);
-    auto targetShard = builder.create<UpdateHaloOp>(
+    auto updateHaloResult = builder.create<UpdateHaloOp>(
         sourceShard.getLoc(),
         RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
         sourceShard, initOprnd, mesh.getSymName(),
@@ -485,9 +485,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
         sourceSharding.getDynamicHaloSizes(),
         sourceSharding.getStaticHaloSizes(),
         targetSharding.getDynamicHaloSizes(),
-        targetSharding.getStaticHaloSizes());
+        targetSharding.getStaticHaloSizes()).getResult();
     return std::make_tuple(
-        cast<TypedValue<ShapedType>>(targetShard.getResult()), targetSharding);
+        cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding);
   }
   return std::nullopt;
 }
@@ -710,7 +710,7 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
   } else {
     // Insert resharding.
     TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
-        spmdizationMap.lookup(srcShardOp.getSrc()));
+        spmdizationMap.lookup(srcShardOp));
     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
                               symbolTableCollection);
   }
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 8b0c4053b0dc7e..16224441ffbb7d 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -219,3 +219,19 @@ func.func @ew_chain_with_halo(
   // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
   return %sharding_annotated_6 : tensor<8x16xf32>
 }
+
+// CHECK-LABEL: func @test_shard_update_halo
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
+func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
+  // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
+  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
+  // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64>
+  %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
+  %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
+  %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
+  %sharding_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [298, 598, 898, 1000] : !mesh.sharding
+  %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+  // CHECK: return %[[UH]] : tensor<304x1200xi64>
+  return %sharding_annotated_3 : tensor<1200x1200xi64>
+}

>From a0aa3ebccd8bbc7f12d242995fed39ab242906dd Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 30 Oct 2024 13:03:34 +0100
Subject: [PATCH 7/9] adding test update_halo 2d

---
 mlir/test/Dialect/Mesh/spmdization.mlir | 17 ++++++++++++++++-
 1 file changed, 16 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 16224441ffbb7d..22ddb72569835d 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -230,8 +230,23 @@ func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1
   %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
   %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
   %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
-  %sharding_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [298, 598, 898, 1000] : !mesh.sharding
   %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
   // CHECK: return %[[UH]] : tensor<304x1200xi64>
   return %sharding_annotated_3 : tensor<1200x1200xi64>
 }
+
+mesh.mesh @mesh4x4(shape = 4x4)
+// CHECK-LABEL: func @test_shard_update_halo2d
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
+func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
+  // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
+  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
+  // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64>
+  %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
+  %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
+  %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
+  %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+  // CHECK: return %[[UH]] : tensor<303x307xi64>
+  return %sharding_annotated_3 : tensor<1200x1200xi64>
+}
\ No newline at end of file

>From 13c590c515a34ae699baac8c386d6749aace778e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 30 Oct 2024 15:03:19 +0100
Subject: [PATCH 8/9] comments/docs

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td     |  8 +++++---
 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 16 ++++++++++++++++
 2 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 32af92b99bcd71..04b4b55a433803 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -1066,18 +1066,20 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
   let summary = "Update halo data.";
   let description = [{
     This operation updates halo regions of shards, e.g. if their sharding
-    specified halos and the actual tensor data might have changed
+    specified halos and the actual tensor/memref data might have changed
     on the remote devices. Changes might be caused by mutating operations
     and/or if the new halo regions are larger than the existing ones.
 
+    Source and destination might have different halo sizes.
+
     Assumes all devices hold tensors with same-sized halo data as specified
     by `source_halo_sizes/static_source_halo_sizes` and
-    `destination_halo_sizes/static_destination_halo_sizes`
+    `destination_halo_sizes/static_destination_halo_sizes` in source shard
+    and destination/result shard.
 
     `split_axes` specifies for each tensor axis along which mesh axes its halo
     data is updated.
 
-    Source and destination might have different halo sizes.
   }];
   let arguments = (ins
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 89ba282731c44a..6a25f18e6afc5e 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -429,12 +429,18 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   return std::nullopt;
 }
 
+// Detect a change in the halo size (only) and create necessary operations if
+// needed. A changed halo sizes requires copying the "core" of the source tensor
+// into the "core" of the destination tensor followed by an update halo
+// operation.
 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
 tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                           MeshSharding sourceSharding,
                           MeshSharding targetSharding,
                           ShapedType sourceUnshardedShape,
                           TypedValue<ShapedType> sourceShard) {
+  // currently handles only cases where halo sizes differ but everything else
+  // stays the same (from source to destination sharding)
   if (sourceSharding.equalSplitAndPartialAxes(targetSharding) &&
       sourceSharding.getPartialAxes().empty() &&
       targetSharding.getPartialAxes().empty() &&
@@ -454,6 +460,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
     SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
         strides(rank, 1), outShape(sourceShard.getType().getShape()),
         coreShape(sourceShard.getType().getShape());
+
+    // determine "core" of source and destination
+    // the core is the local part of the shard excluding halo regions
     for (auto i = 0u; i < rank; ++i) {
       if (i < splitAxes.size() && !splitAxes[i].empty()) {
         if (!srcHaloSizes.empty()) {
@@ -465,6 +474,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
             coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
       }
     }
+
+    // extract core from source and copy into destination core
     auto noVals = ValueRange{};
     auto initVal = builder.create<tensor::EmptyOp>(
         sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
@@ -476,6 +487,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
     auto initOprnd = builder.create<tensor::InsertSliceOp>(
         sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
         tgtCoreOffs, coreShape, strides);
+
+    // finally update the halo
     auto updateHaloResult = builder.create<UpdateHaloOp>(
         sourceShard.getLoc(),
         RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
@@ -546,10 +559,13 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                MeshSharding targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
+  // If source and destination sharding are the same, no need to do anything.
   if (sourceSharding == targetSharding) {
     return sourceShard;
   }
 
+  // tries to handle the case where the resharding is needed because the halo
+  // sizes are different. Supports arbitrary mesh dimensionality.
   if (auto tryRes = tryUpdateHaloInResharding(
           builder, mesh, sourceSharding, targetSharding,
           sourceUnshardedValue.getType(), sourceShard)) {

>From 8b209ce10e9495a75f0a0a04c58a939ab687a841 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 30 Oct 2024 15:38:39 +0100
Subject: [PATCH 9/9] clang-format

---
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 32 +++++++++++--------
 1 file changed, 18 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 6a25f18e6afc5e..1444eec0f727ff 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -489,18 +489,22 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
         tgtCoreOffs, coreShape, strides);
 
     // finally update the halo
-    auto updateHaloResult = builder.create<UpdateHaloOp>(
-        sourceShard.getLoc(),
-        RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
-        sourceShard, initOprnd, mesh.getSymName(),
-        MeshAxesArrayAttr::get(builder.getContext(),
-                               sourceSharding.getSplitAxes()),
-        sourceSharding.getDynamicHaloSizes(),
-        sourceSharding.getStaticHaloSizes(),
-        targetSharding.getDynamicHaloSizes(),
-        targetSharding.getStaticHaloSizes()).getResult();
-    return std::make_tuple(
-        cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding);
+    auto updateHaloResult =
+        builder
+            .create<UpdateHaloOp>(
+                sourceShard.getLoc(),
+                RankedTensorType::get(outShape,
+                                      sourceShard.getType().getElementType()),
+                sourceShard, initOprnd, mesh.getSymName(),
+                MeshAxesArrayAttr::get(builder.getContext(),
+                                       sourceSharding.getSplitAxes()),
+                sourceSharding.getDynamicHaloSizes(),
+                sourceSharding.getStaticHaloSizes(),
+                targetSharding.getDynamicHaloSizes(),
+                targetSharding.getStaticHaloSizes())
+            .getResult();
+    return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
+                           targetSharding);
   }
   return std::nullopt;
 }
@@ -725,8 +729,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
     targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
   } else {
     // Insert resharding.
-    TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
-        spmdizationMap.lookup(srcShardOp));
+    TypedValue<ShapedType> srcSpmdValue =
+        cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
                               symbolTableCollection);
   }



More information about the Mlir-commits mailing list