[Mlir-commits] [mlir] [mlir][mesh, MPI] Mesh2mpi (PR #104566)

Frank Schlimbach llvmlistbot at llvm.org
Tue Nov 5 10:51:31 PST 2024


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/104566

>From 0133634490675c26b548f00372e4c8e6d0b375f4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 3 Sep 2024 09:29:35 +0200
Subject: [PATCH 01/26] cut&paste error

---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 19e9212157ae47..e0ecbee7ee11f5 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -583,14 +583,14 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
                                     rhs.getStaticHaloSizes().end()))) {
     return false;
   }
-  if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size() ||
+  if (rhs.getStaticShardedDimsSizes().size() != getStaticShardedDimsSizes().size() ||
       !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
                                     getStaticShardedDimsSizes().end()),
                    llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
                                     rhs.getStaticShardedDimsSizes().end()))) {
     return false;
   }
-  if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size() ||
+  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
       !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
                                     getDynamicHaloSizes().end()),
                    llvm::make_range(rhs.getDynamicHaloSizes().begin(),

>From e99596fe154f921003e6acd4b1e3d957c9d05208 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 10 Oct 2024 12:09:44 +0200
Subject: [PATCH 02/26] initial spmdizeation pattern for halo updates

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  3 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 32 +++++---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 35 +++++----
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 77 +++++++++++++++++--
 4 files changed, 116 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index db7b64fda57d7b..d1cab4d819a0e0 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/MathExtras.h"
@@ -80,6 +81,8 @@ class MeshSharding {
   bool operator!=(const MeshSharding &rhs) const;
   bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
   bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
+  bool equalHaloSizes(const MeshSharding &rhs) const;
+  bool equalShardSizes(const MeshSharding &rhs) const;
 };
 
 } // namespace mesh
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..0c5606a94a6dc7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/Mesh/IR/MeshBase.td"
 include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -1052,6 +1053,10 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 }
 
 def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
+  DestinationStyleOpInterface,
+  TypesMatchWith<
+    "result has same type as destination",
+    "result", "destination", "$_self">,
   DeclareOpInterfaceMethods<SymbolUserOpInterface>
 ]> {
   let summary = "Update halo data.";
@@ -1062,27 +1067,34 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     and/or if the new halo regions are larger than the existing ones.
 
     Assumes all devices hold tensors with same-sized halo data as specified
-    by `dynamic/static_halo_sizes`.
+    by `source_halo_sizes/static_source_halo_sizes`.
 
     `split_axes` specifies for each tensor axis along which mesh axes its halo
     data is updated.
 
-    Optionally resizes to new halo sizes `target_halo_sizes`.
+    The destination halo sizes are allowed differ from the source sizes. The sizes
+    of the inner (local) shard is inferred from the destination shape and source sharding.
   }];
   let arguments = (ins
-    AnyNon0RankedMemRef:$input,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
     FlatSymbolRefAttr:$mesh,
     Mesh_MeshAxesArrayAttr:$split_axes,
-    Variadic<I64>:$dynamic_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$target_halo_sizes
+    Variadic<I64>:$source_halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes
+  );
+  let results = (outs
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh
+    $source `into` $destination
+    `on` $mesh
     `split_axes` `=` $split_axes
-    (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
-    (`target_halo_sizes` `=` $target_halo_sizes^)?
-    attr-dict `:` type($input)
+    (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
+    attr-dict `:` type($source) `->` type($result)
+  }];
+  let extraClassDeclaration = [{
+    MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
   }];
 }
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index e0ecbee7ee11f5..29356a9de73f56 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -576,13 +576,10 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
 }
 
 bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
-  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
-      !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
-                                    getStaticHaloSizes().end()),
-                   llvm::make_range(rhs.getStaticHaloSizes().begin(),
-                                    rhs.getStaticHaloSizes().end()))) {
-    return false;
-  }
+  return equalShardSizes(rhs) && equalHaloSizes(rhs);
+}
+
+bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
   if (rhs.getStaticShardedDimsSizes().size() != getStaticShardedDimsSizes().size() ||
       !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
                                     getStaticShardedDimsSizes().end()),
@@ -590,13 +587,6 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
                                     rhs.getStaticShardedDimsSizes().end()))) {
     return false;
   }
-  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
-      !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
-                                    getDynamicHaloSizes().end()),
-                   llvm::make_range(rhs.getDynamicHaloSizes().begin(),
-                                    rhs.getDynamicHaloSizes().end()))) {
-    return false;
-  }
   if (rhs.getDynamicShardedDimsSizes().size() !=
           getDynamicShardedDimsSizes().size() ||
       !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
@@ -605,6 +595,23 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
                                     rhs.getDynamicShardedDimsSizes().end()))) {
     return false;
   }
+}
+
+bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
+  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
+      !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
+                                    getStaticHaloSizes().end()),
+                   llvm::make_range(rhs.getStaticHaloSizes().begin(),
+                                    rhs.getStaticHaloSizes().end()))) {
+    return false;
+  }
+  if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
+      !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
+                                    getDynamicHaloSizes().end()),
+                   llvm::make_range(rhs.getDynamicHaloSizes().begin(),
+                                    rhs.getDynamicHaloSizes().end()))) {
+    return false;
+  }
   return true;
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index fdfed39972fd52..8635890f9c5afd 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -429,6 +429,62 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   return std::nullopt;
 }
 
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
+tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
+                          MeshSharding sourceSharding,
+                          MeshSharding targetSharding,
+                          ShapedType sourceUnshardedShape,
+                          TypedValue<ShapedType> sourceShard) {
+  if (sourceSharding.equalSplitAndPartialAxes(targetSharding) &&
+      sourceSharding.getPartialAxes().empty() &&
+      targetSharding.getPartialAxes().empty() &&
+      sourceSharding.getStaticShardedDimsSizes().empty() &&
+      targetSharding.getStaticShardedDimsSizes().empty() &&
+      !sourceSharding.equalHaloSizes(targetSharding)) {
+    auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
+    auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
+    assert(
+        ((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
+         !ShapedType::isDynamicShape(tgtHaloSizes) &&
+         sourceShard.getType().hasStaticShape()) &&
+        "dynamic shapes/halos are not supported yet for mesh-spmdization");
+    auto rank = sourceShard.getType().getRank();
+    SmallVector<int64_t> outShape, srcCoreOffs(rank, 0), tgtCoreOffs,
+        strides(rank, 1), coreShape(sourceShard.getType().getShape());
+    for (auto i = 0u; i < rank; ++i) {
+      if (!srcHaloSizes.empty()) {
+        coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
+        srcCoreOffs[i] = srcHaloSizes[i * 2];
+      }
+      tgtCoreOffs.emplace_back(tgtHaloSizes[i * 2]);
+      outShape.emplace_back(coreShape[i] + tgtHaloSizes[i * 2] +
+                            tgtHaloSizes[i * 2 + 1]);
+    }
+    auto noVals = ValueRange{};
+    auto initVal = builder.create<tensor::EmptyOp>(
+        sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
+    auto core = builder.create<tensor::ExtractSliceOp>(
+        sourceShard.getLoc(),
+        RankedTensorType::get(coreShape,
+                              sourceShard.getType().getElementType()),
+        sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
+    auto initOprnd = builder.create<tensor::InsertSliceOp>(
+        sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
+        tgtCoreOffs, coreShape, strides);
+    auto targetShard = builder.create<UpdateHaloOp>(
+        sourceShard.getLoc(),
+        RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
+        sourceShard, initOprnd, mesh.getSymName(),
+        MeshAxesArrayAttr::get(builder.getContext(),
+                               sourceSharding.getSplitAxes()),
+        sourceSharding.getDynamicHaloSizes(),
+        sourceSharding.getStaticHaloSizes());
+    return std::make_tuple(
+        cast<TypedValue<ShapedType>>(targetShard.getResult()), targetSharding);
+  }
+  return std::nullopt;
+}
+
 // Handles only resharding on a 1D mesh.
 // Currently the sharded tensor axes must be exactly divisible by the single
 // mesh axis size.
@@ -454,10 +510,10 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
 
   TypedValue<ShapedType> targetShard;
   MeshSharding actualTargetSharding;
-  if (reducedSourceSharding.getStaticHaloSizes().empty() &&
-      targetSharding.getStaticHaloSizes().empty() &&
-      reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
-      targetSharding.getStaticShardedDimsSizes().empty()) {
+  if (reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
+      targetSharding.getStaticShardedDimsSizes().empty() &&
+      reducedSourceSharding.getStaticHaloSizes().empty() &&
+      targetSharding.getStaticHaloSizes().empty()) {
     if (auto tryRes = tryMoveLastSplitAxisInResharding(
             builder, mesh, reducedSourceSharding, targetSharding,
             sourceUnshardedValue.getType(), reducedSourceShard)) {
@@ -483,6 +539,12 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                MeshSharding targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
+  if (auto tryRes = tryUpdateHaloInResharding(
+          builder, mesh, sourceSharding, targetSharding,
+          sourceUnshardedValue.getType(), sourceShard)) {
+    return std::get<0>(tryRes.value()); // targetShard
+  }
+
   // Resort to handling only 1D meshes since the general case is complicated if
   // it needs to be communication efficient in terms of minimizing the data
   // transfered between devices.
@@ -497,9 +559,10 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
   auto sourceSharding = source.getSharding();
   auto targetSharding = target.getSharding();
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
-  return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
-                 cast<TypedValue<ShapedType>>(source.getSrc()),
-                 sourceShardValue);
+  auto shard =
+      reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
+              cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
+  return shard;
 }
 
 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,

>From 3c1b4eb829aa44b3d68b1b09b83265bf97576a41 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 15 Oct 2024 17:44:55 +0200
Subject: [PATCH 03/26] update to mesh.update_halo; sharded_dims_sizes ->
 sharded_dims_offsets

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  16 +--
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  28 ++---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 102 +++++++++---------
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |   8 +-
 mlir/test/Dialect/Mesh/invalid.mlir           |   6 +-
 mlir/test/Dialect/Mesh/ops.mlir               |  26 +++--
 .../test/Dialect/Tensor/mesh-spmdization.mlir |  22 ++--
 7 files changed, 106 insertions(+), 102 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index d1cab4d819a0e0..75cb096130ca6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -46,9 +46,9 @@ class MeshSharding {
   SmallVector<MeshAxis> partial_axes;
   ReductionKind partial_type;
   SmallVector<int64_t> static_halo_sizes;
-  SmallVector<int64_t> static_sharded_dims_sizes;
+  SmallVector<int64_t> static_sharded_dims_offsets;
   SmallVector<Value> dynamic_halo_sizes;
-  SmallVector<Value> dynamic_sharded_dims_sizes;
+  SmallVector<Value> dynamic_sharded_dims_offsets;
 
 public:
   MeshSharding() = default;
@@ -58,21 +58,21 @@ class MeshSharding {
                           ArrayRef<MeshAxis> partial_axes_ = {},
                           ReductionKind partial_type_ = ReductionKind::Sum,
                           ArrayRef<int64_t> static_halo_sizes_ = {},
-                          ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+                          ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
                           ArrayRef<Value> dynamic_halo_sizes_ = {},
-                          ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+                          ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
   ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
   ::llvm::StringRef getMesh() const { return mesh.getValue(); }
   ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
   ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
   ReductionKind getPartialType() const { return partial_type; }
   ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
-  ArrayRef<int64_t> getStaticShardedDimsSizes() const {
-    return static_sharded_dims_sizes;
+  ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
+    return static_sharded_dims_offsets;
   }
   ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
-  ArrayRef<Value> getDynamicShardedDimsSizes() const {
-    return dynamic_sharded_dims_sizes;
+  ArrayRef<Value> getDynamicShardedDimsOffsets() const {
+    return dynamic_sharded_dims_offsets;
   }
   operator bool() const { return (!mesh) == false; }
   bool operator==(Value rhs) const;
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 0c5606a94a6dc7..f881f22adcf52a 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -197,16 +197,18 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
     `?` indicates dynamic halo sizes.
     
-    6. [Optional] Sizes of sharded dimensions of each shard.
-    `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
-    device-mesh one value for each sharded tensor dimension.
+    6. [Optional] Offsets for each shard and sharded tensor dimension.
+    `sharded_dims_offsets`is provided as a flattened 1d array of i64s:
+    For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
+    The offset of each first shard is omitted and is implicitly assumed to be 0.
+    The last value per dimension denotes the end of the last shard.
     Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
-    `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
-    the device-mesh will get a shard of shape 16x8x32 and the second device will get a
-    shard of shape 16x24x32.
+    `sharded_dims_offsets` = [24, 32, 20, 32] means that the first device of
+    the device-mesh will get a shard of shape 24x20x32 and the second device will get a
+    shard of shape 8x12x32.
     `?` indicates dynamic shard dimensions.
     
-    `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+    `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
 
     Examples:
 
@@ -241,7 +243,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
     // and it has pre-defined shard sizes. The shards of the devices will have
     // the following shapes: [4x2, 4x3, 4x4, 4x5]
-    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
+    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_offsets = [2, 5, 9, 14]
     %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];
@@ -251,8 +253,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     Mesh_MeshAxesArrayAttr:$split_axes,
     OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
     OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
-    Variadic<I64>:$dynamic_sharded_dims_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
+    Variadic<I64>:$dynamic_sharded_dims_offsets,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
     Variadic<I64>:$dynamic_halo_sizes
   );
@@ -264,7 +266,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     `split_axes` `=` $split_axes
     (`partial` `=` $partial_type $partial_axes^)?
     (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
-    (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
+    (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
     attr-dict `:` type($result)
   }];
   let builders = [
@@ -273,13 +275,13 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
                    "ArrayRef<MeshAxis>":$partial_axes,
                    "mesh::ReductionKind":$partial_type,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
-                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes)>,
     OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
                    "ArrayRef<MeshAxesAttr>":$split_axes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
-                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
+                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
     OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
   ];
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 29356a9de73f56..b10f8ac98a27d1 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -192,33 +192,33 @@ template <typename InShape, typename MeshShape, typename SplitAxes,
           typename OutShape>
 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                        const SplitAxes &splitAxes, OutShape &outShape,
-                       ArrayRef<int64_t> shardedDimsSizes = {},
+                       ArrayRef<int64_t> shardedDimsOffsets = {},
                        ArrayRef<int64_t> haloSizes = {}) {
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
 
-  if (!shardedDimsSizes.empty()) {
+  if (!shardedDimsOffsets.empty()) {
+    uint64_t pos = 0;
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-      if (innerSplitAxes.empty()) {
-#ifndef NDEBUG
-        for (auto dimSz : shardedDimsSizes) {
-          auto inAxis = dimSz % inShape.size();
-          assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
-                 inShape[inAxis] == ShapedType::kDynamic);
-        }
-#endif // NDEBUG
-      } else {
-        // find sharded dims in sharded_dims_sizes with same static size on
-        // all devices. Use kDynamic for dimensions with dynamic or non-uniform
-        // sizes in sharded_dims_sizes.
-        auto sz = shardedDimsSizes[tensorAxis];
-        bool same = true;
-        for (size_t i = tensorAxis + inShape.size();
-             i < shardedDimsSizes.size(); i += inShape.size()) {
-          if (shardedDimsSizes[i] != sz) {
-            same = false;
-            break;
+      if (!innerSplitAxes.empty()) {
+        auto sz = shardedDimsOffsets[pos];
+        bool same = !ShapedType::isDynamicShape(meshShape);
+        if (same) {
+          // find sharded dims in shardedDimsOffsets with same static size on
+          // all devices. Use kDynamic for dimensions with dynamic or
+          // non-uniform offs in shardedDimsOffsets.
+          uint64_t numShards = 0;
+          for (auto i : innerSplitAxes.asArrayRef()) {
+            numShards += meshShape[i];
+          }
+          for (size_t i = 1; i < numShards; ++i) {
+            if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
+                sz) {
+              same = false;
+              break;
+            }
           }
+          pos += numShards;
         }
         outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
       }
@@ -255,7 +255,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
   shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
-             resShapeArr, sharding.getStaticShardedDimsSizes(),
+             resShapeArr, sharding.getStaticShardedDimsOffsets(),
              sharding.getStaticHaloSizes());
   return shape.clone(resShapeArr);
 }
@@ -432,13 +432,14 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        ArrayRef<MeshAxis> partial_axes,
                        mesh::ReductionKind partial_type,
                        ArrayRef<int64_t> static_halo_sizes,
-                       ArrayRef<int64_t> static_sharded_dims_sizes) {
+                       ArrayRef<int64_t> static_sharded_dims_offsets) {
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
-      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(),
+                                     static_sharded_dims_offsets),
       {});
 }
 
@@ -455,11 +456,11 @@ void ShardingOp::build(
     ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
     FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
     ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
-    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
   mlir::SmallVector<int64_t> staticHalos, staticDims;
   mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
   dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
-  dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+  dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
       ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
@@ -477,10 +478,10 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
             : b.getDenseI16ArrayAttr(from.getPartialAxes()),
         ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
                                              from.getPartialType()),
-        from.getStaticShardedDimsSizes().empty()
+        from.getStaticShardedDimsOffsets().empty()
             ? DenseI64ArrayAttr()
-            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
-        from.getDynamicShardedDimsSizes(),
+            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
+        from.getDynamicShardedDimsOffsets(),
         from.getStaticHaloSizes().empty()
             ? DenseI64ArrayAttr()
             : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
@@ -509,7 +510,7 @@ LogicalResult ShardingOp::verify() {
       failed(checkMeshAxis(getPartialAxes().value())))
     return failure();
 
-  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
+  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
     return emitOpError("halo sizes and shard shapes are mutually exclusive");
   }
 
@@ -539,8 +540,8 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
   if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
-      getStaticShardedDimsSizes().size() > 0) {
-    return emitError() << "sharded dims sizes are not allowed for "
+      getStaticShardedDimsOffsets().size() > 0) {
+    return emitError() << "sharded dims offsets are not allowed for "
                           "devices meshes with dynamic shape.";
   }
   return success();
@@ -580,21 +581,24 @@ bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
 }
 
 bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
-  if (rhs.getStaticShardedDimsSizes().size() != getStaticShardedDimsSizes().size() ||
-      !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
-                                    getStaticShardedDimsSizes().end()),
-                   llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
-                                    rhs.getStaticShardedDimsSizes().end()))) {
+  if (rhs.getStaticShardedDimsOffsets().size() !=
+          getStaticShardedDimsOffsets().size() ||
+      !llvm::equal(llvm::make_range(getStaticShardedDimsOffsets().begin(),
+                                    getStaticShardedDimsOffsets().end()),
+                   llvm::make_range(rhs.getStaticShardedDimsOffsets().begin(),
+                                    rhs.getStaticShardedDimsOffsets().end()))) {
     return false;
   }
-  if (rhs.getDynamicShardedDimsSizes().size() !=
-          getDynamicShardedDimsSizes().size() ||
-      !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
-                                    getDynamicShardedDimsSizes().end()),
-                   llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(),
-                                    rhs.getDynamicShardedDimsSizes().end()))) {
+  if (rhs.getDynamicShardedDimsOffsets().size() !=
+          getDynamicShardedDimsOffsets().size() ||
+      !llvm::equal(
+          llvm::make_range(getDynamicShardedDimsOffsets().begin(),
+                           getDynamicShardedDimsOffsets().end()),
+          llvm::make_range(rhs.getDynamicShardedDimsOffsets().begin(),
+                           rhs.getDynamicShardedDimsOffsets().end()))) {
     return false;
   }
+  return true;
 }
 
 bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
@@ -636,9 +640,9 @@ MeshSharding::MeshSharding(Value rhs) {
               shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
               shardingOp.getPartialType().value_or(ReductionKind::Sum),
               shardingOp.getStaticHaloSizes(),
-              shardingOp.getStaticShardedDimsSizes(),
+              shardingOp.getStaticShardedDimsOffsets(),
               SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
-              SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
+              SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
 }
 
 MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
@@ -646,9 +650,9 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
                                ArrayRef<MeshAxis> partial_axes_,
                                ReductionKind partial_type_,
                                ArrayRef<int64_t> static_halo_sizes_,
-                               ArrayRef<int64_t> static_sharded_dims_sizes_,
+                               ArrayRef<int64_t> static_sharded_dims_offsets_,
                                ArrayRef<Value> dynamic_halo_sizes_,
-                               ArrayRef<Value> dynamic_sharded_dims_sizes_) {
+                               ArrayRef<Value> dynamic_sharded_dims_offsets_) {
   MeshSharding res;
   res.mesh = mesh_;
   res.split_axes.resize(split_axes_.size());
@@ -665,9 +669,9 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
   clone(partial_axes_, res.partial_axes);
   res.partial_type = partial_type_;
   clone(static_halo_sizes_, res.static_halo_sizes);
-  clone(static_sharded_dims_sizes_, res.static_sharded_dims_sizes);
+  clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
   clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
-  clone(dynamic_sharded_dims_sizes_, res.dynamic_sharded_dims_sizes);
+  clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
 
   return res;
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 8635890f9c5afd..6684fbfa0aec0b 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -438,8 +438,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   if (sourceSharding.equalSplitAndPartialAxes(targetSharding) &&
       sourceSharding.getPartialAxes().empty() &&
       targetSharding.getPartialAxes().empty() &&
-      sourceSharding.getStaticShardedDimsSizes().empty() &&
-      targetSharding.getStaticShardedDimsSizes().empty() &&
+      sourceSharding.getStaticShardedDimsOffsets().empty() &&
+      targetSharding.getStaticShardedDimsOffsets().empty() &&
       !sourceSharding.equalHaloSizes(targetSharding)) {
     auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
     auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
@@ -510,8 +510,8 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
 
   TypedValue<ShapedType> targetShard;
   MeshSharding actualTargetSharding;
-  if (reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
-      targetSharding.getStaticShardedDimsSizes().empty() &&
+  if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
+      targetSharding.getStaticShardedDimsOffsets().empty() &&
       reducedSourceSharding.getStaticHaloSizes().empty() &&
       targetSharding.getStaticHaloSizes().empty()) {
     if (auto tryRes = tryMoveLastSplitAxisInResharding(
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3827df90e6962f..f8abc6959a13a4 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -90,7 +90,7 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
 
 func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
   // expected-error at +1 {{halo sizes and shard shapes are mutually exclusive}}
-  %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_sizes = [2, 2] : !mesh.sharding
+  %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [2, 2] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
@@ -99,8 +99,8 @@ func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
 
 mesh.mesh @mesh_dyn(shape = ?x?)
 func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
-  // expected-error at +1 {{sharded dims sizes are not allowed for devices meshes with dynamic shape}}
-  %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_sizes = [2, 2] : !mesh.sharding
+  // expected-error at +1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}}
+  %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [2, 2] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 5ead7babe2c084..3b0ef39a6e212e 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -144,10 +144,10 @@ func.func @mesh_shard_halo_sizes() -> () {
 func.func @mesh_shard_dims_sizes() -> () {
   // CHECK: %[[C3:.*]] = arith.constant 3 : i64
   %c3 = arith.constant 3 : i64
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
-  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [4, %[[C3]], 1] : !mesh.sharding
-  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_sizes = [4, %c3, 1] : !mesh.sharding
+  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 4, 2] : !mesh.sharding
+  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [1, 4, 2] : !mesh.sharding
+  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [4, %[[C3]], 1] : !mesh.sharding
+  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [4, %c3, 1] : !mesh.sharding
   return
 }
 
@@ -615,18 +615,16 @@ func.func @update_halo(
     // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
   // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] into %[[ARG]] on @mesh0
   // CHECK-SAME: split_axes = {{\[\[}}0]]
-  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
+  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> -> memref<12x12xi8>
   %c2 = arith.constant 2 : i64
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
-    halo_sizes = [2, %c2] : memref<12x12xi8>
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  %uh1 = mesh.update_halo %arg0 into %arg0 on @mesh0 split_axes = [[0]]
+    source_halo_sizes = [2, %c2] : memref<12x12xi8> -> memref<12x12xi8>
+  // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[ARG]] into %[[UH1]] on @mesh0
   // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
-  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
-  // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
-    halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
-    : memref<12x12xi8>
+  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> -> memref<12x12xi8>
+  %uh2 = mesh.update_halo %arg0 into %uh1 on @mesh0 split_axes = [[0], [1]]
+    source_halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> -> memref<12x12xi8>
   return
 }
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 611acb5b41445b..fba9e17b53934a 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -4,12 +4,12 @@
 
 mesh.mesh @mesh_1d_4(shape = 4)
 
-// CHECK-LABEL: func @tensor_empty_static_sharded_dims_sizes
-func.func @tensor_empty_static_sharded_dims_sizes() -> () {
+// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
+func.func @tensor_empty_static_sharded_dims_offsets() -> () {
   %b = tensor.empty() : tensor<8x16xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
-  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
   // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
   // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
@@ -17,13 +17,13 @@ func.func @tensor_empty_static_sharded_dims_sizes() -> () {
   return
 }
 
-// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_sizes
+// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
 // CHECK-SAME: %[[A0:.*]]: index
-func.func @tensor_empty_dynamic_sharded_dims_sizes(%arg0 : index) -> () {
+func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
   %b = tensor.empty(%arg0) : tensor<8x?xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
-  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
   // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
   // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
@@ -33,9 +33,9 @@ func.func @tensor_empty_dynamic_sharded_dims_sizes(%arg0 : index) -> () {
 
 // CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
 func.func @tensor_empty_same_static_dims_sizes() -> () {
-  %b = tensor.empty() : tensor<8x16xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_sizes = [4, 4, 4, 4] : !mesh.sharding
-  %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
+  %b = tensor.empty() : tensor<16x16xf32>
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [4, 8, 12, 16] : !mesh.sharding
+  %sharded= mesh.shard %b to %sharding : tensor<16x16xf32>
   // CHECK-NEXT:  tensor.empty() : tensor<4x16xf32>
 
   return

>From aa26d2be192f5c7f0823e6c27f7697358db40c60 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 17 Oct 2024 11:57:00 +0200
Subject: [PATCH 04/26] canonicalize (debug) ProcessMultiIndexOp

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  4 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 92 +++++++++++++++++++
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 25 +++--
 mlir/test/Dialect/Mesh/canonicalization.mlir  | 17 ++++
 4 files changed, 129 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index f881f22adcf52a..c967bb74d5c1a8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -132,6 +132,7 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
     OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
@@ -198,7 +199,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     `?` indicates dynamic halo sizes.
     
     6. [Optional] Offsets for each shard and sharded tensor dimension.
-    `sharded_dims_offsets`is provided as a flattened 1d array of i64s:
+    `sharded_dims_offsets` is provided as a flattened 1d array of i64s:
     For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
     The offset of each first shard is omitted and is implicitly assumed to be 0.
     The last value per dimension denotes the end of the last shard.
@@ -285,6 +286,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
   ];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b10f8ac98a27d1..fb952bff077a65 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -547,6 +547,42 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
+namespace {
+class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
+public:
+  using OpRewritePattern<ShardingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShardingOp op,
+                                PatternRewriter &b) const override {
+    auto mixedHalos =
+        getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
+    auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
+                                    op.getDynamicShardedDimsOffsets(), b);
+
+    // No constant operands were folded, just return;
+    if (failed(foldDynamicIndexList(mixedHalos, /*onlyNonNegative=*/true)) &&
+        failed(foldDynamicIndexList(mixedOffs, /*onlyNonNegative=*/true))) {
+      return failure();
+    }
+
+    auto halos = decomposeMixedValues(mixedHalos);
+    auto offs = decomposeMixedValues(mixedOffs);
+
+    op.setStaticHaloSizes(halos.first);
+    op.getDynamicHaloSizesMutable().assign(halos.second);
+    op.setStaticShardedDimsOffsets(offs.first);
+    op.getDynamicShardedDimsOffsetsMutable().assign(offs.second);
+
+    return success();
+  }
+};
+} // namespace
+
+void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
+                                             mlir::MLIRContext *context) {
+  results.add<FoldDynamicLists>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MeshSharding
 //===----------------------------------------------------------------------===//
@@ -740,6 +776,62 @@ void ProcessMultiIndexOp::getAsmResultNames(
   setNameFn(getResults()[0], "proc_linear_idx");
 }
 
+namespace {
+#ifndef NDEBUG
+static std::vector<int> convertStringToVector(const std::string &str) {
+  std::vector<int> result;
+  std::stringstream ss(str);
+  std::string item;
+  while (std::getline(ss, item, ',')) {
+    result.push_back(std::stoi(item));
+  }
+  return result;
+}
+#endif // NDEBUG
+
+std::optional<SmallVector<Value>> getMyMultiIndex(OpBuilder &b,
+                                                  ::mlir::mesh::MeshOp mesh) {
+#ifndef NDEBUG
+  if (auto envStr = getenv("DEBUG_MESH_INDEX")) {
+    auto myIdx = convertStringToVector(envStr);
+    if (myIdx.size() == mesh.getShape().size()) {
+      SmallVector<Value> idxs;
+      for (auto i : myIdx) {
+        idxs.push_back(b.create<::mlir::arith::ConstantOp>(mesh->getLoc(),
+                                                           b.getIndexAttr(i)));
+      }
+      return idxs;
+    } else {
+      mesh->emitError() << "DEBUG_MESH_INDEX has wrong size";
+    }
+  }
+#endif // NDEBUG
+  return std::nullopt;
+}
+
+class FoldStaticIndex final : public OpRewritePattern<ProcessMultiIndexOp> {
+public:
+  using OpRewritePattern<ProcessMultiIndexOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
+                                PatternRewriter &b) const override {
+#ifndef NDEBUG
+    SymbolTableCollection tmp;
+    if (auto idxs = getMyMultiIndex(b, getMesh(op, tmp))) {
+      b.replaceOp(op, idxs.value());
+      return success();
+    }
+#endif // NDEBUG
+    return failure();
+  }
+};
+} // namespace
+
+void ProcessMultiIndexOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+  results.add<FoldStaticIndex>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.process_linear_index op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 6684fbfa0aec0b..c1eda4cb257238 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -443,22 +443,27 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
       !sourceSharding.equalHaloSizes(targetSharding)) {
     auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
     auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
+    assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
     assert(
         ((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
          !ShapedType::isDynamicShape(tgtHaloSizes) &&
          sourceShard.getType().hasStaticShape()) &&
         "dynamic shapes/halos are not supported yet for mesh-spmdization");
     auto rank = sourceShard.getType().getRank();
-    SmallVector<int64_t> outShape, srcCoreOffs(rank, 0), tgtCoreOffs,
-        strides(rank, 1), coreShape(sourceShard.getType().getShape());
+    auto splitAxes = sourceSharding.getSplitAxes();
+    SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
+        strides(rank, 1), outShape(sourceShard.getType().getShape()),
+        coreShape(sourceShard.getType().getShape());
     for (auto i = 0u; i < rank; ++i) {
-      if (!srcHaloSizes.empty()) {
-        coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
-        srcCoreOffs[i] = srcHaloSizes[i * 2];
+      if (i < splitAxes.size() && !splitAxes[i].empty()) {
+        if (!srcHaloSizes.empty()) {
+          coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
+          srcCoreOffs[i] = srcHaloSizes[i * 2];
+        }
+        tgtCoreOffs[i] = tgtHaloSizes[i * 2];
+        outShape[i] =
+            coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
       }
-      tgtCoreOffs.emplace_back(tgtHaloSizes[i * 2]);
-      outShape.emplace_back(coreShape[i] + tgtHaloSizes[i * 2] +
-                            tgtHaloSizes[i * 2 + 1]);
     }
     auto noVals = ValueRange{};
     auto initVal = builder.create<tensor::EmptyOp>(
@@ -539,6 +544,10 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                MeshSharding targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
+  if (sourceSharding == targetSharding) {
+    return sourceShard;
+  }
+
   if (auto tryRes = tryUpdateHaloInResharding(
           builder, mesh, sourceSharding, targetSharding,
           sourceUnshardedValue.getType(), sourceShard)) {
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index ea2bd29056ec78..0bd09056835d30 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -191,3 +191,20 @@ func.func @send_empty_mesh_axes(
 // CHECK: return %[[ARG]]
   return %0 : tensor<4xf32>
 }
+
+mesh.mesh @mesh4x4(shape = 4x4)
+// CHECK-LABEL: func @test_halo_sizes
+func.func @test_halo_sizes() -> !mesh.sharding {
+  %c2_i64 = arith.constant 2 : i64
+  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
+  return %sharding : !mesh.sharding
+}
+
+// CHECK-LABEL: func @test_shard_offs
+func.func @test_shard_offs() -> !mesh.sharding {
+  %c2_i64 = arith.constant 2 : i64
+  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [1, 2, 2, 22] : !mesh.sharding
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
+  return %sharding : !mesh.sharding
+}
\ No newline at end of file

>From 5e4cadaf26ecdba44b2d59f561fca982ac80274a Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 17 Oct 2024 14:00:22 +0200
Subject: [PATCH 05/26] removing ProcessMultiIndexOp canonicalizer, fixing
 MeshSharding.equalSplitAndPartialAxes

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 15 +++--
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 63 ++-----------------
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 11 ++--
 3 files changed, 21 insertions(+), 68 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index c967bb74d5c1a8..32af92b99bcd71 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -132,7 +132,6 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
     OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
-  let hasCanonicalizer = 1;
 }
 
 def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
@@ -1061,7 +1060,8 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
   TypesMatchWith<
     "result has same type as destination",
     "result", "destination", "$_self">,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  AttrSizedOperandSegments
 ]> {
   let summary = "Update halo data.";
   let description = [{
@@ -1071,13 +1071,13 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     and/or if the new halo regions are larger than the existing ones.
 
     Assumes all devices hold tensors with same-sized halo data as specified
-    by `source_halo_sizes/static_source_halo_sizes`.
+    by `source_halo_sizes/static_source_halo_sizes` and
+    `destination_halo_sizes/static_destination_halo_sizes`
 
     `split_axes` specifies for each tensor axis along which mesh axes its halo
     data is updated.
 
-    The destination halo sizes are allowed differ from the source sizes. The sizes
-    of the inner (local) shard is inferred from the destination shape and source sharding.
+    Source and destination might have different halo sizes.
   }];
   let arguments = (ins
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
@@ -1085,7 +1085,9 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     FlatSymbolRefAttr:$mesh,
     Mesh_MeshAxesArrayAttr:$split_axes,
     Variadic<I64>:$source_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
+    Variadic<I64>:$destination_halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
   );
   let results = (outs
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
@@ -1095,6 +1097,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     `on` $mesh
     `split_axes` `=` $split_axes
     (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
+    (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
     attr-dict `:` type($source) `->` type($result)
   }];
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fb952bff077a65..d65f7e4bbadd1a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -592,7 +592,12 @@ bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
     return false;
   }
 
-  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+  if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
+      (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
+      !llvm::equal(
+          llvm::make_range(getPartialAxes().begin(), getPartialAxes().end()),
+          llvm::make_range(rhs.getPartialAxes().begin(),
+                           rhs.getPartialAxes().end()))) {
     return false;
   }
 
@@ -776,62 +781,6 @@ void ProcessMultiIndexOp::getAsmResultNames(
   setNameFn(getResults()[0], "proc_linear_idx");
 }
 
-namespace {
-#ifndef NDEBUG
-static std::vector<int> convertStringToVector(const std::string &str) {
-  std::vector<int> result;
-  std::stringstream ss(str);
-  std::string item;
-  while (std::getline(ss, item, ',')) {
-    result.push_back(std::stoi(item));
-  }
-  return result;
-}
-#endif // NDEBUG
-
-std::optional<SmallVector<Value>> getMyMultiIndex(OpBuilder &b,
-                                                  ::mlir::mesh::MeshOp mesh) {
-#ifndef NDEBUG
-  if (auto envStr = getenv("DEBUG_MESH_INDEX")) {
-    auto myIdx = convertStringToVector(envStr);
-    if (myIdx.size() == mesh.getShape().size()) {
-      SmallVector<Value> idxs;
-      for (auto i : myIdx) {
-        idxs.push_back(b.create<::mlir::arith::ConstantOp>(mesh->getLoc(),
-                                                           b.getIndexAttr(i)));
-      }
-      return idxs;
-    } else {
-      mesh->emitError() << "DEBUG_MESH_INDEX has wrong size";
-    }
-  }
-#endif // NDEBUG
-  return std::nullopt;
-}
-
-class FoldStaticIndex final : public OpRewritePattern<ProcessMultiIndexOp> {
-public:
-  using OpRewritePattern<ProcessMultiIndexOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
-                                PatternRewriter &b) const override {
-#ifndef NDEBUG
-    SymbolTableCollection tmp;
-    if (auto idxs = getMyMultiIndex(b, getMesh(op, tmp))) {
-      b.replaceOp(op, idxs.value());
-      return success();
-    }
-#endif // NDEBUG
-    return failure();
-  }
-};
-} // namespace
-
-void ProcessMultiIndexOp::getCanonicalizationPatterns(
-    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
-  results.add<FoldStaticIndex>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // mesh.process_linear_index op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index c1eda4cb257238..6d7d036b2ea2e4 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -483,7 +483,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
         MeshAxesArrayAttr::get(builder.getContext(),
                                sourceSharding.getSplitAxes()),
         sourceSharding.getDynamicHaloSizes(),
-        sourceSharding.getStaticHaloSizes());
+        sourceSharding.getStaticHaloSizes(),
+        targetSharding.getDynamicHaloSizes(),
+        targetSharding.getStaticHaloSizes());
     return std::make_tuple(
         cast<TypedValue<ShapedType>>(targetShard.getResult()), targetSharding);
   }
@@ -568,10 +570,9 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
   auto sourceSharding = source.getSharding();
   auto targetSharding = target.getSharding();
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
-  auto shard =
-      reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
-              cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
-  return shard;
+  return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
+                 cast<TypedValue<ShapedType>>(source.getSrc()),
+                 sourceShardValue);
 }
 
 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,

>From 1cce571499b1ee67b28be7aef9c10fce2d1b0596 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 18 Oct 2024 17:40:38 +0200
Subject: [PATCH 06/26] fixing resharding

---
 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp |  8 ++++----
 mlir/test/Dialect/Mesh/spmdization.mlir          | 16 ++++++++++++++++
 2 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 6d7d036b2ea2e4..89ba282731c44a 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -476,7 +476,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
     auto initOprnd = builder.create<tensor::InsertSliceOp>(
         sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
         tgtCoreOffs, coreShape, strides);
-    auto targetShard = builder.create<UpdateHaloOp>(
+    auto updateHaloResult = builder.create<UpdateHaloOp>(
         sourceShard.getLoc(),
         RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
         sourceShard, initOprnd, mesh.getSymName(),
@@ -485,9 +485,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
         sourceSharding.getDynamicHaloSizes(),
         sourceSharding.getStaticHaloSizes(),
         targetSharding.getDynamicHaloSizes(),
-        targetSharding.getStaticHaloSizes());
+        targetSharding.getStaticHaloSizes()).getResult();
     return std::make_tuple(
-        cast<TypedValue<ShapedType>>(targetShard.getResult()), targetSharding);
+        cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding);
   }
   return std::nullopt;
 }
@@ -710,7 +710,7 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
   } else {
     // Insert resharding.
     TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
-        spmdizationMap.lookup(srcShardOp.getSrc()));
+        spmdizationMap.lookup(srcShardOp));
     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
                               symbolTableCollection);
   }
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 8b0c4053b0dc7e..16224441ffbb7d 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -219,3 +219,19 @@ func.func @ew_chain_with_halo(
   // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
   return %sharding_annotated_6 : tensor<8x16xf32>
 }
+
+// CHECK-LABEL: func @test_shard_update_halo
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
+func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
+  // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
+  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
+  // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64>
+  %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
+  %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
+  %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
+  %sharding_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [298, 598, 898, 1000] : !mesh.sharding
+  %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+  // CHECK: return %[[UH]] : tensor<304x1200xi64>
+  return %sharding_annotated_3 : tensor<1200x1200xi64>
+}

>From a0aa3ebccd8bbc7f12d242995fed39ab242906dd Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 30 Oct 2024 13:03:34 +0100
Subject: [PATCH 07/26] adding test update_halo 2d

---
 mlir/test/Dialect/Mesh/spmdization.mlir | 17 ++++++++++++++++-
 1 file changed, 16 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 16224441ffbb7d..22ddb72569835d 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -230,8 +230,23 @@ func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1
   %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
   %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
   %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
-  %sharding_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [298, 598, 898, 1000] : !mesh.sharding
   %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
   // CHECK: return %[[UH]] : tensor<304x1200xi64>
   return %sharding_annotated_3 : tensor<1200x1200xi64>
 }
+
+mesh.mesh @mesh4x4(shape = 4x4)
+// CHECK-LABEL: func @test_shard_update_halo2d
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
+func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
+  // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
+  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
+  // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64>
+  %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
+  %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
+  %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
+  %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+  // CHECK: return %[[UH]] : tensor<303x307xi64>
+  return %sharding_annotated_3 : tensor<1200x1200xi64>
+}
\ No newline at end of file

>From 13c590c515a34ae699baac8c386d6749aace778e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 30 Oct 2024 15:03:19 +0100
Subject: [PATCH 08/26] comments/docs

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td     |  8 +++++---
 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 16 ++++++++++++++++
 2 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 32af92b99bcd71..04b4b55a433803 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -1066,18 +1066,20 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
   let summary = "Update halo data.";
   let description = [{
     This operation updates halo regions of shards, e.g. if their sharding
-    specified halos and the actual tensor data might have changed
+    specified halos and the actual tensor/memref data might have changed
     on the remote devices. Changes might be caused by mutating operations
     and/or if the new halo regions are larger than the existing ones.
 
+    Source and destination might have different halo sizes.
+
     Assumes all devices hold tensors with same-sized halo data as specified
     by `source_halo_sizes/static_source_halo_sizes` and
-    `destination_halo_sizes/static_destination_halo_sizes`
+    `destination_halo_sizes/static_destination_halo_sizes` in source shard
+    and destination/result shard.
 
     `split_axes` specifies for each tensor axis along which mesh axes its halo
     data is updated.
 
-    Source and destination might have different halo sizes.
   }];
   let arguments = (ins
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 89ba282731c44a..6a25f18e6afc5e 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -429,12 +429,18 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   return std::nullopt;
 }
 
+// Detect a change in the halo size (only) and create necessary operations if
+// needed. A changed halo sizes requires copying the "core" of the source tensor
+// into the "core" of the destination tensor followed by an update halo
+// operation.
 static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
 tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                           MeshSharding sourceSharding,
                           MeshSharding targetSharding,
                           ShapedType sourceUnshardedShape,
                           TypedValue<ShapedType> sourceShard) {
+  // currently handles only cases where halo sizes differ but everything else
+  // stays the same (from source to destination sharding)
   if (sourceSharding.equalSplitAndPartialAxes(targetSharding) &&
       sourceSharding.getPartialAxes().empty() &&
       targetSharding.getPartialAxes().empty() &&
@@ -454,6 +460,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
     SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
         strides(rank, 1), outShape(sourceShard.getType().getShape()),
         coreShape(sourceShard.getType().getShape());
+
+    // determine "core" of source and destination
+    // the core is the local part of the shard excluding halo regions
     for (auto i = 0u; i < rank; ++i) {
       if (i < splitAxes.size() && !splitAxes[i].empty()) {
         if (!srcHaloSizes.empty()) {
@@ -465,6 +474,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
             coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
       }
     }
+
+    // extract core from source and copy into destination core
     auto noVals = ValueRange{};
     auto initVal = builder.create<tensor::EmptyOp>(
         sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
@@ -476,6 +487,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
     auto initOprnd = builder.create<tensor::InsertSliceOp>(
         sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
         tgtCoreOffs, coreShape, strides);
+
+    // finally update the halo
     auto updateHaloResult = builder.create<UpdateHaloOp>(
         sourceShard.getLoc(),
         RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
@@ -546,10 +559,13 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                MeshSharding targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
+  // If source and destination sharding are the same, no need to do anything.
   if (sourceSharding == targetSharding) {
     return sourceShard;
   }
 
+  // tries to handle the case where the resharding is needed because the halo
+  // sizes are different. Supports arbitrary mesh dimensionality.
   if (auto tryRes = tryUpdateHaloInResharding(
           builder, mesh, sourceSharding, targetSharding,
           sourceUnshardedValue.getType(), sourceShard)) {

>From 8b209ce10e9495a75f0a0a04c58a939ab687a841 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 30 Oct 2024 15:38:39 +0100
Subject: [PATCH 09/26] clang-format

---
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 32 +++++++++++--------
 1 file changed, 18 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 6a25f18e6afc5e..1444eec0f727ff 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -489,18 +489,22 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
         tgtCoreOffs, coreShape, strides);
 
     // finally update the halo
-    auto updateHaloResult = builder.create<UpdateHaloOp>(
-        sourceShard.getLoc(),
-        RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
-        sourceShard, initOprnd, mesh.getSymName(),
-        MeshAxesArrayAttr::get(builder.getContext(),
-                               sourceSharding.getSplitAxes()),
-        sourceSharding.getDynamicHaloSizes(),
-        sourceSharding.getStaticHaloSizes(),
-        targetSharding.getDynamicHaloSizes(),
-        targetSharding.getStaticHaloSizes()).getResult();
-    return std::make_tuple(
-        cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding);
+    auto updateHaloResult =
+        builder
+            .create<UpdateHaloOp>(
+                sourceShard.getLoc(),
+                RankedTensorType::get(outShape,
+                                      sourceShard.getType().getElementType()),
+                sourceShard, initOprnd, mesh.getSymName(),
+                MeshAxesArrayAttr::get(builder.getContext(),
+                                       sourceSharding.getSplitAxes()),
+                sourceSharding.getDynamicHaloSizes(),
+                sourceSharding.getStaticHaloSizes(),
+                targetSharding.getDynamicHaloSizes(),
+                targetSharding.getStaticHaloSizes())
+            .getResult();
+    return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
+                           targetSharding);
   }
   return std::nullopt;
 }
@@ -725,8 +729,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
     targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
   } else {
     // Insert resharding.
-    TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
-        spmdizationMap.lookup(srcShardOp));
+    TypedValue<ShapedType> srcSpmdValue =
+        cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
                               symbolTableCollection);
   }

>From 1991285f9b0f6da9cbb3ac2b695bb609a13c031d Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 4 Nov 2024 11:09:53 +0100
Subject: [PATCH 10/26] formatting

Co-authored-by: Matteo Franciolini <m.franciolini at icloud.com>
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 04b4b55a433803..47008cb72119ba 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -198,7 +198,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     `?` indicates dynamic halo sizes.
     
     6. [Optional] Offsets for each shard and sharded tensor dimension.
-    `sharded_dims_offsets` is provided as a flattened 1d array of i64s:
+    `sharded_dims_offsets` is provided as a flattened 1d array of i64s.
     For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
     The offset of each first shard is omitted and is implicitly assumed to be 0.
     The last value per dimension denotes the end of the last shard.

>From 130c0e00eb0dfa80bec8a71697cd13f173a7429e Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 4 Nov 2024 11:15:02 +0100
Subject: [PATCH 11/26] Formatting

Co-authored-by: Matteo Franciolini <m.franciolini at icloud.com>
---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index d65f7e4bbadd1a..14371314b12d5b 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -204,7 +204,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
         auto sz = shardedDimsOffsets[pos];
         bool same = !ShapedType::isDynamicShape(meshShape);
         if (same) {
-          // find sharded dims in shardedDimsOffsets with same static size on
+          // Find sharded dims in shardedDimsOffsets with same static size on
           // all devices. Use kDynamic for dimensions with dynamic or
           // non-uniform offs in shardedDimsOffsets.
           uint64_t numShards = 0;

>From 942bddd18aadc747a4ffc4301baed9b6148bab39 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 4 Nov 2024 11:16:13 +0100
Subject: [PATCH 12/26] cut&paste error

Co-authored-by: Chengji Yao <yaochengji at hotmail.com>
---
 mlir/test/Dialect/Mesh/invalid.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index f8abc6959a13a4..148cf474a9aef3 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -89,7 +89,7 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
 // -----
 
 func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
-  // expected-error at +1 {{halo sizes and shard shapes are mutually exclusive}}
+  // expected-error at +1 {{halo sizes and shard offsets are mutually exclusive}}
   %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [2, 2] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return

>From fa7d412dd5411cafd94f2da30ab39f0a90b82518 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 4 Nov 2024 12:02:54 +0100
Subject: [PATCH 13/26] cut&paste error in doc

Co-authored-by: Matteo Franciolini <m.franciolini at icloud.com>
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 47008cb72119ba..4490f93e33d16e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -243,7 +243,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
     // and it has pre-defined shard sizes. The shards of the devices will have
     // the following shapes: [4x2, 4x3, 4x4, 4x5]
-    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_offsets = [2, 5, 9, 14]
+    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [2, 5, 9, 14]
     %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];

>From 86945fd0aae325651459eae28c9e46d6e923e4d4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 4 Nov 2024 12:08:41 +0100
Subject: [PATCH 14/26] Betetr capitalization and punctuation in comments.

---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp             |  2 +-
 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 16 ++++++++--------
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 14371314b12d5b..f993f10e60ec6a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -511,7 +511,7 @@ LogicalResult ShardingOp::verify() {
     return failure();
 
   if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
-    return emitOpError("halo sizes and shard shapes are mutually exclusive");
+    return emitOpError("halo sizes and shard offsets are mutually exclusive");
   }
 
   if (!getStaticHaloSizes().empty()) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 1444eec0f727ff..3ce60777bf7300 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -126,7 +126,7 @@ static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
 }
 
 // Split a replicated tensor along a mesh axis.
-// e.g. [[0, 1]] -> [[0, 1, 2]].
+// E.g. [[0, 1]] -> [[0, 1, 2]].
 // Returns the spmdized target value with its sharding.
 static std::tuple<TypedValue<ShapedType>, MeshSharding>
 splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
@@ -439,8 +439,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                           MeshSharding targetSharding,
                           ShapedType sourceUnshardedShape,
                           TypedValue<ShapedType> sourceShard) {
-  // currently handles only cases where halo sizes differ but everything else
-  // stays the same (from source to destination sharding)
+  // Currently handles only cases where halo sizes differ but everything else
+  // stays the same (from source to destination sharding).
   if (sourceSharding.equalSplitAndPartialAxes(targetSharding) &&
       sourceSharding.getPartialAxes().empty() &&
       targetSharding.getPartialAxes().empty() &&
@@ -461,8 +461,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
         strides(rank, 1), outShape(sourceShard.getType().getShape()),
         coreShape(sourceShard.getType().getShape());
 
-    // determine "core" of source and destination
-    // the core is the local part of the shard excluding halo regions
+    // Determine "core" of source and destination.
+    // The core is the local part of the shard excluding halo regions.
     for (auto i = 0u; i < rank; ++i) {
       if (i < splitAxes.size() && !splitAxes[i].empty()) {
         if (!srcHaloSizes.empty()) {
@@ -475,7 +475,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
       }
     }
 
-    // extract core from source and copy into destination core
+    // Extract core from source and copy into destination core.
     auto noVals = ValueRange{};
     auto initVal = builder.create<tensor::EmptyOp>(
         sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
@@ -488,7 +488,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
         sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
         tgtCoreOffs, coreShape, strides);
 
-    // finally update the halo
+    // Finally update the halo.
     auto updateHaloResult =
         builder
             .create<UpdateHaloOp>(
@@ -568,7 +568,7 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
     return sourceShard;
   }
 
-  // tries to handle the case where the resharding is needed because the halo
+  // Tries to handle the case where the resharding is needed because the halo
   // sizes are different. Supports arbitrary mesh dimensionality.
   if (auto tryRes = tryUpdateHaloInResharding(
           builder, mesh, sourceSharding, targetSharding,

>From 17a96edeb3dc42e30e7441fe456cb5bd869ca834 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 4 Nov 2024 12:15:42 +0100
Subject: [PATCH 15/26] early return and reduce the indentation block

---
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 125 +++++++++---------
 1 file changed, 62 insertions(+), 63 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 3ce60777bf7300..b4d088cbd7088d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -441,72 +441,71 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                           TypedValue<ShapedType> sourceShard) {
   // Currently handles only cases where halo sizes differ but everything else
   // stays the same (from source to destination sharding).
-  if (sourceSharding.equalSplitAndPartialAxes(targetSharding) &&
-      sourceSharding.getPartialAxes().empty() &&
-      targetSharding.getPartialAxes().empty() &&
-      sourceSharding.getStaticShardedDimsOffsets().empty() &&
-      targetSharding.getStaticShardedDimsOffsets().empty() &&
-      !sourceSharding.equalHaloSizes(targetSharding)) {
-    auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
-    auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
-    assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
-    assert(
-        ((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
-         !ShapedType::isDynamicShape(tgtHaloSizes) &&
-         sourceShard.getType().hasStaticShape()) &&
-        "dynamic shapes/halos are not supported yet for mesh-spmdization");
-    auto rank = sourceShard.getType().getRank();
-    auto splitAxes = sourceSharding.getSplitAxes();
-    SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
-        strides(rank, 1), outShape(sourceShard.getType().getShape()),
-        coreShape(sourceShard.getType().getShape());
-
-    // Determine "core" of source and destination.
-    // The core is the local part of the shard excluding halo regions.
-    for (auto i = 0u; i < rank; ++i) {
-      if (i < splitAxes.size() && !splitAxes[i].empty()) {
-        if (!srcHaloSizes.empty()) {
-          coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
-          srcCoreOffs[i] = srcHaloSizes[i * 2];
-        }
-        tgtCoreOffs[i] = tgtHaloSizes[i * 2];
-        outShape[i] =
-            coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
+  if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
+      !sourceSharding.getPartialAxes().empty() ||
+      !targetSharding.getPartialAxes().empty() ||
+      !sourceSharding.getStaticShardedDimsOffsets().empty() ||
+      !targetSharding.getStaticShardedDimsOffsets().empty() ||
+      sourceSharding.equalHaloSizes(targetSharding)) {
+    return std::nullopt;
+  }
+
+  auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
+  auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
+  assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
+  assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
+          !ShapedType::isDynamicShape(tgtHaloSizes) &&
+          sourceShard.getType().hasStaticShape()) &&
+         "dynamic shapes/halos are not supported yet for mesh-spmdization");
+  auto rank = sourceShard.getType().getRank();
+  auto splitAxes = sourceSharding.getSplitAxes();
+  SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
+      strides(rank, 1), outShape(sourceShard.getType().getShape()),
+      coreShape(sourceShard.getType().getShape());
+
+  // Determine "core" of source and destination.
+  // The core is the local part of the shard excluding halo regions.
+  for (auto i = 0u; i < rank; ++i) {
+    if (i < splitAxes.size() && !splitAxes[i].empty()) {
+      if (!srcHaloSizes.empty()) {
+        coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
+        srcCoreOffs[i] = srcHaloSizes[i * 2];
       }
+      tgtCoreOffs[i] = tgtHaloSizes[i * 2];
+      outShape[i] =
+          coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
     }
-
-    // Extract core from source and copy into destination core.
-    auto noVals = ValueRange{};
-    auto initVal = builder.create<tensor::EmptyOp>(
-        sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
-    auto core = builder.create<tensor::ExtractSliceOp>(
-        sourceShard.getLoc(),
-        RankedTensorType::get(coreShape,
-                              sourceShard.getType().getElementType()),
-        sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
-    auto initOprnd = builder.create<tensor::InsertSliceOp>(
-        sourceShard.getLoc(), core, initVal, noVals, noVals, noVals,
-        tgtCoreOffs, coreShape, strides);
-
-    // Finally update the halo.
-    auto updateHaloResult =
-        builder
-            .create<UpdateHaloOp>(
-                sourceShard.getLoc(),
-                RankedTensorType::get(outShape,
-                                      sourceShard.getType().getElementType()),
-                sourceShard, initOprnd, mesh.getSymName(),
-                MeshAxesArrayAttr::get(builder.getContext(),
-                                       sourceSharding.getSplitAxes()),
-                sourceSharding.getDynamicHaloSizes(),
-                sourceSharding.getStaticHaloSizes(),
-                targetSharding.getDynamicHaloSizes(),
-                targetSharding.getStaticHaloSizes())
-            .getResult();
-    return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
-                           targetSharding);
   }
-  return std::nullopt;
+
+  // Extract core from source and copy into destination core.
+  auto noVals = ValueRange{};
+  auto initVal = builder.create<tensor::EmptyOp>(
+      sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
+  auto core = builder.create<tensor::ExtractSliceOp>(
+      sourceShard.getLoc(),
+      RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
+      sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
+  auto initOprnd = builder.create<tensor::InsertSliceOp>(
+      sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
+      coreShape, strides);
+
+  // Finally update the halo.
+  auto updateHaloResult =
+      builder
+          .create<UpdateHaloOp>(
+              sourceShard.getLoc(),
+              RankedTensorType::get(outShape,
+                                    sourceShard.getType().getElementType()),
+              sourceShard, initOprnd, mesh.getSymName(),
+              MeshAxesArrayAttr::get(builder.getContext(),
+                                     sourceSharding.getSplitAxes()),
+              sourceSharding.getDynamicHaloSizes(),
+              sourceSharding.getStaticHaloSizes(),
+              targetSharding.getDynamicHaloSizes(),
+              targetSharding.getStaticHaloSizes())
+          .getResult();
+  return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
+                         targetSharding);
 }
 
 // Handles only resharding on a 1D mesh.

>From fb6335f84164bda95297f677815f35e674c6a7ab Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 4 Nov 2024 12:22:59 +0100
Subject: [PATCH 16/26] doc FoldDynamicLists

---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index f993f10e60ec6a..57460882daa082 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -548,6 +548,10 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 }
 
 namespace {
+// Sharding annotations "halo sizes" and "sharded dims offsets"
+// are a mix of attributes and dynamic values. This canonicalization moves
+// constant values to the respective attribute lists and so minimizes the number
+// of values.
 class FoldDynamicLists final : public OpRewritePattern<ShardingOp> {
 public:
   using OpRewritePattern<ShardingOp>::OpRewritePattern;

>From 07887db7b13cf0e9816816158e55a39c849d89ea Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 5 Nov 2024 13:13:08 +0100
Subject: [PATCH 17/26] sharded_dims_offsets with exclusive prefix sum at pos i
 for shard i

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 32 ++++++++++---------
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  2 +-
 mlir/test/Dialect/Mesh/canonicalization.mlir  |  4 +--
 mlir/test/Dialect/Mesh/invalid.mlir           |  4 +--
 mlir/test/Dialect/Mesh/ops.mlir               |  8 ++---
 .../test/Dialect/Tensor/mesh-spmdization.mlir | 10 +++---
 6 files changed, 31 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 4490f93e33d16e..19498fe5a32d69 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -190,23 +190,25 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     `generic`: is not an allowed value inside a shard attribute.
 
     5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
-    `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
-    `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
-    halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
-    `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
-    e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
-    `?` indicates dynamic halo sizes.
+    `halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
+    sharded dimension. `halo_sizes = [1, 2]` means that the first sharded dimension
+    gets an additional halo of size 1 at the start of the first dimension and a halo
+    size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
+    sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
+    seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
     
     6. [Optional] Offsets for each shard and sharded tensor dimension.
-    `sharded_dims_offsets` is provided as a flattened 1d array of i64s.
-    For each sharded tensor dimension the offsets (starting index) of all shards in that dimension.
-    The offset of each first shard is omitted and is implicitly assumed to be 0.
-    The last value per dimension denotes the end of the last shard.
+    `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
+    sharded tensor dimension the offsets (starting index) of all shards in that
+    dimension and an additional value for the end of the last shard are provided.
+    For a 1d sharding this means that position `i` has the exclusive prefix sum for
+    shard `i`, and since only contiguous sharding is supported, its inclusive prefix
+    sum is at position 'i+1'.
+    
     Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
-    `sharded_dims_offsets` = [24, 32, 20, 32] means that the first device of
-    the device-mesh will get a shard of shape 24x20x32 and the second device will get a
-    shard of shape 8x12x32.
-    `?` indicates dynamic shard dimensions.
+    `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
+    the device-mesh will get a shard of shape 24x20x32 and the second device will get
+    a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
     
     `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
 
@@ -243,7 +245,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
     // and it has pre-defined shard sizes. The shards of the devices will have
     // the following shapes: [4x2, 4x3, 4x4, 4x5]
-    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [2, 5, 9, 14]
+    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
     %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 57460882daa082..42ab14a19eaf58 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -198,7 +198,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
             llvm::adl_begin(outShape));
 
   if (!shardedDimsOffsets.empty()) {
-    uint64_t pos = 0;
+    uint64_t pos = 1;
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
       if (!innerSplitAxes.empty()) {
         auto sz = shardedDimsOffsets[pos];
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 0bd09056835d30..54d3e779d03727 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -204,7 +204,7 @@ func.func @test_halo_sizes() -> !mesh.sharding {
 // CHECK-LABEL: func @test_shard_offs
 func.func @test_shard_offs() -> !mesh.sharding {
   %c2_i64 = arith.constant 2 : i64
-  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [1, 2, 2, 22] : !mesh.sharding
-  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
+  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 0, 2, 22] : !mesh.sharding
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 0, %c2_i64, 22] : !mesh.sharding
   return %sharding : !mesh.sharding
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 148cf474a9aef3..08e5248bda1981 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -90,7 +90,7 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
 
 func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
   // expected-error at +1 {{halo sizes and shard offsets are mutually exclusive}}
-  %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [2, 2] : !mesh.sharding
+  %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
@@ -100,7 +100,7 @@ func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
 mesh.mesh @mesh_dyn(shape = ?x?)
 func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
   // expected-error at +1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}}
-  %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [2, 2] : !mesh.sharding
+  %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
   %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 3b0ef39a6e212e..d8df01c3d6520d 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -144,10 +144,10 @@ func.func @mesh_shard_halo_sizes() -> () {
 func.func @mesh_shard_dims_sizes() -> () {
   // CHECK: %[[C3:.*]] = arith.constant 3 : i64
   %c3 = arith.constant 3 : i64
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 4, 2] : !mesh.sharding
-  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [1, 4, 2] : !mesh.sharding
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [4, %[[C3]], 1] : !mesh.sharding
-  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [4, %c3, 1] : !mesh.sharding
+  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
+  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
+  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !mesh.sharding
+  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !mesh.sharding
   return
 }
 
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index fba9e17b53934a..5443eea83aa2d8 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -7,9 +7,9 @@ mesh.mesh @mesh_1d_4(shape = 4)
 // CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
 func.func @tensor_empty_static_sharded_dims_offsets() -> () {
   %b = tensor.empty() : tensor<8x16xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
-  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
+  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
   // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
@@ -21,9 +21,9 @@ func.func @tensor_empty_static_sharded_dims_offsets() -> () {
 // CHECK-SAME: %[[A0:.*]]: index
 func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
   %b = tensor.empty(%arg0) : tensor<8x?xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
-  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 4, 7, 8] : !mesh.sharding
+  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
   // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
@@ -34,7 +34,7 @@ func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
 // CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
 func.func @tensor_empty_same_static_dims_sizes() -> () {
   %b = tensor.empty() : tensor<16x16xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [4, 8, 12, 16] : !mesh.sharding
+  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<16x16xf32>
   // CHECK-NEXT:  tensor.empty() : tensor<4x16xf32>
 

>From b6ff1a8a1ed993441f83617a240440d5144c7eca Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 5 Nov 2024 13:53:44 +0100
Subject: [PATCH 18/26] checking for non-decreasing dims sizes in sharding

---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp         | 33 ++++++++++++++++++--
 mlir/test/Dialect/Mesh/canonicalization.mlir |  4 +--
 mlir/test/Dialect/Mesh/invalid.mlir          | 20 ++++++++++++
 3 files changed, 53 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 42ab14a19eaf58..c5570d8ee8a443 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -198,11 +198,12 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
             llvm::adl_begin(outShape));
 
   if (!shardedDimsOffsets.empty()) {
+    auto isDynShape = ShapedType::isDynamicShape(meshShape);
     uint64_t pos = 1;
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
       if (!innerSplitAxes.empty()) {
         auto sz = shardedDimsOffsets[pos];
-        bool same = !ShapedType::isDynamicShape(meshShape);
+        bool same = !isDynShape;
         if (same) {
           // Find sharded dims in shardedDimsOffsets with same static size on
           // all devices. Use kDynamic for dimensions with dynamic or
@@ -218,7 +219,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
               break;
             }
           }
-          pos += numShards;
+          pos += numShards + 1;
         }
         outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
       }
@@ -544,6 +545,34 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return emitError() << "sharded dims offsets are not allowed for "
                           "devices meshes with dynamic shape.";
   }
+
+  auto shardedDimsOffsets = getStaticShardedDimsOffsets();
+  if (!shardedDimsOffsets.empty()) {
+    auto meshShape = mesh.value().getShape();
+    assert(!ShapedType::isDynamicShape(meshShape));
+    uint64_t pos = 0;
+    for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
+      if (!innerSplitAxes.empty()) {
+        int64_t numShards = 0, off = 0;
+        for (auto i : innerSplitAxes.asArrayRef()) {
+          numShards += meshShape[i];
+        }
+        for (int64_t i = 0; i <= numShards; ++i) {
+          if (shardedDimsOffsets.size() <= pos + i) {
+            return emitError() << "sharded dims offsets has wrong size.";
+          }
+          if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
+            if (shardedDimsOffsets[pos + i] < off) {
+              return emitError()
+                     << "sharded dims offsets must be non-decreasing.";
+            }
+            off = shardedDimsOffsets[pos + i];
+          }
+        }
+        pos += numShards + 1;
+      }
+    }
+  }
   return success();
 }
 
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 54d3e779d03727..f0112d689805d3 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -204,7 +204,7 @@ func.func @test_halo_sizes() -> !mesh.sharding {
 // CHECK-LABEL: func @test_shard_offs
 func.func @test_shard_offs() -> !mesh.sharding {
   %c2_i64 = arith.constant 2 : i64
-  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 0, 2, 22] : !mesh.sharding
-  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 0, %c2_i64, 22] : !mesh.sharding
+  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
+  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
   return %sharding : !mesh.sharding
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 08e5248bda1981..29b900a8da4a60 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -107,6 +107,26 @@ func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
 
 // -----
 
+mesh.mesh @mesh0(shape = 2x4)
+func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) {
+  // expected-error at +1 {{sharded dims offsets has wrong size}}
+  %s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  return
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 4)
+func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) {
+  // expected-error at +1 {{sharded dims offsets must be non-decreasing}}
+  %s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  return
+}
+
+// -----
+
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) {

>From 827c300974e5b7071ea08dba1fb7c119cda744c9 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 14 Aug 2024 19:29:23 +0200
Subject: [PATCH 19/26] initial hack lowering mesh.update_halo to MPI

---
 .../mlir/Conversion/MeshToMPI/MeshToMPI.h     |  27 +++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  17 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  33 ++++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 mlir/lib/Conversion/MeshToMPI/CMakeLists.txt  |  22 +++
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 171 ++++++++++++++++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  19 ++
 .../MeshToMPI/convert-mesh-to-mpi.mlir        |  34 ++++
 9 files changed, 325 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
 create mode 100644 mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
 create mode 100644 mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir

diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
new file mode 100644
index 00000000000000..6a2c196da45577
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -0,0 +1,27 @@
+//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Lowers Mesh communication operations (updateHalo, AllGater, ...)
+/// to MPI primitives.
+std::unique_ptr<Pass> createConvertMeshToMPIPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 2ab32836c80b1c..b577aa83946f23 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -51,6 +51,7 @@
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 4d272ba219c6f1..83e0c5a06c43f7 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -878,6 +878,23 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MeshToMPI
+//===----------------------------------------------------------------------===//
+
+def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
+  let summary = "Convert Mesh dialect to MPI dialect.";
+  let description = [{
+    This pass converts communication operations
+    from the Mesh dialect to operations from the MPI dialect.
+  }];
+  let dependentDialects = [
+    "memref::MemRefDialect",
+    "mpi::MPIDialect",
+    "scf::SCFDialect"
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 19498fe5a32d69..2c2b6e20f3654d 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -156,6 +156,39 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
   ];
 }
 
+def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary =
+      "For given split axes get the linear index the direct neighbor processes.";
+  let description = [{
+    Example:
+    ```
+    %idx = mesh.neighbor_linear_index on @mesh for $device 
+               split_axes = $split_axes : index
+    ```
+    Given `@mesh` with shape `(10, 20, 30)`,
+          `device` = `(1, 2, 3)`
+          `$split_axes` = `[1]`
+    it returns the linear indices of the processes at positions `(1, 1, 3)`: `633`
+    and `(1, 3, 3)`: `693`.
+
+    A negative value is returned if `$device` has no neighbor in the given
+    direction along the given `split_axes`.
+  }];
+  let arguments = (ins FlatSymbolRefAttr:$mesh,
+                       Variadic<Index>:$device,
+                       Mesh_MeshAxesAttr:$split_axes);
+  let results = (outs Index:$neighbor_down, Index:$neighbor_up);
+  let assemblyFormat =  [{
+      `on` $mesh `[` $device `]`
+      `split_axes` `=` $split_axes
+      attr-dict `:` type(results)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Sharding operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 6651d87162257f..62461c0cea08af 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
+add_subdirectory(MeshToMPI)
 add_subdirectory(NVGPUToNVVM)
 add_subdirectory(NVVMToLLVM)
 add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
new file mode 100644
index 00000000000000..95815a683f6d6a
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMeshToMPI
+  MeshToMPI.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRIR
+  MLIRLinalgTransforms
+  MLIRMemRefDialect
+  MLIRPass
+  MLIRMeshDialect
+  MLIRMPIDialect
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
new file mode 100644
index 00000000000000..b4cf9da8497a2d
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -0,0 +1,171 @@
+//===- MeshToMPI.cpp - Mesh to MPI  dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Mesh communicatin ops tp MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+#define DEBUG_TYPE "mesh-to-mpi"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+struct ConvertMeshToMPIPass
+    : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+  using Base::Base;
+
+  /// Run the dialect converter on the module.
+  void runOnOperation() override {
+    getOperation()->walk([&](UpdateHaloOp op) {
+      SymbolTableCollection symbolTableCollection;
+      OpBuilder builder(op);
+      auto loc = op.getLoc();
+
+      auto toValue = [&builder, &loc](OpFoldResult &v) {
+        return v.is<Value>()
+                   ? v.get<Value>()
+                   : builder.create<::mlir::arith::ConstantOp>(
+                         loc,
+                         builder.getIndexAttr(
+                             cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+      };
+
+      auto array = op.getInput();
+      auto rank = array.getType().getRank();
+      auto mesh = op.getMesh();
+      auto meshOp = getMesh(op, symbolTableCollection);
+      auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+                                      op.getDynamicHaloSizes(), builder);
+      for (auto &sz : haloSizes) {
+        if (sz.is<Value>()) {
+          sz = builder
+                   .create<arith::IndexCastOp>(loc, builder.getIndexType(),
+                                               sz.get<Value>())
+                   .getResult();
+        }
+      }
+
+      SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+      SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
+      SmallVector<OpFoldResult> shape(rank);
+      for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+        if (ShapedType::isDynamic(s)) {
+          shape[i] = builder.create<memref::DimOp>(loc, array, s).getResult();
+        } else {
+          shape[i] = builder.getIndexAttr(s);
+        }
+      }
+
+      auto tagAttr = builder.getI32IntegerAttr(91); // whatever
+      auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+      auto zeroAttr = builder.getI32IntegerAttr(0); // whatever
+      auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+      SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+                                         builder.getIndexType());
+      auto myMultiIndex =
+          builder.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+              .getResult();
+      auto currHaloDim = 0;
+
+      for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+        if (!splitAxes.empty()) {
+          auto tmp = builder
+                         .create<NeighborsLinearIndicesOp>(
+                             loc, mesh, myMultiIndex, splitAxes)
+                         .getResults();
+          Value neighbourIDs[2] = {builder.create<arith::IndexCastOp>(
+                                       loc, builder.getI32Type(), tmp[0]),
+                                   builder.create<arith::IndexCastOp>(
+                                       loc, builder.getI32Type(), tmp[1])};
+          auto orgDimSize = shape[dim];
+          auto upperOffset = builder.create<arith::SubIOp>(
+              loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1]));
+
+          // make sure we send/recv in a way that does not lead to a dead-lock
+          // This is by far not optimal, this should be at least MPI_sendrecv
+          // and - probably even more importantly - buffers should be re-used
+          // Currently using temporary, contiguous buffer for MPI communication
+          auto genSendRecv = [&](auto dim, bool upperHalo) {
+            auto orgOffset = offsets[dim];
+            shape[dim] =
+                upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2];
+            auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+            auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+            auto hasFrom = builder.create<arith::CmpIOp>(
+                loc, arith::CmpIPredicate::sge, from, zero);
+            auto hasTo = builder.create<arith::CmpIOp>(
+                loc, arith::CmpIPredicate::sge, to, zero);
+            auto buffer = builder.create<memref::AllocOp>(
+                loc, shape, array.getType().getElementType());
+            builder.create<scf::IfOp>(
+                loc, hasTo, [&](OpBuilder &builder, Location loc) {
+                  offsets[dim] = upperHalo
+                                     ? OpFoldResult(builder.getIndexAttr(0))
+                                     : OpFoldResult(upperOffset);
+                  auto subview = builder.create<memref::SubViewOp>(
+                      loc, array, offsets, shape, strides);
+                  builder.create<memref::CopyOp>(loc, subview, buffer);
+                  builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag,
+                                              to);
+                  builder.create<scf::YieldOp>(loc);
+                });
+            builder.create<scf::IfOp>(
+                loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+                  offsets[dim] = upperHalo
+                                     ? OpFoldResult(upperOffset)
+                                     : OpFoldResult(builder.getIndexAttr(0));
+                  builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
+                                              from);
+                  auto subview = builder.create<memref::SubViewOp>(
+                      loc, array, offsets, shape, strides);
+                  builder.create<memref::CopyOp>(loc, buffer, subview);
+                  builder.create<scf::YieldOp>(loc);
+                });
+            builder.create<memref::DeallocOp>(loc, buffer);
+            offsets[dim] = orgOffset;
+          };
+
+          genSendRecv(dim, false);
+          genSendRecv(dim, true);
+
+          shape[dim] = builder
+                           .create<arith::SubIOp>(
+                               loc, toValue(orgDimSize),
+                               builder
+                                   .create<arith::AddIOp>(
+                                       loc, toValue(haloSizes[dim * 2]),
+                                       toValue(haloSizes[dim * 2 + 1]))
+                                   .getResult())
+                           .getResult();
+          offsets[dim] = haloSizes[dim * 2];
+          ++currHaloDim;
+        }
+      }
+    });
+  }
+};
+} // namespace
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c5570d8ee8a443..33460ff25e9e45 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -837,6 +837,25 @@ void ProcessLinearIndexOp::getAsmResultNames(
   setNameFn(getResult(), "proc_linear_idx");
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.neighbors_linear_indices op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return success();
+}
+
+void NeighborsLinearIndicesOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getNeighborDown(), "down_linear_idx");
+  setNameFn(getNeighborUp(), "up_linear_idx");
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
new file mode 100644
index 00000000000000..9ef826ca0cdace
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s
+
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 2x2x4)
+
+// -----
+
+// CHECK-LABEL: func @update_halo
+func.func @update_halo_1d(
+    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+    %arg0 : memref<12x12xi8>) {
+  // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
+  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  // CHECK-SAME: split_axes = {{\[\[}}0]]
+  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
+  %c2 = arith.constant 2 : i64
+  mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+    halo_sizes = [2, %c2] : memref<12x12xi8>
+  return
+}
+
+func.func @update_halo_2d(
+    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+    %arg0 : memref<12x12xi8>) {
+  %c2 = arith.constant 2 : i64
+  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
+  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
+  // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
+  mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
+    halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
+    : memref<12x12xi8>
+  return
+}

>From 80f4b5f1e1eef84419c2aa0bec6cceaaf01f25dc Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 16 Aug 2024 10:55:28 +0200
Subject: [PATCH 20/26] dim fixes, proper testing

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 306 ++++++++++--------
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 179 ++++++++--
 2 files changed, 339 insertions(+), 146 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b4cf9da8497a2d..42d885a109ee79 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements a translation of Mesh communicatin ops tp MPI ops.
+// This file implements a translation of Mesh communication ops tp MPI ops.
 //
 //===----------------------------------------------------------------------===//
 
@@ -21,6 +21,8 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "mesh-to-mpi"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -34,138 +36,190 @@ using namespace mlir;
 using namespace mlir::mesh;
 
 namespace {
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertUpdateHaloOp
+    : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::mesh::UpdateHaloOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    // Halos are exchanged as 2 blocks per dimension (one for each side: down
+    // and up). It is assumed that the last dim in a default memref is
+    // contiguous, hence iteration starts with the complete halo on the first
+    // dim which should be contiguous (unless the source is not). The size of
+    // the exchanged data will decrease when iterating over dimensions. That's
+    // good because the halos of last dim will be most fragmented.
+    // memref.subview is used to read and write the halo data from and to the
+    // local data. subviews and halos have dynamic and static values, so
+    // OpFoldResults are used whenever possible.
+
+    SymbolTableCollection symbolTableCollection;
+    auto loc = op.getLoc();
+
+    // convert a OpFoldResult into a Value
+    auto toValue = [&rewriter, &loc](OpFoldResult &v) {
+      return v.is<Value>()
+                 ? v.get<Value>()
+                 : rewriter.create<::mlir::arith::ConstantOp>(
+                       loc,
+                       rewriter.getIndexAttr(
+                           cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+    };
+
+    auto array = op.getInput();
+    auto rank = array.getType().getRank();
+    auto mesh = op.getMesh();
+    auto meshOp = getMesh(op, symbolTableCollection);
+    auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+                                    op.getDynamicHaloSizes(), rewriter);
+    // subviews need Index values
+    for (auto &sz : haloSizes) {
+      if (sz.is<Value>()) {
+        sz = rewriter
+                 .create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+                                             sz.get<Value>())
+                 .getResult();
+      }
+    }
+
+    // most of the offset/size/stride data is the same for all dims
+    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> shape(rank);
+    // we need the actual shape to compute offsets and sizes
+    for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+      if (ShapedType::isDynamic(s)) {
+        shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+      } else {
+        shape[i] = rewriter.getIndexAttr(s);
+      }
+    }
+
+    auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
+    auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+    auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
+    auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+    SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+                                       rewriter.getIndexType());
+    auto myMultiIndex =
+        rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+            .getResult();
+    // halo sizes are provided for split dimensions only
+    auto currHaloDim = 0;
+
+    for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+      if (splitAxes.empty()) {
+        continue;
+      }
+      // Get the linearized ids of the neighbors (down and up) for the
+      // given split
+      auto tmp = rewriter
+                     .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
+                                                       splitAxes)
+                     .getResults();
+      // MPI operates on i32...
+      Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
+                                   loc, rewriter.getI32Type(), tmp[0]),
+                               rewriter.create<arith::IndexCastOp>(
+                                   loc, rewriter.getI32Type(), tmp[1])};
+      // store for later
+      auto orgDimSize = shape[dim];
+      // this dim's offset to the start of the upper halo
+      auto upperOffset = rewriter.create<arith::SubIOp>(
+          loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
+
+      // Make sure we send/recv in a way that does not lead to a dead-lock.
+      // The current approach is by far not optimal, this should be at least
+      // be a red-black pattern or using MPI_sendrecv.
+      // Also, buffers should be re-used.
+      // Still using temporary contiguous buffers for MPI communication...
+      // Still yielding a "serialized" communication pattern...
+      auto genSendRecv = [&](auto dim, bool upperHalo) {
+        auto orgOffset = offsets[dim];
+        shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+                               : haloSizes[currHaloDim * 2];
+        // Check if we need to send and/or receive
+        // Processes on the mesh borders have only one neighbor
+        auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+        auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+        auto hasFrom = rewriter.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::sge, from, zero);
+        auto hasTo = rewriter.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::sge, to, zero);
+        auto buffer = rewriter.create<memref::AllocOp>(
+            loc, shape, array.getType().getElementType());
+        // if has neighbor: copy halo data from array to buffer and send
+        rewriter.create<scf::IfOp>(
+            loc, hasTo, [&](OpBuilder &builder, Location loc) {
+              offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
+                                       : OpFoldResult(upperOffset);
+              auto subview = builder.create<memref::SubViewOp>(
+                  loc, array, offsets, shape, strides);
+              builder.create<memref::CopyOp>(loc, subview, buffer);
+              builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
+              builder.create<scf::YieldOp>(loc);
+            });
+        // if has neighbor: receive halo data into buffer and copy to array
+        rewriter.create<scf::IfOp>(
+            loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+              offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
+                                       : OpFoldResult(builder.getIndexAttr(0));
+              builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
+              auto subview = builder.create<memref::SubViewOp>(
+                  loc, array, offsets, shape, strides);
+              builder.create<memref::CopyOp>(loc, buffer, subview);
+              builder.create<scf::YieldOp>(loc);
+            });
+        rewriter.create<memref::DeallocOp>(loc, buffer);
+        offsets[dim] = orgOffset;
+      };
+
+      genSendRecv(dim, false);
+      genSendRecv(dim, true);
+
+      // prepare shape and offsets for next split dim
+      auto _haloSz =
+          rewriter
+              .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+                                     toValue(haloSizes[currHaloDim * 2 + 1]))
+              .getResult();
+      // the shape for next halo excludes the halo on both ends for the
+      // current dim
+      shape[dim] =
+          rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
+              .getResult();
+      // the offsets for next halo starts after the down halo for the
+      // current dim
+      offsets[dim] = haloSizes[currHaloDim * 2];
+      // on to next halo
+      ++currHaloDim;
+    }
+    rewriter.eraseOp(op);
+    return mlir::success();
+  }
+};
+
 struct ConvertMeshToMPIPass
     : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
   using Base::Base;
 
   /// Run the dialect converter on the module.
   void runOnOperation() override {
-    getOperation()->walk([&](UpdateHaloOp op) {
-      SymbolTableCollection symbolTableCollection;
-      OpBuilder builder(op);
-      auto loc = op.getLoc();
-
-      auto toValue = [&builder, &loc](OpFoldResult &v) {
-        return v.is<Value>()
-                   ? v.get<Value>()
-                   : builder.create<::mlir::arith::ConstantOp>(
-                         loc,
-                         builder.getIndexAttr(
-                             cast<IntegerAttr>(v.get<Attribute>()).getInt()));
-      };
+    auto *ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
 
-      auto array = op.getInput();
-      auto rank = array.getType().getRank();
-      auto mesh = op.getMesh();
-      auto meshOp = getMesh(op, symbolTableCollection);
-      auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
-                                      op.getDynamicHaloSizes(), builder);
-      for (auto &sz : haloSizes) {
-        if (sz.is<Value>()) {
-          sz = builder
-                   .create<arith::IndexCastOp>(loc, builder.getIndexType(),
-                                               sz.get<Value>())
-                   .getResult();
-        }
-      }
-
-      SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
-      SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
-      SmallVector<OpFoldResult> shape(rank);
-      for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
-        if (ShapedType::isDynamic(s)) {
-          shape[i] = builder.create<memref::DimOp>(loc, array, s).getResult();
-        } else {
-          shape[i] = builder.getIndexAttr(s);
-        }
-      }
+    patterns.insert<ConvertUpdateHaloOp>(ctx);
 
-      auto tagAttr = builder.getI32IntegerAttr(91); // whatever
-      auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr);
-      auto zeroAttr = builder.getI32IntegerAttr(0); // whatever
-      auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
-      SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
-                                         builder.getIndexType());
-      auto myMultiIndex =
-          builder.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
-              .getResult();
-      auto currHaloDim = 0;
-
-      for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
-        if (!splitAxes.empty()) {
-          auto tmp = builder
-                         .create<NeighborsLinearIndicesOp>(
-                             loc, mesh, myMultiIndex, splitAxes)
-                         .getResults();
-          Value neighbourIDs[2] = {builder.create<arith::IndexCastOp>(
-                                       loc, builder.getI32Type(), tmp[0]),
-                                   builder.create<arith::IndexCastOp>(
-                                       loc, builder.getI32Type(), tmp[1])};
-          auto orgDimSize = shape[dim];
-          auto upperOffset = builder.create<arith::SubIOp>(
-              loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1]));
-
-          // make sure we send/recv in a way that does not lead to a dead-lock
-          // This is by far not optimal, this should be at least MPI_sendrecv
-          // and - probably even more importantly - buffers should be re-used
-          // Currently using temporary, contiguous buffer for MPI communication
-          auto genSendRecv = [&](auto dim, bool upperHalo) {
-            auto orgOffset = offsets[dim];
-            shape[dim] =
-                upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2];
-            auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
-            auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
-            auto hasFrom = builder.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::sge, from, zero);
-            auto hasTo = builder.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::sge, to, zero);
-            auto buffer = builder.create<memref::AllocOp>(
-                loc, shape, array.getType().getElementType());
-            builder.create<scf::IfOp>(
-                loc, hasTo, [&](OpBuilder &builder, Location loc) {
-                  offsets[dim] = upperHalo
-                                     ? OpFoldResult(builder.getIndexAttr(0))
-                                     : OpFoldResult(upperOffset);
-                  auto subview = builder.create<memref::SubViewOp>(
-                      loc, array, offsets, shape, strides);
-                  builder.create<memref::CopyOp>(loc, subview, buffer);
-                  builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag,
-                                              to);
-                  builder.create<scf::YieldOp>(loc);
-                });
-            builder.create<scf::IfOp>(
-                loc, hasFrom, [&](OpBuilder &builder, Location loc) {
-                  offsets[dim] = upperHalo
-                                     ? OpFoldResult(upperOffset)
-                                     : OpFoldResult(builder.getIndexAttr(0));
-                  builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
-                                              from);
-                  auto subview = builder.create<memref::SubViewOp>(
-                      loc, array, offsets, shape, strides);
-                  builder.create<memref::CopyOp>(loc, buffer, subview);
-                  builder.create<scf::YieldOp>(loc);
-                });
-            builder.create<memref::DeallocOp>(loc, buffer);
-            offsets[dim] = orgOffset;
-          };
-
-          genSendRecv(dim, false);
-          genSendRecv(dim, true);
-
-          shape[dim] = builder
-                           .create<arith::SubIOp>(
-                               loc, toValue(orgDimSize),
-                               builder
-                                   .create<arith::AddIOp>(
-                                       loc, toValue(haloSizes[dim * 2]),
-                                       toValue(haloSizes[dim * 2 + 1]))
-                                   .getResult())
-                           .getResult();
-          offsets[dim] = haloSizes[dim * 2];
-          ++currHaloDim;
-        }
-      }
-    });
+    (void)mlir::applyPatternsAndFoldGreedily(getOperation(),
+                                             std::move(patterns));
   }
 };
-} // namespace
\ No newline at end of file
+
+} // namespace
+
+// Create a pass that convert Mesh to MPI
+std::unique_ptr<::mlir::OperationPass<void>> createConvertMeshToMPIPass() {
+  return std::make_unique<ConvertMeshToMPIPass>();
+}
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 9ef826ca0cdace..5f563364272d96 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -1,34 +1,173 @@
-// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s
+// RUN: mlir-opt %s -convert-mesh-to-mpi | FileCheck %s
 
 // CHECK: mesh.mesh @mesh0
 mesh.mesh @mesh0(shape = 2x2x4)
 
-// -----
-
-// CHECK-LABEL: func @update_halo
-func.func @update_halo_1d(
-    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+// CHECK-LABEL: func @update_halo_1d_first
+func.func @update_halo_1d_first(
+  // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
-  // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
-  // CHECK-SAME: split_axes = {{\[\[}}0]]
-  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
-  %c2 = arith.constant 2 : i64
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
+  // CHECK-NEXT: return
   mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
-    halo_sizes = [2, %c2] : memref<12x12xi8>
+    halo_sizes = [2, 3] : memref<12x12xi8>
+  return
+}
+
+// CHECK-LABEL: func @update_halo_1d_second
+func.func @update_halo_1d_second(
+  // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
+  %arg0 : memref<12x12xi8>) {
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1]>> to memref<12x3xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
+  // CHECK-NEXT: return
+  mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
+    halo_sizes = [2, 3] : memref<12x12xi8>
   return
 }
 
+// CHECK-LABEL: func @update_halo_2d
 func.func @update_halo_2d(
-    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+    // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
-  %c2 = arith.constant 2 : i64
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
-  // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
-  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
-  // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
+  // CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<1x12xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<1x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<1x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<1x12xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<2x12xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<2x12xi8, strided<[12, 1]>> to memref<2x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<2x12xi8>
+  // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
+  // CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32
+  // CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32
+  // CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
+  // CHECK-NEXT: scf.if [[v9]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_3]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<?x3xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v8]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<?x3xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT:   memref.copy [[valloc_3]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<?x3xi8>
+  // CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
+  // CHECK-NEXT: scf.if [[v11]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_4]] : memref<?x4xi8, strided<[12, 1], offset: 12>> to memref<?x4xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<?x4xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v10]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<?x4xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_4]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<?x4xi8>
+  // CHECK-NEXT: return
   mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
-    halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
-    : memref<12x12xi8>
+      halo_sizes = [1, 2, 3, 4]
+      : memref<12x12xi8>
   return
 }

>From bb8378f70ea6bae6435916e1633f6bd008c0e3d4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Aug 2024 19:23:13 +0200
Subject: [PATCH 21/26] fixed corner halos by reversing data-exchanges from
 high to low dims

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   |  91 +++++-----
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 161 +++++++++---------
 2 files changed, 137 insertions(+), 115 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 42d885a109ee79..9cf9458ce2b687 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -70,6 +70,7 @@ struct ConvertUpdateHaloOp
 
     auto array = op.getInput();
     auto rank = array.getType().getRank();
+    auto opSplitAxes = op.getSplitAxes().getAxes();
     auto mesh = op.getMesh();
     auto meshOp = getMesh(op, symbolTableCollection);
     auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
@@ -87,32 +88,54 @@ struct ConvertUpdateHaloOp
     // most of the offset/size/stride data is the same for all dims
     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> shape(rank);
+    SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
+    auto currHaloDim = -1; // halo sizes are provided for split dimensions only
     // we need the actual shape to compute offsets and sizes
-    for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+    for (auto i = 0; i < rank; ++i) {
+      auto s = array.getType().getShape()[i];
       if (ShapedType::isDynamic(s)) {
         shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
       } else {
         shape[i] = rewriter.getIndexAttr(s);
       }
+
+      if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
+        ++currHaloDim;
+        // the offsets for lower dim sstarts after their down halo
+        offsets[i] = haloSizes[currHaloDim * 2];
+
+        // prepare shape and offsets of highest dim's halo exchange
+        auto _haloSz =
+            rewriter
+                .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+                                       toValue(haloSizes[currHaloDim * 2 + 1]))
+                .getResult();
+        // the halo shape of lower dims exlude the halos
+        dimSizes[i] =
+            rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
+                .getResult();
+      } else {
+        dimSizes[i] = shape[i];
+      }
     }
 
     auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
     auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
     auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
     auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+
     SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
                                        rewriter.getIndexType());
     auto myMultiIndex =
         rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
             .getResult();
-    // halo sizes are provided for split dimensions only
-    auto currHaloDim = 0;
-
-    for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+    // traverse all split axes from high to low dim
+    for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
+      auto splitAxes = opSplitAxes[dim];
       if (splitAxes.empty()) {
         continue;
       }
+      assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
       // Get the linearized ids of the neighbors (down and up) for the
       // given split
       auto tmp = rewriter
@@ -124,11 +147,13 @@ struct ConvertUpdateHaloOp
                                    loc, rewriter.getI32Type(), tmp[0]),
                                rewriter.create<arith::IndexCastOp>(
                                    loc, rewriter.getI32Type(), tmp[1])};
-      // store for later
-      auto orgDimSize = shape[dim];
-      // this dim's offset to the start of the upper halo
-      auto upperOffset = rewriter.create<arith::SubIOp>(
+
+      auto lowerRecvOffset = rewriter.getIndexAttr(0);
+      auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
+      auto upperRecvOffset = rewriter.create<arith::SubIOp>(
           loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
+      auto upperSendOffset = rewriter.create<arith::SubIOp>(
+          loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
 
       // Make sure we send/recv in a way that does not lead to a dead-lock.
       // The current approach is by far not optimal, this should be at least
@@ -136,10 +161,10 @@ struct ConvertUpdateHaloOp
       // Also, buffers should be re-used.
       // Still using temporary contiguous buffers for MPI communication...
       // Still yielding a "serialized" communication pattern...
-      auto genSendRecv = [&](auto dim, bool upperHalo) {
+      auto genSendRecv = [&](bool upperHalo) {
         auto orgOffset = offsets[dim];
-        shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
-                               : haloSizes[currHaloDim * 2];
+        dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+                                  : haloSizes[currHaloDim * 2];
         // Check if we need to send and/or receive
         // Processes on the mesh borders have only one neighbor
         auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
@@ -149,14 +174,14 @@ struct ConvertUpdateHaloOp
         auto hasTo = rewriter.create<arith::CmpIOp>(
             loc, arith::CmpIPredicate::sge, to, zero);
         auto buffer = rewriter.create<memref::AllocOp>(
-            loc, shape, array.getType().getElementType());
+            loc, dimSizes, array.getType().getElementType());
         // if has neighbor: copy halo data from array to buffer and send
         rewriter.create<scf::IfOp>(
             loc, hasTo, [&](OpBuilder &builder, Location loc) {
-              offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
-                                       : OpFoldResult(upperOffset);
+              offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
+                                       : OpFoldResult(upperSendOffset);
               auto subview = builder.create<memref::SubViewOp>(
-                  loc, array, offsets, shape, strides);
+                  loc, array, offsets, dimSizes, strides);
               builder.create<memref::CopyOp>(loc, subview, buffer);
               builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
               builder.create<scf::YieldOp>(loc);
@@ -164,11 +189,11 @@ struct ConvertUpdateHaloOp
         // if has neighbor: receive halo data into buffer and copy to array
         rewriter.create<scf::IfOp>(
             loc, hasFrom, [&](OpBuilder &builder, Location loc) {
-              offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
-                                       : OpFoldResult(builder.getIndexAttr(0));
+              offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
+                                       : OpFoldResult(lowerRecvOffset);
               builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
               auto subview = builder.create<memref::SubViewOp>(
-                  loc, array, offsets, shape, strides);
+                  loc, array, offsets, dimSizes, strides);
               builder.create<memref::CopyOp>(loc, buffer, subview);
               builder.create<scf::YieldOp>(loc);
             });
@@ -176,25 +201,15 @@ struct ConvertUpdateHaloOp
         offsets[dim] = orgOffset;
       };
 
-      genSendRecv(dim, false);
-      genSendRecv(dim, true);
-
-      // prepare shape and offsets for next split dim
-      auto _haloSz =
-          rewriter
-              .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
-                                     toValue(haloSizes[currHaloDim * 2 + 1]))
-              .getResult();
-      // the shape for next halo excludes the halo on both ends for the
-      // current dim
-      shape[dim] =
-          rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
-              .getResult();
-      // the offsets for next halo starts after the down halo for the
-      // current dim
-      offsets[dim] = haloSizes[currHaloDim * 2];
+      genSendRecv(false);
+      genSendRecv(true);
+
+      // the shape for lower dims include higher dims' halos
+      dimSizes[dim] = shape[dim];
+      // -> the offset for higher dims is always 0
+      offsets[dim] = rewriter.getIndexAttr(0);
       // on to next halo
-      ++currHaloDim;
+      --currHaloDim;
     }
     rewriter.eraseOp(op);
     return mlir::success();
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 5f563364272d96..c3b0dc12e6d746 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -6,8 +6,10 @@ mesh.mesh @mesh0(shape = 2x2x4)
 // CHECK-LABEL: func @update_halo_1d_first
 func.func @update_halo_1d_first(
   // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
-    %arg0 : memref<12x12xi8>) {
+  %arg0 : memref<12x12xi8>) {
+  // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
   // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
   // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
   // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
   // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
@@ -18,7 +20,7 @@ func.func @update_halo_1d_first(
   // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
   // CHECK-NEXT: scf.if [[v3]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
   // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
   // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
   // CHECK-NEXT: }
@@ -32,8 +34,8 @@ func.func @update_halo_1d_first(
   // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
   // CHECK-NEXT: scf.if [[v5]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8>
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8>
   // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
   // CHECK-NEXT: }
   // CHECK-NEXT: scf.if [[v4]] {
@@ -42,9 +44,9 @@ func.func @update_halo_1d_first(
   // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
   // CHECK-NEXT: }
   // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
-  // CHECK-NEXT: return
   mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
     halo_sizes = [2, 3] : memref<12x12xi8>
+  // CHECK-NEXT: return
   return
 }
 
@@ -52,44 +54,46 @@ func.func @update_halo_1d_first(
 func.func @update_halo_1d_second(
   // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
   %arg0 : memref<12x12xi8>) {
-  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
-  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
-  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
-  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
-  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
-  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
-  // CHECK-NEXT: scf.if [[v3]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
-  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v2]] {
-  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
-  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
-  // CHECK-NEXT: scf.if [[v5]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1]>> to memref<12x3xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v4]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
-  // CHECK-NEXT: return
+  //CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
+  //CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  //CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
+  //CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  //CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  //CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  //CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
+  //CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  //CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  //CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  //CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  //CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
+  //CHECK-NEXT: scf.if [[v3]] {
+  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c7] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
+  //CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
+  //CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
+  //CHECK-NEXT: }
+  //CHECK-NEXT: scf.if [[v2]] {
+  //CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
+  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
+  //CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
+  //CHECK-NEXT: }
+  //CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
+  //CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  //CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  //CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
+  //CHECK-NEXT: scf.if [[v5]] {
+  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c2] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  //CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1], offset: ?>> to memref<12x3xi8>
+  //CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
+  //CHECK-NEXT: }
+  //CHECK-NEXT: scf.if [[v4]] {
+  //CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
+  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  //CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  //CHECK-NEXT: }
+  //CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
   mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
     halo_sizes = [2, 3] : memref<12x12xi8>
+  //CHECK-NEXT: return
   return
 }
 
@@ -97,77 +101,80 @@ func.func @update_halo_1d_second(
 func.func @update_halo_2d(
     // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
+  // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
+  // CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index
+  // CHECK-NEXT: [[vc5:%.*]] = arith.constant 5 : index
   // CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index
+  // CHECK-NEXT: [[vc3:%.*]] = arith.constant 3 : index
   // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
-  // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
   // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
   // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
   // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
   // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
   // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
   // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<1x12xi8>
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
   // CHECK-NEXT: scf.if [[v3]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
-  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<1x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c5] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<?x3xi8>, i32, i32
   // CHECK-NEXT: }
   // CHECK-NEXT: scf.if [[v2]] {
-  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<1x12xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<?x3xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
   // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<1x12xi8>
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<?x3xi8>
   // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<2x12xi8>
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
   // CHECK-NEXT: scf.if [[v5]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<2x12xi8, strided<[12, 1]>> to memref<2x12xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c3] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<?x4xi8, strided<[12, 1], offset: ?>> to memref<?x4xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<?x4xi8>, i32, i32
   // CHECK-NEXT: }
   // CHECK-NEXT: scf.if [[v4]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<?x4xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
   // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<2x12xi8>
-  // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<?x4xi8>
+  // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
   // CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32
   // CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32
   // CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
+  // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc() : memref<1x12xi8>
   // CHECK-NEXT: scf.if [[v9]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_3]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<?x3xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_3]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<1x12xi8>, i32, i32
   // CHECK-NEXT: }
   // CHECK-NEXT: scf.if [[v8]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<?x3xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
-  // CHECK-NEXT:   memref.copy [[valloc_3]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT:   mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<1x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc_3]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
   // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<?x3xi8>
+  // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<1x12xi8>
   // CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
+  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<2x12xi8>
   // CHECK-NEXT: scf.if [[v11]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: 12>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_4]] : memref<?x4xi8, strided<[12, 1], offset: 12>> to memref<?x4xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<?x4xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc1]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_4]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<2x12xi8>, i32, i32
   // CHECK-NEXT: }
   // CHECK-NEXT: scf.if [[v10]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<?x4xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[valloc_4]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_4]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
   // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<?x4xi8>
-  // CHECK-NEXT: return
+  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<2x12xi8>
   mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
       halo_sizes = [1, 2, 3, 4]
       : memref<12x12xi8>
+  // CHECK-NEXT: return
   return
 }

>From c5c0c3c44a794de66566b64f8722bbc35c0030c5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 21 Aug 2024 12:25:08 +0200
Subject: [PATCH 22/26] addressed review comments (docs, formatting)

---
 .../mlir/Conversion/MeshToMPI/MeshToMPI.h        |  2 +-
 mlir/include/mlir/Conversion/Passes.td           |  2 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td     |  6 +++---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp      | 16 +++++++++-------
 4 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
index 6a2c196da45577..b8803f386f7356 100644
--- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -1,4 +1,4 @@
-//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===//
+//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 83e0c5a06c43f7..2781fab917048d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -886,7 +886,7 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
   let summary = "Convert Mesh dialect to MPI dialect.";
   let description = [{
     This pass converts communication operations
-    from the Mesh dialect to operations from the MPI dialect.
+    from the Mesh dialect to the MPI dialect.
   }];
   let dependentDialects = [
     "memref::MemRefDialect",
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 2c2b6e20f3654d..e6f61aa84a1312 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -162,7 +162,7 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
   DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
   let summary =
-      "For given split axes get the linear index the direct neighbor processes.";
+      "For given split axes get the linear indices of the direct neighbor processes.";
   let description = [{
     Example:
     ```
@@ -172,8 +172,8 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
     Given `@mesh` with shape `(10, 20, 30)`,
           `device` = `(1, 2, 3)`
           `$split_axes` = `[1]`
-    it returns the linear indices of the processes at positions `(1, 1, 3)`: `633`
-    and `(1, 3, 3)`: `693`.
+    returns two indices, `633` and `693`, which correspond to the index of the previous
+    process `(1, 1, 3)`, and the next process `(1, 3, 3) along the split axis `1`.
 
     A negative value is returned if `$device` has no neighbor in the given
     direction along the given `split_axes`.
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 9cf9458ce2b687..ea1323e43462cd 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -45,15 +45,17 @@ struct ConvertUpdateHaloOp
   mlir::LogicalResult
   matchAndRewrite(mlir::mesh::UpdateHaloOp op,
                   mlir::PatternRewriter &rewriter) const override {
+    // The input/output memref is assumed to be in C memory order.
     // Halos are exchanged as 2 blocks per dimension (one for each side: down
-    // and up). It is assumed that the last dim in a default memref is
-    // contiguous, hence iteration starts with the complete halo on the first
-    // dim which should be contiguous (unless the source is not). The size of
-    // the exchanged data will decrease when iterating over dimensions. That's
-    // good because the halos of last dim will be most fragmented.
+    // and up). For each haloed dimension `d`, the exchanged blocks are
+    // expressed as multi-dimensional subviews. The subviews include potential
+    // halos of higher dimensions `dh > d`, no halos for the lower dimensions
+    // `dl < d` and for dimension `d` the currently exchanged halo only.
+    // By iterating form higher to lower dimensions this also updates the halos
+    // in the 'corners'.
     // memref.subview is used to read and write the halo data from and to the
-    // local data. subviews and halos have dynamic and static values, so
-    // OpFoldResults are used whenever possible.
+    // local data. Because subviews and halos can have mixed dynamic and static
+    // shapes, OpFoldResults are used whenever possible.
 
     SymbolTableCollection symbolTableCollection;
     auto loc = op.getLoc();

>From 9578fc415b699b5f4a4e756acfcc0af18e4c11f2 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 3 Sep 2024 10:41:22 +0200
Subject: [PATCH 23/26] newline

---
 mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
index b8803f386f7356..04271f8ab67b95 100644
--- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -24,4 +24,4 @@ std::unique_ptr<Pass> createConvertMeshToMPIPass();
 
 } // namespace mlir
 
-#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
\ No newline at end of file
+#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H

>From a7c9fe7b7573f766a1a4416ced0b268e97d3e854 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 31 Oct 2024 17:54:52 +0100
Subject: [PATCH 24/26] removing source from UpdateHaloOp, because not required
 for destination passing style

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 19 +++++++------------
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 10 +++++-----
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  4 +---
 mlir/test/Dialect/Mesh/ops.mlir               | 16 ++++++++--------
 mlir/test/Dialect/Mesh/spmdization.mlir       |  4 ++--
 5 files changed, 23 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index e6f61aa84a1312..3c52c63330e95f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -1095,8 +1095,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
   TypesMatchWith<
     "result has same type as destination",
     "result", "destination", "$_self">,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
-  AttrSizedOperandSegments
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>
 ]> {
   let summary = "Update halo data.";
   let description = [{
@@ -1105,7 +1104,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     on the remote devices. Changes might be caused by mutating operations
     and/or if the new halo regions are larger than the existing ones.
 
-    Source and destination might have different halo sizes.
+    Destination is supposed to be initialized with the local data (not halos).
 
     Assumes all devices hold tensors with same-sized halo data as specified
     by `source_halo_sizes/static_source_halo_sizes` and
@@ -1117,25 +1116,21 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
 
   }];
   let arguments = (ins
-    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
     FlatSymbolRefAttr:$mesh,
     Mesh_MeshAxesArrayAttr:$split_axes,
-    Variadic<I64>:$source_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
-    Variadic<I64>:$destination_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
+    Variadic<I64>:$halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes
   );
   let results = (outs
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
   );
   let assemblyFormat = [{
-    $source `into` $destination
+    $destination
     `on` $mesh
     `split_axes` `=` $split_axes
-    (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
-    (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
-    attr-dict `:` type($source) `->` type($result)
+    (`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
+    attr-dict `:` type($result)
   }];
   let extraClassDeclaration = [{
     MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index ea1323e43462cd..11d7c0e08f1a67 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -70,13 +70,13 @@ struct ConvertUpdateHaloOp
                            cast<IntegerAttr>(v.get<Attribute>()).getInt()));
     };
 
-    auto array = op.getInput();
-    auto rank = array.getType().getRank();
+    auto array = op.getDestination();
+    auto rank = cast<ShapedType>(array.getType()).getRank();
     auto opSplitAxes = op.getSplitAxes().getAxes();
     auto mesh = op.getMesh();
     auto meshOp = getMesh(op, symbolTableCollection);
     auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
-                                    op.getDynamicHaloSizes(), rewriter);
+                                    op.getHaloSizes(), rewriter);
     // subviews need Index values
     for (auto &sz : haloSizes) {
       if (sz.is<Value>()) {
@@ -94,7 +94,7 @@ struct ConvertUpdateHaloOp
     auto currHaloDim = -1; // halo sizes are provided for split dimensions only
     // we need the actual shape to compute offsets and sizes
     for (auto i = 0; i < rank; ++i) {
-      auto s = array.getType().getShape()[i];
+      auto s = cast<ShapedType>(array.getType()).getShape()[i];
       if (ShapedType::isDynamic(s)) {
         shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
       } else {
@@ -176,7 +176,7 @@ struct ConvertUpdateHaloOp
         auto hasTo = rewriter.create<arith::CmpIOp>(
             loc, arith::CmpIPredicate::sge, to, zero);
         auto buffer = rewriter.create<memref::AllocOp>(
-            loc, dimSizes, array.getType().getElementType());
+            loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
         // if has neighbor: copy halo data from array to buffer and send
         rewriter.create<scf::IfOp>(
             loc, hasTo, [&](OpBuilder &builder, Location loc) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b4d088cbd7088d..327ea0991e4e1e 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -496,11 +496,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
               sourceShard.getLoc(),
               RankedTensorType::get(outShape,
                                     sourceShard.getType().getElementType()),
-              sourceShard, initOprnd, mesh.getSymName(),
+              initOprnd, mesh.getSymName(),
               MeshAxesArrayAttr::get(builder.getContext(),
                                      sourceSharding.getSplitAxes()),
-              sourceSharding.getDynamicHaloSizes(),
-              sourceSharding.getStaticHaloSizes(),
               targetSharding.getDynamicHaloSizes(),
               targetSharding.getStaticHaloSizes())
           .getResult();
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index d8df01c3d6520d..978de4939ee77c 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -615,16 +615,16 @@ func.func @update_halo(
     // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
   // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
-  // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] into %[[ARG]] on @mesh0
+  // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0
   // CHECK-SAME: split_axes = {{\[\[}}0]]
-  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> -> memref<12x12xi8>
+  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
   %c2 = arith.constant 2 : i64
-  %uh1 = mesh.update_halo %arg0 into %arg0 on @mesh0 split_axes = [[0]]
-    source_halo_sizes = [2, %c2] : memref<12x12xi8> -> memref<12x12xi8>
-  // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[ARG]] into %[[UH1]] on @mesh0
+  %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+    halo_sizes = [2, %c2] : memref<12x12xi8>
+  // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0
   // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
-  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> -> memref<12x12xi8>
-  %uh2 = mesh.update_halo %arg0 into %uh1 on @mesh0 split_axes = [[0], [1]]
-    source_halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> -> memref<12x12xi8>
+  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8>
+  %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]]
+    halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8>
   return
 }
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 22ddb72569835d..c1b96fda0f4a74 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -226,7 +226,7 @@ func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1
   %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
   // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
   // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
-  // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64>
+  // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
   %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
   %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
   %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
@@ -242,7 +242,7 @@ func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200
   %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
   // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
   // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
-  // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64>
+  // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
   %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
   %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
   %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>

>From 4214070e5c8ec0040723418145b60f6ead00616f Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 5 Nov 2024 16:55:08 +0100
Subject: [PATCH 25/26] clang-format

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 11d7c0e08f1a67..5d9ea9cfccf8d4 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -75,8 +75,8 @@ struct ConvertUpdateHaloOp
     auto opSplitAxes = op.getSplitAxes().getAxes();
     auto mesh = op.getMesh();
     auto meshOp = getMesh(op, symbolTableCollection);
-    auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
-                                    op.getHaloSizes(), rewriter);
+    auto haloSizes =
+        getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
     // subviews need Index values
     for (auto &sz : haloSizes) {
       if (sz.is<Value>()) {

>From b13138b7931ef24ed9ea0921a8c8fe31aa7ed31c Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 5 Nov 2024 19:50:56 +0100
Subject: [PATCH 26/26] allow tensor as destination in UpdateHaloOp and fixing
 its tests

---
 mlir/include/mlir/Conversion/Passes.td        |  3 +-
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 23 ++++++-
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 65 ++++++++++++++++---
 3 files changed, 79 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2781fab917048d..43015ad5b11e65 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -891,7 +891,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
   let dependentDialects = [
     "memref::MemRefDialect",
     "mpi::MPIDialect",
-    "scf::SCFDialect"
+    "scf::SCFDialect",
+    "bufferization::BufferizationDialect"
   ];
 }
 
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 5d9ea9cfccf8d4..b1b58584aaae24 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
@@ -70,7 +71,16 @@ struct ConvertUpdateHaloOp
                            cast<IntegerAttr>(v.get<Attribute>()).getInt()));
     };
 
-    auto array = op.getDestination();
+    auto dest = op.getDestination();
+    auto dstShape = cast<ShapedType>(dest.getType()).getShape();
+    Value array = dest;
+    if (isa<RankedTensorType>(array.getType())) {
+      // If the destination is a memref, we need to cast it to a tensor
+      auto tensorType = MemRefType::get(
+          dstShape, cast<ShapedType>(array.getType()).getElementType());
+      array = rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array)
+                  .getResult();
+    }
     auto rank = cast<ShapedType>(array.getType()).getRank();
     auto opSplitAxes = op.getSplitAxes().getAxes();
     auto mesh = op.getMesh();
@@ -94,7 +104,7 @@ struct ConvertUpdateHaloOp
     auto currHaloDim = -1; // halo sizes are provided for split dimensions only
     // we need the actual shape to compute offsets and sizes
     for (auto i = 0; i < rank; ++i) {
-      auto s = cast<ShapedType>(array.getType()).getShape()[i];
+      auto s = dstShape[i];
       if (ShapedType::isDynamic(s)) {
         shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
       } else {
@@ -213,7 +223,14 @@ struct ConvertUpdateHaloOp
       // on to next halo
       --currHaloDim;
     }
-    rewriter.eraseOp(op);
+
+    if (isa<MemRefType>(op.getResult().getType())) {
+      rewriter.replaceOp(op, array);
+    } else {
+      assert(isa<RankedTensorType>(op.getResult().getType()));
+      rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
+                                 loc, op.getResult().getType(), array));
+    }
     return mlir::success();
   }
 };
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index c3b0dc12e6d746..d05c53bd83aaf9 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -53,7 +53,7 @@ func.func @update_halo_1d_first(
 // CHECK-LABEL: func @update_halo_1d_second
 func.func @update_halo_1d_second(
   // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
-  %arg0 : memref<12x12xi8>) {
+  %arg0 : memref<12x12xi8>) -> memref<12x12xi8> {
   //CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
   //CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
   //CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
@@ -91,16 +91,16 @@ func.func @update_halo_1d_second(
   //CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
   //CHECK-NEXT: }
   //CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
+  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
     halo_sizes = [2, 3] : memref<12x12xi8>
-  //CHECK-NEXT: return
-  return
+  //CHECK-NEXT: return [[varg0]] : memref<12x12xi8>
+  return %res : memref<12x12xi8>
 }
 
 // CHECK-LABEL: func @update_halo_2d
 func.func @update_halo_2d(
     // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
-    %arg0 : memref<12x12xi8>) {
+    %arg0 : memref<12x12xi8>) -> memref<12x12xi8> {
   // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
   // CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index
   // CHECK-NEXT: [[vc5:%.*]] = arith.constant 5 : index
@@ -172,9 +172,58 @@ func.func @update_halo_2d(
   // CHECK-NEXT:   memref.copy [[valloc_4]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
   // CHECK-NEXT: }
   // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<2x12xi8>
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
+  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
       halo_sizes = [1, 2, 3, 4]
       : memref<12x12xi8>
-  // CHECK-NEXT: return
-  return
+  // CHECK-NEXT: return [[varg0]] : memref<12x12xi8>
+  return %res : memref<12x12xi8>
+}
+
+// CHECK-LABEL: func @update_halo_1d_tnsr
+func.func @update_halo_1d_tnsr(
+  // CHECK-SAME: [[varg0:%.*]]: tensor<12x12xi8>
+  %arg0 : tensor<12x12xi8>) -> tensor<12x12xi8> {
+  // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[mref:%.*]] = bufferization.to_memref %arg0 : memref<12x12xi8>
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
+  // CHECK-NEXT: [[res:%.*]] = bufferization.to_tensor [[mref]] : memref<12x12xi8>
+  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+    halo_sizes = [2, 3] : tensor<12x12xi8>
+  // CHECK-NEXT: return [[res]]
+  return %res : tensor<12x12xi8>
 }



More information about the Mlir-commits mailing list