[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)

Frank Schlimbach llvmlistbot at llvm.org
Tue Jul 9 06:46:31 PDT 2024


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/98145

>From 23611e543959dc1b9aa151c91d3ce19776797bdd Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 24 Jun 2024 17:25:44 +0200
Subject: [PATCH 1/8] New op mesh.sharding replace attribute. Adding halo_sizes
 and shard_dims_sizes to sharding. First spmdization of halo annotated
 sharding

---
 .../mlir/Dialect/Mesh/IR/CMakeLists.txt       |   4 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 112 ++-----
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   | 135 +++++++-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 286 +++++++++++++----
 .../Mesh/Interfaces/ShardingInterface.h       |  22 +-
 .../Mesh/Interfaces/ShardingInterface.td      |  12 +-
 .../Mesh/Interfaces/ShardingInterfaceImpl.h   |  16 +-
 .../Dialect/Tensor/IR/ShardingInterfaceImpl.h |  23 ++
 mlir/include/mlir/InitAllDialects.h           |   2 +
 .../mlir/Interfaces/InferTypeOpInterface.h    |   2 +-
 .../Transforms/MeshShardingInterfaceImpl.cpp  |  28 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 293 ++++++++++++++----
 .../Mesh/Interfaces/ShardingInterface.cpp     |  93 +++---
 .../Mesh/Transforms/ShardingPropagation.cpp   |  60 ++--
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 231 +++++++++-----
 mlir/lib/Dialect/Tensor/IR/CMakeLists.txt     |   1 +
 .../Tensor/IR/ShardingInterfaceImpl.cpp       | 101 ++++++
 .../test/Dialect/Linalg/mesh-spmdization.mlir |   2 +-
 mlir/test/Dialect/Mesh/ops.mlir               |  77 +++--
 mlir/test/Dialect/Mesh/spmdization.mlir       |  84 ++++-
 .../Mesh/TestReshardingSpmdization.cpp        |  11 +-
 21 files changed, 1148 insertions(+), 447 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index 7ba966d8cab7c..f26c6285efd89 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -13,6 +13,10 @@ set(LLVM_TARGET_DEFINITIONS MeshBase.td)
 mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
 mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
 
+set(LLVM_TARGET_DEFINITIONS MeshBase.td)
+mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs)
+
 set(LLVM_TARGET_DEFINITIONS MeshOps.td)
 mlir_tablegen(MeshOps.h.inc -gen-op-decls)
 mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 3a85bf2d552f3..61403ac178980 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -12,6 +12,7 @@
 include "mlir/IR/OpBase.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/CommonAttrConstraints.td"
 include "mlir/IR/EnumAttr.td"
 
 //===----------------------------------------------------------------------===//
@@ -31,11 +32,13 @@ def Mesh_Dialect : Dialect {
   ];
 
   let useDefaultAttributePrinterParser = 1;
+  let useDefaultTypePrinterParser = 1;
   let hasConstantMaterializer = 1;
 }
 
 def Mesh_MeshAxis : I<16>;
 def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
 
 //===----------------------------------------------------------------------===//
 // Mesh Enums.
@@ -59,104 +62,33 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
 }
 
 def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
-  let assemblyFormat = "`<` $value `>`";
+  let assemblyFormat = "$value";
+}
+
+class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
+                   string baseCppClass = "::mlir::Type">
+    : TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
+  let mnemonic = typeMnemonic;
+}
+
+def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
+  let summary = "sharding definition";
+  let assemblyFormat = "";
 }
 
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
 //===----------------------------------------------------------------------===//
 
-def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
-  let mnemonic = "shard";
-
-  let parameters = (ins
-    AttrParameter<"::mlir::FlatSymbolRefAttr",
-     "The mesh on which tensors are sharded.">:$mesh,
-    ArrayRefParameter<"MeshAxesAttr">:$split_axes,
-    OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
-    OptionalParameter<"::mlir::mesh::ReductionKind">:$partial_type
-  );
-
-  let summary = "Attribute that extends tensor type to distributed tensor type.";
-
-  let description = [{
-    The MeshSharding attribute is used in a `mesh.shard` operation.
-    It specifies how a tensor is sharded and distributed across the process
-    mesh.
-
-    1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
-    mesh where the distributed tensor is placed. The symbol must resolve to a
-    `mesh.mesh` operation.
-
-    2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
-    maximum size is the `rank` of the related tensor. For the i-th sub-array, if
-    its value is [x, y], it indicates that the tensor's i-th dimension is splitted
-    along the x and y axes of the device mesh.
-
-    3. `partial_axes`: if not empty, this signifies that the tensor is partial
-    one along the specified mesh axes. An all-reduce should be applied to obtain
-    the complete tensor, with reduction type being specified by `partial_type`.
-
-    4. `partial_type`: indicates the reduction type of the possible all-reduce
-    op. It has 4 possible values:
-    `generic`: is not an allowed value inside a shard attribute.
-
-    Example:
-
-    ```
-    mesh.mesh @mesh0(shape = 2x2x4)
-
-    // The tensor is fully replicated on @mesh0.
-    // Currently, there must be at least one sub-array present in axes, even
-    // if it's empty. Otherwise, a parsing error will occur.
-    #mesh.shard<@mesh0, [[]]>
-
-    // The tensor is sharded on the first dimension along axis 0 of @mesh0
-    #mesh.shard<@mesh0, [[0]]>
-
-    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
-    // it is also a partial_sum along mesh axis 1.
-    #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
-
-    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
-    // it is also a partial_max along mesh axis 1.
-    #mesh.shard<@mesh0, [[0]], partial = max[1]>
-
-    // Could be used in the attribute of mesh.shard op
-    %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
-    ```
-  }];
-  let assemblyFormat = [{
-    `<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
-       $partial_axes^ `]`)? `>`
-  }];
-
-  let builders = [
-    AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
-                     "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
-                     "ArrayRef<MeshAxis>": $partial_axes,
-                     "mesh::ReductionKind": $partial_type), [{
-      SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
-                  split_axes, [&](ArrayRef<MeshAxis> array) {
-          return MeshAxesAttr::get($_ctxt, array);
-      });
-      return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
-                   partial_type);
-    }]>,
-    AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
-                     "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
-      return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, ReductionKind::Sum);
-    }]>
-  ];
-
+def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
+  let mnemonic = "axisarray";
+  let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
+  let assemblyFormat = "`[` $axes `]`";
   let extraClassDeclaration = [{
-    bool operator==(::mlir::Attribute rhs) const;
-    bool operator!=(::mlir::Attribute rhs) const;
-    bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
-    bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
+    size_t size() const { return getAxes().size(); }
+    auto begin() const { return getAxes().begin(); }
+    auto end() const { return getAxes().end(); }
   }];
-
-  let genVerifyDecl = 1;
 }
 
 #endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index b27c9e81b3293..12677c0bae740 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,6 +24,8 @@ namespace mesh {
 
 using MeshAxis = int16_t;
 using MeshAxesAttr = DenseI16ArrayAttr;
+using ShardShapeAttr = DenseI64ArrayAttr;
+using HaloSizePairAttr = DenseI64ArrayAttr;
 
 } // namespace mesh
 } // namespace mlir
@@ -33,6 +35,55 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
 
+namespace mlir {
+namespace mesh {
+
+class MeshSharding {
+  private:
+    ::mlir::FlatSymbolRefAttr mesh;
+    SmallVector<MeshAxesAttr> split_axes;
+    SmallVector<MeshAxis> partial_axes;
+    ReductionKind partial_type;
+    SmallVector<int64_t> static_halo_sizes;
+    SmallVector<int64_t> static_sharded_dims_sizes;
+    SmallVector<Value> dynamic_halo_sizes;
+    SmallVector<Value> dynamic_sharded_dims_sizes;
+  public:
+    MeshSharding(Value rhs);
+    static MeshSharding get(
+        ::mlir::FlatSymbolRefAttr mesh_,
+        ArrayRef<MeshAxesAttr> split_axes_,
+        ArrayRef<MeshAxis> partial_axes_ = {},
+        ReductionKind partial_type_ = ReductionKind::Sum,
+        ArrayRef<int64_t> static_halo_sizes_ = {},
+        ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+        ArrayRef<Value> dynamic_halo_sizes_ = {},
+        ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+    MeshSharding() = default;
+    ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
+    ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+    ArrayRef<MeshAxesAttr> getSplitAxes() const {return split_axes; }
+    ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
+    ReductionKind getPartialType() const { return partial_type; }
+    ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
+    ArrayRef<int64_t> getStaticShardedDimsSizes() const { return static_sharded_dims_sizes; }
+    ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
+    ArrayRef<Value> getDynamicShardedDimsSizes() const { return dynamic_sharded_dims_sizes; }
+    operator bool() const { return (!mesh) == false; }
+    bool operator==(Value rhs) const;
+    bool operator!=(Value rhs) const;
+    bool operator==(MeshSharding rhs) const;
+    bool operator!=(MeshSharding rhs) const;
+    template<typename RHS> bool sameExceptConstraint(RHS rhs) const;
+    template<typename RHS> bool sameConstraint(RHS rhs) const;
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
 
@@ -50,9 +101,9 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 }
 
 // Is the same tensor replicated on all processes.
-inline bool isFullReplication(MeshShardingAttr attr) {
-  return attr.getPartialAxes().empty() &&
-         llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
+inline bool isFullReplication(MeshSharding sharding) {
+  return sharding.getPartialAxes().empty() &&
+         llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
            return axes.asArrayRef().empty();
          });
 }
@@ -80,7 +131,8 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
 template <>
 inline mesh::MeshOp
 getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
-  return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
+  return getMesh(op.getOperation(),
+                 cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
                  symbolTableCollection);
 }
 
@@ -131,25 +183,90 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
 // result in a shape for each shard of ?x2x?.
 ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
-                           MeshShardingAttr sharding);
+                           MeshSharding sharding);
 
 // If ranked tensor type return its sharded counterpart.
 //
 // If not ranked tensor type return `type`.
 // `sharding` in that case must be null.
-Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 
 // Insert shard op if there is not one that already has the same sharding.
 // May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+void maybeInsertTargetShardingAnnotation(Value sharding,
                                          OpOperand &operand,
                                          OpBuilder &builder);
-void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+void maybeInsertTargetShardingAnnotation(Value sharding,
                                          OpResult result, OpBuilder &builder);
-void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+void maybeInsertSourceShardingAnnotation(Value sharding,
                                          OpOperand &operand,
                                          OpBuilder &builder);
 
+
+template<typename RHS>
+bool MeshSharding::sameExceptConstraint(RHS rhs) const {
+  if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
+    return false;
+  }
+
+  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+    return false;
+  }
+
+  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
+  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
+                                    getSplitAxes().begin() + minSize),
+                   llvm::make_range(rhs.getSplitAxes().begin(),
+                                    rhs.getSplitAxes().begin() + minSize))) {
+    return false;
+  }
+
+  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
+                                       getSplitAxes().end()),
+                      std::mem_fn(&MeshAxesAttr::empty)) &&
+         llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
+                                       rhs.getSplitAxes().end()),
+                      std::mem_fn(&MeshAxesAttr::empty));
+}
+
+template<typename RHS>
+bool MeshSharding::sameConstraint(RHS rhs) const {
+    if (rhs.getStaticHaloSizes().size() == getStaticHaloSizes().size() ) {
+      if (!llvm::equal(llvm::make_range(getStaticHaloSizes().begin(), getStaticHaloSizes().end()),
+                       llvm::make_range(rhs.getStaticHaloSizes().begin(), rhs.getStaticHaloSizes().end()))) {
+        return false;
+      }
+    } else {
+      return false;
+    }
+    if (rhs.getStaticShardedDimsSizes().size() == getDynamicHaloSizes().size() ) {
+      if (!llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(), getStaticShardedDimsSizes().end()),
+                       llvm::make_range(rhs.getStaticShardedDimsSizes().begin(), rhs.getStaticShardedDimsSizes().end()))) {
+        return false;
+      }
+    } else {
+      return false;
+    }
+    if (rhs.getDynamicHaloSizes().size() == getStaticShardedDimsSizes().size() ) {
+      if (!llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(), getDynamicHaloSizes().end()),
+                       llvm::make_range(rhs.getDynamicHaloSizes().begin(), rhs.getDynamicHaloSizes().end()))) {
+        return false;
+      }
+    } else {
+      return false;
+    }
+    if (rhs.getDynamicShardedDimsSizes().size() == getDynamicShardedDimsSizes().size()) {
+      if (!llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(), getDynamicShardedDimsSizes().end()),
+                       llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(), rhs.getDynamicShardedDimsSizes().end()))) {
+        return false;
+      }
+    } else {
+      return false;
+    }
+    return true;
+}
+
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8e1e475463585..b4de29f8e3214 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -20,7 +20,7 @@ include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
-// Mesh Dialect operations.
+// Mesh operations.
 //===----------------------------------------------------------------------===//
 
 class Mesh_Op<string mnemonic, list<Trait> traits = []> :
@@ -105,9 +105,200 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
   ];
 }
 
+def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary = "Get the multi index of current device along specified mesh axes.";
+  let description = [{
+    It is used in the SPMD format of IR.
+    The `axes` mush be non-negative and less than the total number of mesh axes.
+    If the axes are empty then get the index along all axes.
+  }];
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+  let results = (outs
+    Variadic<Index>:$result
+  );
+  let assemblyFormat = [{
+    `on` $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+  ];
+}
+
+def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary = "Get the linear index of the current device.";
+  let description = [{
+    Example:
+    ```
+    %idx = mesh.process_linear_index on @mesh : index
+    ```
+    if `@mesh` has shape `(10, 20, 30)`, a device with multi
+    index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
+  }];
+  let arguments = (ins FlatSymbolRefAttr:$mesh);
+  let results = (outs Index:$result);
+  let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Sharding operations.
+//===----------------------------------------------------------------------===//
+
+def Mesh_ShardingOp : Mesh_Op<"sharding", [
+    Pure,
+    AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Define a sharding of a tensor.";
+  let description = [{
+    The MeshSharding is used in a `mesh.shard` operation.
+    It specifies how a tensor is sharded and distributed across the process
+    mesh.
+
+    1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+    mesh where the distributed tensor is placed. The symbol must resolve to a
+    `mesh.mesh` operation.
+
+    2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+    maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+    its value is [x, y], it indicates that the tensor's i-th dimension is splitted
+    along the x and y axes of the device mesh.
+
+    3. `partial_axes`: if not empty, this signifies that the tensor is partial
+    one along the specified mesh axes. An all-reduce should be applied to obtain
+    the complete tensor, with reduction type being specified by `partial_type`.
+
+    4. `partial_type`: indicates the reduction type of the possible all-reduce
+    op. It has 4 possible values:
+    `generic`: is not an allowed value inside a shard attribute.
+
+    5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
+    `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
+    `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
+    halo of size 1 at the start of the dimension and a halo size is 2 at the end.
+    `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
+    e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
+    `?` indicates dynamic halo sizes.
+    
+    6. [Optional] Sizes of sharded dimensions of each shard.
+    `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
+    device-mesh one value for each sharded tensor dimension.
+    Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
+    `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
+    the device-mesh will get a shard of shape 16x8x32 and the second device will get a
+    shard of shape 16x24x32.
+    `?` indicates dynamic shard dimensions.
+    
+    
+    `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+
+    Example:
+
+    ```
+    mesh.mesh @mesh0(shape = 2x2x4)
+
+    // The tensor is fully replicated on @mesh0.
+    // Currently, there must be at least one sub-array present in axes, even
+    // if it's empty. Otherwise, a parsing error will occur.
+    #mesh.shard<@mesh0, [[]]>
+
+    // The tensor is sharded on the first dimension along axis 0 of @mesh0
+    #mesh.shard<@mesh0, [[0]]>
+
+    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
+    // it is also a partial_sum along mesh axis 1.
+    #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
+
+    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
+    // it is also a partial_max along mesh axis 1.
+    #mesh.shard<@mesh0, [[0]], partial = max[1]>
+
+    // Could be used in the attribute of mesh.shard op
+    %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+
+    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
+    // and it has halo-sizes of 1 and 2 on the sharded dim.
+    %0 = mesh.shard %arg0 to <@mesh0, [[0]] {<halo_sizes = [1, 2]>}> : tensor<4x8xf32>
+    ```
+  }];
+    
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    Mesh_MeshAxesArrayAttr:$split_axes,
+    OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
+    OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
+    Variadic<I64>:$dynamic_sharded_dims_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
+    Variadic<I64>:$dynamic_halo_sizes
+  );
+  let results = (outs
+    Mesh_Sharding:$result
+  );
+  let assemblyFormat = [{
+    $mesh `,` $split_axes
+    (`partial` `=` $partial_type $partial_axes^)?
+    oilist(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes) |
+           `sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes))
+    attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+                   "ArrayRef<MeshAxesAttr>":$split_axes,
+                   "ArrayRef<MeshAxis>":$partial_axes,
+                   "mesh::ReductionKind":$partial_type,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_sizes)>,
+    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+                   "ArrayRef<MeshAxesAttr>":$split_axes)>,
+    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
+                   "ArrayRef<MeshAxesAttr>":$split_axes,
+                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
+                   "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_sizes)>,
+    OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
+  ];
+  let hasVerifier = 1;
+}
+
+def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
+  let summary = "Get the shard shape of a given process/device.";
+  let description = [{
+    The device/process id is a linearized id of the device/process in the mesh.
+    This operation might be used during spmdization when the shard shape depends
+    on values if the sharding.
+  }];
+  let arguments = (ins
+    DenseI64ArrayAttr:$shape,
+    Mesh_Sharding:$sharding,
+    Index:$device
+  );
+  let results = (outs Variadic<Index>:$result);
+  let assemblyFormat = [{
+      $shape $sharding $device attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
+  ];
+}
+
 def Mesh_ShardOp : Mesh_Op<"shard", [
     Pure,
-    SameOperandsAndResultType,
+    AllTypesMatch<["result", "src"]>,
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
   ]> {
   let summary = "Annotate on how a tensor is sharded across a mesh.";
@@ -119,7 +310,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
     1. `input`: This operand represents the tensor value that needs to be
     annotated for sharding.
 
-    2. `shard`: This attribute is type of `MeshSharding`, which is the core data
+    2. `sharding`: This attribute is type of `MeshShardingType`, which is the core data
     structure to represent distribution of a tensor on a mesh.
 
     3. `annotate_for_users`: A unit attribute addressing the scenario when a
@@ -129,6 +320,10 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
     as an operand in subsequent operations. If not, the sharding applies to the
     operation that defines the tensor value.
 
+    4. `force`: A unit attribute requesting an explicit sharding of the data not
+    allowing to be optimizied away. This is useful in the presence of halos and
+    inplace semantics.
+
     Example:
     ```
     func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
@@ -188,68 +383,21 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
   }];
   let arguments = (ins
     AnyRankedTensor:$src,
-    MeshSharding:$shard,
-    UnitAttr:$annotate_for_users
+    Mesh_Sharding:$sharding,
+    UnitAttr:$annotate_for_users,
+    UnitAttr:$force
   );
   let results = (outs
     AnyRankedTensor:$result
   );
   let assemblyFormat = [{
-    $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
-      type($result)
+    $src `to` $sharding
+      (`annotate_for_users` $annotate_for_users^)?
+      (`force` $force^)?
+      attr-dict `:` type($result)
   }];
 }
 
-def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
-  Pure,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
-  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
-]> {
-  let summary = "Get the multi index of current device along specified mesh axes.";
-  let description = [{
-    It is used in the SPMD format of IR.
-    The `axes` mush be non-negative and less than the total number of mesh axes.
-    If the axes are empty then get the index along all axes.
-  }];
-  let arguments = (ins
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
-  );
-  let results = (outs
-    Variadic<Index>:$result
-  );
-  let assemblyFormat = [{
-    `on` $mesh (`axes` `=` $axes^)?
-    attr-dict `:` type($result)
-  }];
-  let builders = [
-    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
-    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
-  ];
-}
-
-def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
-  Pure,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
-  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
-]> {
-  let summary = "Get the linear index of the current device.";
-  let description = [{
-    Example:
-    ```
-    %idx = mesh.process_linear_index on @mesh : index
-    ```
-    if `@mesh` has shape `(10, 20, 30)`, a device with multi
-    index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
-  }];
-  let arguments = (ins FlatSymbolRefAttr:$mesh);
-  let results = (outs Index:$result);
-  let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
-  let builders = [
-    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
-  ];
-}
-
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
@@ -879,4 +1027,28 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
   let hasCanonicalizer = 1;
 }
 
+def Mesh_UpdateHaloOp : Mesh_CollectiveCommunicationOpBase<"update_halo", [
+    AllShapesMatch<["input", "result"]>,
+    AllElementTypesMatch<["input", "result"]>
+  ]> {
+  let summary = "Send over a device mesh.";
+  let description = [{
+    Send from one device to another within a device group.
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    DenseI64ArrayAttr:$dynamic_halo_sizes,
+    OptionalAttr<DenseI64ArrayAttr>:$target_halo_sizes
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `halo_sizes` `=` $dynamic_halo_sizes
+    (`target_halo_sizes` `=` $target_halo_sizes^)?
+    attr-dict `:` functional-type(operands, results)
+  }];
+  // let hasCanonicalizer = 1;
+}
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 216d7e10296df..dd662247dc639 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -44,24 +44,24 @@ struct ShardingOption {
   }
 };
 
-// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// This method retrieves the 'MeshSharding' from a given operation
 // result and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshShardingAttr>>
-getMeshShardingAttr(OpResult result);
+FailureOr<std::pair<bool, MeshSharding>>
+getMeshSharding(OpResult result);
 
-// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// This method retrieves the 'MeshSharding' from a given operation
 // operand and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshShardingAttr>>
-getMeshShardingAttr(OpOperand &opOperand);
+FailureOr<std::pair<bool, MeshSharding>>
+getMeshSharding(OpOperand &opOperand);
 
 namespace detail {
 
 FailureOr<ShardingOption>
 defaultGetShardingOption(Operation *op,
-                         ArrayRef<MeshShardingAttr> operandShardings,
-                         ArrayRef<MeshShardingAttr> resultShardings);
+                         ArrayRef<MeshSharding> operandShardings,
+                         ArrayRef<MeshSharding> resultShardings);
 
-FailureOr<SmallVector<MeshShardingAttr>>
+FailureOr<std::vector<MeshSharding>>
 defaultGetShardingAnnotations(Operation *op,
                               const ShardingOption &shardingOption);
 
@@ -74,8 +74,8 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
 // Assumes full replication on all ranked tensor arguments and results.
 void spmdizeFullyReplicatedOperation(
     Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshShardingAttr> operandShardings,
-    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    ArrayRef<MeshSharding> operandShardings,
+    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
     SymbolTableCollection &symbolTable, OpBuilder &builder);
 
 } // namespace mesh
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 47a74f619f56c..a70d2c3e03851 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -84,8 +84,8 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
         /*retTy=*/"FailureOr<ShardingOption>",
         /*methodName=*/"getShardingOption",
         /*args=*/(ins
-          "ArrayRef<MeshShardingAttr>": $operandShardings,
-          "ArrayRef<MeshShardingAttr>": $resultShardings
+          "ArrayRef<MeshSharding>": $operandShardings,
+          "ArrayRef<MeshSharding>": $resultShardings
         ),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
@@ -100,7 +100,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
           This is what shardings the operands and results need to have in order
           to shard the op according to shardingOption.
         }],
-        /*retTy=*/"FailureOr<SmallVector<MeshShardingAttr>>",
+        /*retTy=*/"FailureOr<std::vector<MeshSharding>>",
         /*methodName=*/"getShardingAnnotations",
         /*args=*/(ins
           "const ShardingOption &":$shardingOption
@@ -139,7 +139,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
           annotations from the IR for each argument/result and prepare
           `operandShardings` and `resultShardings`.
           Values that are not ranked tensors do not have sharding annotations.
-          In this case their corresponding MeshShardingAttr is null.
+          In this case their corresponding MeshSharding is null.
 
           For convenience it will also prepare `spmdizedOperands`, although
           they can be retrieved from the `spmdizationMap`.
@@ -161,8 +161,8 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
         /*methodName=*/"spmdize",
         /*args=*/(ins
           "ArrayRef<Value>": $spmdizedOperands,
-          "ArrayRef<MeshShardingAttr>": $operandShardings,
-          "ArrayRef<MeshShardingAttr>": $resultShardings,
+          "ArrayRef<MeshSharding>": $operandShardings,
+          "ArrayRef<MeshSharding>": $resultShardings,
           "IRMapping&": $spmdizationMap,
           "SymbolTableCollection &": $symbolTableCollection,
           "OpBuilder &":$builder
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index 5e4b4f3a66af9..a25ba2bf649b0 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -26,8 +26,8 @@ namespace mesh {
 // on the provided shardings for the op's operands and results.
 // Assumes that the indexingMaps are projected permutations.
 ShardingArray getMeshAxisAssignmentForLoopIterators(
-    ArrayRef<MeshShardingAttr> operandShardings,
-    ArrayRef<MeshShardingAttr> resultShardings,
+    ArrayRef<MeshSharding> operandShardings,
+    ArrayRef<MeshSharding> resultShardings,
     ArrayRef<utils::IteratorType> loopIteratorTypes,
     ArrayRef<AffineMap> indexingMaps);
 
@@ -44,8 +44,8 @@ SmallVector<MeshAxis> getReductionMeshAxes(
 // arguments/results sharded.
 void spmdizeTriviallyShardableOperation(
     Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshShardingAttr> operandShardings,
-    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    ArrayRef<MeshSharding> operandShardings,
+    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
     SymbolTableCollection &symbolTable, OpBuilder &builder);
 
 // All ranked tensor argument and result dimensions have
@@ -72,8 +72,8 @@ struct IndependentParallelIteratorDomainShardingInterface
   }
 
   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
-                        ArrayRef<MeshShardingAttr> operandShardings,
-                        ArrayRef<MeshShardingAttr> resultShardings,
+                        ArrayRef<MeshSharding> operandShardings,
+                        ArrayRef<MeshSharding> resultShardings,
                         IRMapping &spmdizationMap,
                         SymbolTableCollection &symbolTable,
                         OpBuilder &builder) const {
@@ -128,8 +128,8 @@ struct ElementwiseShardingInterface
   }
 
   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
-                        ArrayRef<MeshShardingAttr> operandShardings,
-                        ArrayRef<MeshShardingAttr> resultShardings,
+                        ArrayRef<MeshSharding> operandShardings,
+                        ArrayRef<MeshSharding> resultShardings,
                         IRMapping &spmdizationMap,
                         SymbolTableCollection &symbolTable,
                         OpBuilder &builder) const {
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
new file mode 100644
index 0000000000000..3e23419eeec07
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tensor {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 549c26c72d8a1..03b3a3709e472 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -75,6 +75,7 @@
 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -179,6 +180,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tensor::registerBufferizableOpInterfaceExternalModels(registry);
   tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
   tensor::registerInferTypeOpInterfaceExternalModels(registry);
+  tensor::registerShardingInterfaceExternalModels(registry);
   tensor::registerSubsetOpInterfaceExternalModels(registry);
   tensor::registerTilingInterfaceExternalModels(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 67de05b0cb4ff..237cfea223b66 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -270,7 +270,7 @@ class InferShapedTypeOpAdaptor
 /// shape and elemental types.
 /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
 ///   Less strict is possible (e.g., implements inferReturnTypeComponents and
-///   these always populates all element types and shapes or fails, but this\
+///   these always populates all element types and shapes or fails, but this
 ///   trait is currently only used where the interfaces are, so keep it
 ///   restricted for now).
 template <typename ConcreteType>
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 36b6088b83cc2..33686d344e828 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -43,7 +43,7 @@ namespace mlir::linalg {
 
 using MeshAxis = mesh::MeshAxis;
 using ReductionKind = mesh::ReductionKind;
-using MeshShardingAttr = mesh::MeshShardingAttr;
+using MeshSharding = mesh::MeshSharding;
 using ShardingArray = mesh::ShardingArray;
 using MeshOp = mesh::MeshOp;
 
@@ -103,18 +103,18 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
 }
 
 static MeshOp getMesh(Operation *op,
-                      ArrayRef<MeshShardingAttr> operandShardings,
-                      ArrayRef<MeshShardingAttr> resultShardings,
+                      ArrayRef<MeshSharding> operandShardings,
+                      ArrayRef<MeshSharding> resultShardings,
                       SymbolTableCollection &symbolTable) {
-  for (MeshShardingAttr sharding : operandShardings) {
+  for (MeshSharding sharding : operandShardings) {
     if (sharding) {
-      return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+      return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
     }
   }
 
-  for (MeshShardingAttr sharding : resultShardings) {
+  for (MeshSharding sharding : resultShardings) {
     if (sharding) {
-      return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+      return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
     }
   }
 
@@ -185,7 +185,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
 
 static void createAllReduceForResultWithoutPartialSharding(
     Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
-    MeshShardingAttr resultSharding, ReductionKind reductionKind,
+    MeshSharding resultSharding, ReductionKind reductionKind,
     IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
   SmallVector<MeshAxis> allReduceMeshAxes;
   llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
@@ -199,14 +199,14 @@ static void createAllReduceForResultWithoutPartialSharding(
 
   Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
   Value reducedValue = builder.create<mesh::AllReduceOp>(
-      spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
+      spmdizedLinalgOpResult, resultSharding.getMesh(),
       allReduceMeshAxes, reductionKind);
   spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
 }
 
 static void createAllReduceForResultsWithoutPartialShardings(
     LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
-    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
     ImplicitLocOpBuilder &builder) {
   ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
   for (auto [unshardedLinalgOpResult, resultSharding] :
@@ -219,8 +219,8 @@ static void createAllReduceForResultsWithoutPartialShardings(
 
 static void spmdizeLinalgOpWithShardedReduction(
     LinalgOp op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshShardingAttr> operandShardings,
-    ArrayRef<MeshShardingAttr> resultShardings,
+    ArrayRef<MeshSharding> operandShardings,
+    ArrayRef<MeshSharding> resultShardings,
     ArrayRef<utils::IteratorType> loopIteratorTypes,
     ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
     IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
@@ -293,8 +293,8 @@ struct StructuredOpShardingInterface
   }
 
   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
-                        ArrayRef<MeshShardingAttr> operandShardings,
-                        ArrayRef<MeshShardingAttr> resultShardings,
+                        ArrayRef<MeshSharding> operandShardings,
+                        ArrayRef<MeshSharding> resultShardings,
                         IRMapping &spmdizationMap,
                         SymbolTableCollection &symbolTable,
                         OpBuilder &builder) const {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 75fceee14e123..9c7c79e602903 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -86,6 +87,10 @@ void MeshDialect::initialize() {
 #define GET_ATTRDEF_LIST
 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
       >();
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
+      >();
 }
 
 Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
@@ -150,36 +155,82 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
 template <typename InShape, typename MeshShape, typename SplitAxes,
           typename OutShape>
 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
-                       const SplitAxes &splitAxes, OutShape &outShape) {
+                       const SplitAxes &splitAxes, OutShape &outShape,
+                       ArrayRef<int64_t> shardedDims = {},
+                       ArrayRef<int64_t> haloSizes = {}) {
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
-  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-    outShape[tensorAxis] = shardDimension(
-        inShape[tensorAxis],
-        collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
+
+  if (!shardedDims.empty()) {
+    for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+      if (innerSplitAxes.empty()) {
+        for (auto dimSz : shardedDims) {
+          auto inAxis = dimSz % inShape.size();
+          assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
+                 inShape[inAxis] == ShapedType::kDynamic);
+        }
+      } else {
+        auto sz = shardedDims[tensorAxis];
+        bool same = true;
+        for (size_t i = tensorAxis + inShape.size(); i < shardedDims.size();
+             i += inShape.size()) {
+          if (shardedDims[i] != sz) {
+            same = false;
+            break;
+          }
+        }
+        if (same) {
+          outShape[tensorAxis] = sz;
+        } else {
+          outShape[tensorAxis] = ShapedType::kDynamic;
+        }
+      }
+    }
+  } else {
+    for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+      outShape[tensorAxis] = shardDimension(
+          inShape[tensorAxis],
+          collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
+    }
+
+    if (!haloSizes.empty()) {
+      int haloAxis = 0;
+      for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+        if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
+            !innerSplitAxes.empty()) {
+          if (haloSizes[haloAxis * 2] >= 0 &&
+              haloSizes[haloAxis * 2 + 1] >= 0) {
+            outShape[tensorAxis] +=
+                haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
+            ++haloAxis;
+          } else {
+            outShape[tensorAxis] = ShapedType::kDynamic;
+          }
+        }
+      }
+    }
   }
 }
 
 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
-                                 MeshShardingAttr sharding) {
+                                 MeshSharding sharding) {
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
-  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
-             resShapeArr);
+  shardShape(
+      shape.getShape(), mesh.getShape(), sharding.getSplitAxes(), resShapeArr,
+      sharding.getStaticShardedDimsSizes(), sharding.getStaticHaloSizes());
   return shape.clone(resShapeArr);
 }
 
-Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
+Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
   if (rankedTensorType) {
     return shardShapedType(rankedTensorType, mesh, sharding);
   }
-
-  assert(!sharding);
   return type;
 }
 
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+void mlir::mesh::maybeInsertTargetShardingAnnotation(Value sharding,
                                                      OpOperand &operand,
                                                      OpBuilder &builder) {
   OpBuilder::InsertionGuard insertionGuard(builder);
@@ -187,7 +238,7 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
   Operation *operandOp = operand.getOwner();
   builder.setInsertionPointAfterValue(operandValue);
   ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
-  if (shardOp && shardOp.getShard() == sharding &&
+  if (shardOp && sharding == shardOp.getSharding() &&
       !shardOp.getAnnotateForUsers()) {
     // No need for anything the correct sharding is already set.
     return;
@@ -211,7 +262,7 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
   rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
 }
 
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+void mlir::mesh::maybeInsertTargetShardingAnnotation(Value sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
   for (auto &use : llvm::make_early_inc_range(result.getUses())) {
@@ -219,7 +270,7 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
   }
 }
 
-void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+void mlir::mesh::maybeInsertSourceShardingAnnotation(Value sharding,
                                                      OpOperand &operand,
                                                      OpBuilder &builder) {
   OpBuilder::InsertionGuard insertionGuard(builder);
@@ -229,7 +280,7 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
   bool isBlockArg = !operandSrcOp;
   ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
 
-  if (shardOp && shardOp.getShard() == sharding &&
+  if (shardOp && sharding == shardOp.getSharding() &&
       shardOp.getAnnotateForUsers()) {
     // No need for anything the correct sharding is already set.
     return;
@@ -254,11 +305,10 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
   auto newPreceedingShardOp =
       builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
                               /*annotate_for_users*/ false);
-  rewriter.replaceUsesWithIf(newShardOp.getOperand(), newPreceedingShardOp,
-                             [&newShardOp](OpOperand &use) {
-                               return use.getOwner() ==
-                                      newShardOp.getOperation();
-                             });
+  rewriter.replaceUsesWithIf(
+      newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
+        return use.getOwner() == newShardOp.getOperation();
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -331,16 +381,119 @@ void MeshShapeOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.shard attr
+// mesh.sharding
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                         FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
-                         ArrayRef<MeshAxis> partialAxes, ReductionKind) {
-  // TODO: At present mesh symbol ref is not verified. This is due to the
-  // difficulty in fetching the corresponding symbol op based on an attribute.
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes, ArrayRef<MeshAxis> partial_axes, mesh::ReductionKind partial_type, ArrayRef<int64_t> static_halo_sizes, ArrayRef<int64_t> static_sharded_dims_sizes) {
+      // SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
+      //             split_axes, [&](ArrayRef<MeshAxis> array) {
+      //     return MeshAxesAttr::get(b.getContext(), array);
+      // });
+      return build(b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
+                   ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes), ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
+                   ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {}, ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes), {});
+    
+}
+
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes) {
+      return build(b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+                   {}, {}, {}, {});
+    
+}
 
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes, ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes, ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+      mlir::SmallVector<int64_t> staticHalos, staticDims;
+      mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
+      dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
+      dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+      return build(b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+                   ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos, ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
+    
+}
+
+
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, mlir::mesh::MeshSharding from) {
+  build(b, odsState,
+        b.getIndexType(),
+        from.getMeshAttr(),
+        MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
+        b.getDenseI16ArrayAttr(from.getPartialAxes()),
+        ::mlir::mesh::ReductionKindAttr::get(b.getContext(), from.getPartialType()),
+        b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
+        from.getDynamicShardedDimsSizes(),
+        b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
+        from.getDynamicHaloSizes());
+}
+
+bool MeshSharding::operator==(Value rhs) const {
+  auto shardingOp =
+      mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
+  return shardingOp && sameExceptConstraint(shardingOp)
+         && sameConstraint(shardingOp);
+}
+
+bool MeshSharding::operator!=(Value rhs) const {
+  return !(*this == rhs);
+}
+
+bool MeshSharding::operator==(MeshSharding rhs) const {
+  return sameExceptConstraint(rhs) && sameConstraint(rhs);
+}
+
+bool MeshSharding::operator!=(MeshSharding rhs) const {
+  return !(*this == rhs);
+}
+
+MeshSharding::MeshSharding(Value rhs) {
+  auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
+  assert(shardingOp && "expected sharding op");
+  get(shardingOp.getMeshAttr(),
+      shardingOp.getSplitAxes().getAxes(),
+      shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
+      shardingOp.getPartialType().value_or(ReductionKind::Sum),
+      shardingOp.getStaticHaloSizes(),
+      shardingOp.getStaticShardedDimsSizes(),
+      SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
+      SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
+}
+
+MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
+                  ArrayRef<MeshAxesAttr> split_axes_,
+                  ArrayRef<MeshAxis> partial_axes_,
+                  ReductionKind partial_type_,
+                  ArrayRef<int64_t> static_halo_sizes_,
+                  ArrayRef<int64_t> static_sharded_dims_sizes_,
+                  ArrayRef<Value> dynamic_halo_sizes_,
+                  ArrayRef<Value> dynamic_sharded_dims_sizes_) {
+  MeshSharding res;
+  res.mesh = mesh_;
+  res.split_axes.resize(split_axes_.size());
+  for (auto [i, axis] : llvm::enumerate(split_axes_)) {
+    res.split_axes[i] = MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
+  }
+
+  auto do_copy = [&](auto src, auto dst) {
+    dst.resize(src.size());
+    for (auto [i, v] : llvm::enumerate(src)) {
+      dst[i] = v;
+    }
+  };
+
+  do_copy(partial_axes_, res.partial_axes);
+  res.partial_type = partial_type_;
+  res.static_halo_sizes.resize(static_halo_sizes_.size());
+  do_copy(static_halo_sizes_, res.static_halo_sizes);
+  res.static_sharded_dims_sizes.resize(static_sharded_dims_sizes_.size());
+  do_copy(static_sharded_dims_sizes_, res.static_sharded_dims_sizes);
+  res.dynamic_halo_sizes.resize(dynamic_halo_sizes_.size());
+  do_copy(dynamic_halo_sizes_, res.dynamic_halo_sizes);
+  res.dynamic_sharded_dims_sizes.resize(dynamic_sharded_dims_sizes_.size());
+  do_copy(dynamic_sharded_dims_sizes_, res.dynamic_sharded_dims_sizes);
+
+  return res;
+}
+
+LogicalResult ShardingOp::verify() {
   llvm::SmallSet<MeshAxis, 4> visitedAxes;
 
   auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
@@ -353,53 +506,45 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
     return success();
   };
 
-  for (MeshAxesAttr subAxes : splitAxes) {
+  for (auto subAxes : getSplitAxes().getAxes()) {
     ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
     if (failed(checkMeshAxis(subAxesArray)))
       return failure();
   }
-  if (failed(checkMeshAxis(partialAxes)))
+  if (getPartialAxes().has_value() && failed(checkMeshAxis(getPartialAxes().value())))
     return failure();
-  return success();
-}
-
-bool MeshShardingAttr::operator==(Attribute rhs) const {
-  MeshShardingAttr rhsAsMeshShardingAttr =
-      mlir::dyn_cast<MeshShardingAttr>(rhs);
-  return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
-}
 
-bool MeshShardingAttr::operator!=(Attribute rhs) const {
-  return !(*this == rhs);
-}
-
-bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
-  if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
-    return false;
+  if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
+    return emitOpError("halo sizes and shard shapes are mutually exclusive");
   }
-
-  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
-    return false;
+  
+  if(!getStaticHaloSizes().empty()) {
+    auto numSplitAxes = getSplitAxes().getAxes().size();
+    for (auto splitAxis : getSplitAxes().getAxes()) {
+      if (splitAxis.empty()) {
+        --numSplitAxes;
+      }
+    }
+    if (getStaticHaloSizes().size() != numSplitAxes * 2) {
+      return emitError() << "Halo sizes must be specified for all split axes.";
+    }
   }
 
-  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
-  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
-                                    getSplitAxes().begin() + minSize),
-                   llvm::make_range(rhs.getSplitAxes().begin(),
-                                    rhs.getSplitAxes().begin() + minSize))) {
-    return false;
-  }
+  return success();
+}
 
-  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
-                                       getSplitAxes().end()),
-                      std::mem_fn(&MeshAxesAttr::empty)) &&
-         llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
-                                       rhs.getSplitAxes().end()),
-                      std::mem_fn(&MeshAxesAttr::empty));
+void ShardingOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "sharding");
 }
 
-bool MeshShardingAttr::operator!=(MeshShardingAttr rhs) const {
-  return !(*this == rhs);
+//===----------------------------------------------------------------------===//
+// mesh.shard_shape
+//===----------------------------------------------------------------------===//
+
+void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<int64_t> shape, ::mlir::Value sharding, ::mlir::Value device) {
+  SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
+  build(odsBuilder, odsState, resType, shape, sharding, device);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1044,6 +1189,25 @@ void ShiftOp::getAsmResultNames(
   setNameFn(getResult(), "shift");
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.update_halo op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+
+  return success();
+}
+
+void UpdateHaloOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "update_halo");
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
@@ -1054,4 +1218,7 @@ void ShiftOp::getAsmResultNames(
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
 
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
+
 #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index bcd0e15561320..e525c31791261 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -91,12 +91,21 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
   return positions;
 }
 
+template<typename T>
+SmallVector<MeshAxesAttr> fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {
+  SmallVector<MeshAxesAttr> res;
+  for (const auto &v : vec) {
+    res.emplace_back(MeshAxesAttr::get(ctxt, v));
+  }
+  return res;
+}
+
 //===----------------------------------------------------------------------===//
-// mesh::getMeshShardingAttr
+// mesh::getMeshSharding
 //===----------------------------------------------------------------------===//
 
-FailureOr<std::pair<bool, MeshShardingAttr>>
-mesh::getMeshShardingAttr(OpResult result) {
+FailureOr<std::pair<bool, MeshSharding>>
+mesh::getMeshSharding(OpResult result) {
   Value val = cast<Value>(result);
   bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
     auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
@@ -111,7 +120,7 @@ mesh::getMeshShardingAttr(OpResult result) {
     if (!val.hasOneUse())
       return failure();
     auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
-    return std::make_pair(false, shardOp.getShard());
+    return std::make_pair(false, MeshSharding(shardOp.getSharding()));
   }
 
   bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
@@ -127,11 +136,11 @@ mesh::getMeshShardingAttr(OpResult result) {
       if (shardOp)
         shardOps.push_back(shardOp);
     }
-    MeshShardingAttr shardForDef = shardOps[0].getShard();
+    MeshSharding shardForDef = shardOps[0].getSharding();
     for (size_t i = 1; i < shardOps.size(); ++i) {
       // TODO: Deduce a reasonable mesh sharding attr for def when they are
       // different
-      assert(shardOps[i].getShard() == shardForDef &&
+      assert(shardForDef == shardOps[i].getSharding()  &&
              "only support all shard ops have the same mesh sharding attr");
     }
     return std::make_pair(true, shardForDef);
@@ -139,11 +148,11 @@ mesh::getMeshShardingAttr(OpResult result) {
   return failure();
 }
 
-FailureOr<std::pair<bool, MeshShardingAttr>>
-mesh::getMeshShardingAttr(OpOperand &opOperand) {
+FailureOr<std::pair<bool, MeshSharding>>
+mesh::getMeshSharding(OpOperand &opOperand) {
   Value val = opOperand.get();
   if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
-    return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
+    return std::make_pair(shardOp.getAnnotateForUsers(), MeshSharding(shardOp.getSharding()));
 
   return failure();
 }
@@ -251,8 +260,8 @@ static LogicalResult fillShardingOption(Operation *op,
 } // namespace
 
 FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
-    Operation *op, ArrayRef<MeshShardingAttr> operandShardings,
-    ArrayRef<MeshShardingAttr> resultShardings) {
+    Operation *op, ArrayRef<MeshSharding> operandShardings,
+    ArrayRef<MeshSharding> resultShardings) {
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
   ShardingOption shardingOption;
 
@@ -269,7 +278,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 
   // 1. Fill sharding option based on op results
   for (auto shardingIt : llvm::enumerate(resultShardings)) {
-    MeshShardingAttr shardAttr = shardingIt.value();
+    MeshSharding shardAttr = shardingIt.value();
     if (!shardAttr)
       continue;
     AffineMap map = maps[numOperands + shardingIt.index()];
@@ -283,7 +292,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
       auto dim = cast<AffineDimExpr>(expr);
       unsigned index = dim.getPosition();
       visitedLoopIndices.insert(index);
-      if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
+      if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
                                     axes, index)))
         return failure();
     }
@@ -307,7 +316,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 
   // 2. Fill sharding option based on operands
   for (auto shardingIt : llvm::enumerate(operandShardings)) {
-    MeshShardingAttr shardAttr = shardingIt.value();
+    MeshSharding shardAttr = shardingIt.value();
     if (!shardAttr)
       continue;
 
@@ -334,7 +343,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
       if (loopIndices->size() == 1) {
         unsigned loopIdx = *loopIndices->begin();
         visitedLoopIndices.insert(loopIdx);
-        if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
+        if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
                                       axes, loopIdx)))
           return failure();
       }
@@ -389,7 +398,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 }
 
 // Get the sharding attributed for the given result and sharding option.
-MeshShardingAttr
+MeshSharding
 getShardingAttribute(OpResult result, const ShardingOption &shardingOption,
                      AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
                      ArrayRef<ReductionKind> reductionLoopKinds) {
@@ -399,6 +408,7 @@ getShardingAttribute(OpResult result, const ShardingOption &shardingOption,
 
   // process the split axes
   for (auto it : llvm::enumerate(map.getResults())) {
+    SmallVector<MeshAxis> tmp_axes;
     AffineExpr expr = it.value();
     // `expr` must be an `AffineDimExpr` because `map` is verified by
     // isProjectedPermutation
@@ -427,11 +437,11 @@ getShardingAttribute(OpResult result, const ShardingOption &shardingOption,
   }
 
   removeTrailingEmptySubArray(splitAxes);
-  return MeshShardingAttr::get(result.getContext(), shardingOption.mesh,
-                               splitAxes, partialAxes, partialType);
+  return MeshSharding::get(shardingOption.mesh,
+                           fromArrayOfVector(result.getContext(), splitAxes), partialAxes, partialType);
 }
 
-static FailureOr<MeshShardingAttr>
+static FailureOr<MeshSharding>
 getShardingAttribute(OpOperand &opOperand, const ShardingOption &shardingOption,
                      AffineMap map) {
   Value operandValue = opOperand.get();
@@ -461,14 +471,13 @@ getShardingAttribute(OpOperand &opOperand, const ShardingOption &shardingOption,
   }
 
   removeTrailingEmptySubArray(splitAxes);
-  return MeshShardingAttr::get(opOperand.get().getContext(),
-                               shardingOption.mesh, splitAxes);
+  return MeshSharding::get(shardingOption.mesh, fromArrayOfVector(opOperand.get().getContext(), splitAxes));
 }
 
-FailureOr<SmallVector<MeshShardingAttr>>
+FailureOr<std::vector<MeshSharding>>
 mesh::detail::defaultGetShardingAnnotations(
     Operation *op, const ShardingOption &shardingOption) {
-  SmallVector<MeshShardingAttr> res;
+  std::vector<MeshSharding> res;
 
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
   SmallVector<utils::IteratorType> loopTypes =
@@ -479,7 +488,7 @@ mesh::detail::defaultGetShardingAnnotations(
   unsigned numOperands = op->getNumOperands();
 
   for (OpOperand &opOperand : op->getOpOperands()) {
-    FailureOr<MeshShardingAttr> shardingAttr = getShardingAttribute(
+    FailureOr<MeshSharding> shardingAttr = getShardingAttribute(
         opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
     if (failed(shardingAttr))
       return failure();
@@ -506,9 +515,10 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                 AffineMap map,
                                 ArrayRef<utils::IteratorType> loopTypes,
                                 ArrayRef<ReductionKind> reductionLoopKinds) {
-  MeshShardingAttr shardAttr = getShardingAttribute(
+  MeshSharding sharding = getShardingAttribute(
       result, shardingOption, map, loopTypes, reductionLoopKinds);
-  maybeInsertTargetShardingAnnotation(shardAttr, result, b);
+  auto shardingOp = b.create<ShardingOp>(result.getLoc(), sharding);
+  maybeInsertTargetShardingAnnotation(shardingOp.getResult(), result, b);
 
   return success();
 }
@@ -519,13 +529,14 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
                                 const ShardingOption &shardingOption,
                                 AffineMap map) {
 
-  FailureOr<MeshShardingAttr> shardAttr =
+  FailureOr<MeshSharding> sharding =
       getShardingAttribute(opOperand, shardingOption, map);
-  if (failed(shardAttr)) {
+  if (failed(sharding)) {
     return failure();
   }
   OpBuilder::InsertionGuard guard(b);
-  maybeInsertSourceShardingAnnotation(*shardAttr, opOperand, b);
+  auto shardingOp = b.create<ShardingOp>(opOperand.get().getLoc(), sharding.value());
+  maybeInsertSourceShardingAnnotation(shardingOp.getResult(), opOperand, b);
 
   return success();
 }
@@ -563,7 +574,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
 #ifndef NDEBUG
 static bool
 isValueCompatibleWithFullReplicationSharding(Value value,
-                                             MeshShardingAttr sharding) {
+                                             MeshSharding sharding) {
   if (isa<RankedTensorType>(value.getType())) {
     return sharding && isFullReplication(sharding);
   }
@@ -571,15 +582,15 @@ isValueCompatibleWithFullReplicationSharding(Value value,
   return !sharding;
 }
 
-template <typename ValueRange, typename MeshShardingAttrRage>
+template <typename ValueRange, typename MeshShardingRage>
 static bool areValuesCompatibleWithFullReplicationShardings(
-    ValueRange &&values, MeshShardingAttrRage &&shardings) {
+    ValueRange &&values, MeshShardingRage &&shardings) {
   if (std::size(values) != std::size(shardings)) {
     return false;
   }
   return llvm::all_of(llvm::zip_equal(
                           std::forward<ValueRange>(values),
-                          std::forward<MeshShardingAttrRage>(shardings)),
+                          std::forward<MeshShardingRage>(shardings)),
                       [](auto valueAndSharding) {
                         return isValueCompatibleWithFullReplicationSharding(
                             std::get<0>(valueAndSharding),
@@ -590,8 +601,8 @@ static bool areValuesCompatibleWithFullReplicationShardings(
 
 void mesh::spmdizeFullyReplicatedOperation(
     Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshShardingAttr> operandShardings,
-    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    ArrayRef<MeshSharding> operandShardings,
+    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
     SymbolTableCollection &symbolTable, OpBuilder &builder) {
   assert(spmdizedOperands.size() == operandShardings.size());
   assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
@@ -618,13 +629,13 @@ static void updateMeshAxisAssignmentForLoopIterators(
 }
 
 ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
-    ArrayRef<MeshShardingAttr> operandShardings,
-    ArrayRef<MeshShardingAttr> resultShardings,
+    ArrayRef<MeshSharding> operandShardings,
+    ArrayRef<MeshSharding> resultShardings,
     ArrayRef<utils::IteratorType> loopIteratorTypes,
     ArrayRef<AffineMap> indexingMaps) {
   SmallVector<std::optional<SmallVector<MeshAxis>>>
       meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
-  SmallVector<MeshShardingAttr> operatorAndResultShardings;
+  std::vector<MeshSharding> operatorAndResultShardings;
   operatorAndResultShardings.reserve(operandShardings.size() +
                                      resultShardings.size());
   llvm::append_range(operatorAndResultShardings, operandShardings);
@@ -686,8 +697,8 @@ SmallVector<MeshAxis> mesh::getReductionMeshAxes(
 
 void mesh::spmdizeTriviallyShardableOperation(
     Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshShardingAttr> operandShardings,
-    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    ArrayRef<MeshSharding> operandShardings,
+    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
     SymbolTableCollection &symbolTable, OpBuilder &builder) {
   // `clone` will populate the mapping of old to new results.
   Operation *newOp = builder.clone(op, spmdizationMap);
@@ -695,7 +706,7 @@ void mesh::spmdizeTriviallyShardableOperation(
   for (auto [oldResult, newResult, sharding] :
        llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
     newResult.setType(shardType(newResult.getType(),
-                                getMesh(&op, sharding.getMesh(), symbolTable),
+                                getMesh(&op, sharding.getMeshAttr(), symbolTable),
                                 sharding));
   }
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 511c9102fa303..ac264d93a9776 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -108,16 +108,16 @@ operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
 // specific shardings. For example, mustShardings = [shard0, None] and
 // optionalShardings = [None, shard1], the result will be [[shard0, shard1],
 // [shard0, None]]
-static SmallVector<SmallVector<MeshShardingAttr>>
-getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
-                                ArrayRef<MeshShardingAttr> optionalShardings) {
-  SmallVector<SmallVector<MeshShardingAttr>> allShardingAttrs;
-  SmallVector<MeshShardingAttr> curShardingAttrs;
+static SmallVector<std::vector<MeshSharding>>
+getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
+                                ArrayRef<MeshSharding> optionalShardings) {
+  SmallVector<std::vector<MeshSharding>> allShardingAttrs;
+  std::vector<MeshSharding> curShardingAttrs;
 
   std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
     if (i == mustShardings.size()) {
       allShardingAttrs.push_back(
-          SmallVector<MeshShardingAttr>(curShardingAttrs));
+          std::vector<MeshSharding>(curShardingAttrs));
       return;
     }
 
@@ -132,13 +132,13 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
       curShardingAttrs.push_back(optionalShardings[i]);
       dfsCreateShardingAttrs(i + 1);
       curShardingAttrs.pop_back();
-      curShardingAttrs.push_back(nullptr);
+      curShardingAttrs.push_back({});
       dfsCreateShardingAttrs(i + 1);
       curShardingAttrs.pop_back();
       return;
     }
 
-    curShardingAttrs.push_back(nullptr);
+    curShardingAttrs.push_back({});
     dfsCreateShardingAttrs(i + 1);
     curShardingAttrs.pop_back();
   };
@@ -159,7 +159,7 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
 //   annotation targeting explicitly this operation.
 ReshardingRquirementKind getReshardingRquirementKind(
     Operation *op,
-    const SmallVector<MeshShardingAttr> &operandAndResultShardings) {
+    const std::vector<MeshSharding> &operandAndResultShardings) {
   ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
 
   size_t operandsCount = op->getOperands().size();
@@ -176,7 +176,7 @@ ReshardingRquirementKind getReshardingRquirementKind(
     if (!shardOp) {
       continue;
     }
-    bool needsResharding = shardOp.getShardAttr() != sharding;
+    bool needsResharding = sharding != shardOp.getSharding();
     bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
     if (needsResharding) {
       if (isExplicitAnnotationForThisOp) {
@@ -194,7 +194,7 @@ ReshardingRquirementKind getReshardingRquirementKind(
       if (!shardOp) {
         continue;
       }
-      bool needsResharding = shardOp.getShardAttr() != sharding;
+      bool needsResharding = sharding != shardOp.getSharding();
       bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
       if (needsResharding) {
         if (isExplicitAnnotationForThisOp) {
@@ -218,14 +218,14 @@ ReshardingRquirementKind getReshardingRquirementKind(
 // 3. Resharding of existing explicit sharding annotations for this op.
 static FailureOr<ShardingOption> selectShardingOption(
     ShardingInterface shardingOp,
-    ArrayRef<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs,
-    ArrayRef<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs) {
+    ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
+    ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
   SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
       shardingOptionsAndReshardingRequirements;
 
-  for (ArrayRef<MeshShardingAttr> resultShardings :
+  for (ArrayRef<MeshSharding> resultShardings :
        possibleResultShardingAttrs) {
-    for (ArrayRef<MeshShardingAttr> operandShardings :
+    for (ArrayRef<MeshSharding> operandShardings :
          possibleOperandShardingAttrs) {
       FailureOr<ShardingOption> shardingOption =
           shardingOp.getShardingOption(operandShardings, resultShardings);
@@ -237,14 +237,14 @@ static FailureOr<ShardingOption> selectShardingOption(
       // They may be missing some annotations.
       // Whatever is returned by getShardingAnnotations is exactly what the op
       // needs.
-      FailureOr<SmallVector<MeshShardingAttr>> operandAndResultShardings =
+      FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
           shardingOp.getShardingAnnotations(*shardingOption);
       if (failed(operandAndResultShardings)) {
         return failure();
       }
 
-      LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
-                        << *operandAndResultShardings << "\n";);
+      // LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
+      //                   << *operandAndResultShardings << "\n";);
 
       ReshardingRquirementKind reshardingRquirement =
           getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
@@ -294,14 +294,14 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
     return failure();
   }
 
-  // collect MeshShardingAttr from results
-  SmallVector<MeshShardingAttr> allowConflictsResultShardings;
+  // collect MeshSharding from results
+  std::vector<MeshSharding> allowConflictsResultShardings;
   allowConflictsResultShardings.resize(op->getNumResults());
-  SmallVector<MeshShardingAttr> resultMustShardings;
+  std::vector<MeshSharding> resultMustShardings;
   resultMustShardings.resize(op->getNumResults());
   for (OpResult result : op->getResults()) {
-    FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
-        getMeshShardingAttr(result);
+    FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
+        getMeshSharding(result);
     if (failed(maybeShardAttr))
       continue;
     if (!maybeShardAttr->first)
@@ -311,14 +311,14 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
           maybeShardAttr->second;
   }
 
-  // collect MeshShardingAttr from operands
-  SmallVector<MeshShardingAttr> allowConflictsOperandShardings;
+  // collect MeshSharding from operands
+  std::vector<MeshSharding> allowConflictsOperandShardings;
   allowConflictsOperandShardings.resize(op->getNumOperands());
-  SmallVector<MeshShardingAttr> operandMustShardings;
+  std::vector<MeshSharding> operandMustShardings;
   operandMustShardings.resize(op->getNumOperands());
   for (OpOperand &opOperand : op->getOpOperands()) {
-    FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
-        getMeshShardingAttr(opOperand);
+    FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
+        getMeshSharding(opOperand);
     if (failed(maybeShardAttr))
       continue;
 
@@ -331,10 +331,10 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
   }
 
   // try to get the sharding option
-  SmallVector<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs =
+  SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
       getOrderedPossibleShardingAttrs(operandMustShardings,
                                       allowConflictsOperandShardings);
-  SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
+  SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
       getOrderedPossibleShardingAttrs(resultMustShardings,
                                       allowConflictsResultShardings);
   FailureOr<ShardingOption> shardingOption = selectShardingOption(
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 1df3cf62c2b53..2ba7e9998b49c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -54,10 +54,10 @@ static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
 // targetSharding = <@mesh_1d, [[]]>
 // Then will apply all-reduce on the source value
 // and return it with the sharding <@mesh_1d, [[0]]>.
-static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+static std::tuple<TypedValue<ShapedType>, MeshSharding>
 handlePartialAxesDuringResharding(OpBuilder &builder,
-                                  MeshShardingAttr sourceSharding,
-                                  MeshShardingAttr targetSharding,
+                                  MeshSharding sourceSharding,
+                                  MeshSharding targetSharding,
                                   TypedValue<ShapedType> sourceShard) {
   if (sourceSharding.getPartialAxes().empty() &&
       targetSharding.getPartialAxes().empty()) {
@@ -88,7 +88,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
   TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
       builder
           .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
-                               sourceSharding.getMesh().getLeafReference(),
+                               sourceSharding.getMeshAttr().getLeafReference(),
                                allReduceMeshAxes, sourceShard,
                                sourceSharding.getPartialType())
           .getResult());
@@ -99,15 +99,15 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
                 [&targetShardingPartialAxesSet](Axis a) {
                   return targetShardingPartialAxesSet.contains(a);
                 });
-  MeshShardingAttr resultSharding =
-      MeshShardingAttr::get(builder.getContext(), sourceSharding.getMesh(),
+  MeshSharding resultSharding =
+      MeshSharding::get(sourceSharding.getMeshAttr(),
                             sourceSharding.getSplitAxes(), remainingPartialAxes,
                             sourceSharding.getPartialType());
   return {resultValue, resultSharding};
 }
 
-static MeshShardingAttr
-targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+static MeshSharding
+targetShardingInSplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding,
                               int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
@@ -120,17 +120,17 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
   targetSplitAxes.push_back(splitMeshAxis);
   targetShardingSplitAxes[splitTensorAxis] =
       MeshAxesAttr::get(ctx, targetSplitAxes);
-  return MeshShardingAttr::get(
-      ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
+  return MeshSharding::get(
+      sourceSharding.getMeshAttr(), targetShardingSplitAxes,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
 }
 
 // Split a replicated tensor along a mesh axis.
 // e.g. [[0, 1]] -> [[0, 1, 2]].
 // Returns the spmdized target value with its sharding.
-static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+static std::tuple<TypedValue<ShapedType>, MeshSharding>
 splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
-                          MeshShardingAttr sourceSharding,
+                          MeshSharding sourceSharding,
                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
@@ -139,7 +139,7 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                               ArrayRef<MeshAxis>(splitMeshAxis),
                               splitTensorAxis)
           .getResult());
-  MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
+  MeshSharding targetSharding = targetShardingInSplitLastAxis(
       builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
   return {targetShard, targetSharding};
 }
@@ -150,8 +150,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
 // Does not detect insertions like
 // [[0, 1]] -> [[0, 2, 1]].
 static std::optional<std::tuple<int64_t, MeshAxis>>
-detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
-                                MeshShardingAttr targetSharding) {
+detectSplitLastAxisInResharding(MeshSharding sourceSharding,
+                                MeshSharding targetSharding) {
   for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
        ++tensorAxis) {
     if (sourceSharding.getSplitAxes().size() > tensorAxis) {
@@ -181,10 +181,10 @@ detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
   return std::nullopt;
 }
 
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
 trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                             MeshShardingAttr sourceSharding,
-                             MeshShardingAttr targetSharding,
+                             MeshSharding sourceSharding,
+                             MeshSharding targetSharding,
                              TypedValue<ShapedType> sourceShard) {
   if (auto detectRes =
           detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
@@ -200,8 +200,8 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
 // [[0, 1, 2]] -> [[0, 1]].
 // If detected, returns the corresponding tensor axis mesh axis pair.
 static std::optional<std::tuple<int64_t, MeshAxis>>
-detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
-                                  MeshShardingAttr targetSharding) {
+detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
+                                  MeshSharding targetSharding) {
   for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
        ++tensorAxis) {
     if (targetSharding.getSplitAxes().size() > tensorAxis) {
@@ -228,9 +228,9 @@ detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
   return std::nullopt;
 }
 
-static MeshShardingAttr
+static MeshSharding
 targetShardingInUnsplitLastAxis(MLIRContext *ctx,
-                                MeshShardingAttr sourceSharding,
+                                MeshSharding sourceSharding,
                                 int64_t splitTensorAxis) {
   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
@@ -242,8 +242,8 @@ targetShardingInUnsplitLastAxis(MLIRContext *ctx,
   targetSplitAxes.pop_back();
   targetShardingSplitAxes[splitTensorAxis] =
       MeshAxesAttr::get(ctx, targetSplitAxes);
-  return MeshShardingAttr::get(
-      ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
+  return MeshSharding::get(
+      sourceSharding.getMeshAttr(), targetShardingSplitAxes,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
 }
 
@@ -255,16 +255,16 @@ static ShapedType allGatherResultShapeInUnsplitLastAxis(
   return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
 }
 
-static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+static std::tuple<TypedValue<ShapedType>, MeshSharding>
 unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
-                            MeshShardingAttr sourceSharding,
+                            MeshSharding sourceSharding,
                             ShapedType sourceUnshardedShape,
                             TypedValue<ShapedType> sourceShard, MeshOp mesh,
                             int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
   MLIRContext *ctx = builder.getContext();
   builder.setInsertionPointAfterValue(sourceShard);
 
-  MeshShardingAttr targetSharding =
+  MeshSharding targetSharding =
       targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
   ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
       sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
@@ -280,10 +280,10 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
   return {targetShard, targetSharding};
 }
 
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
 tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                               MeshShardingAttr sourceSharding,
-                               MeshShardingAttr targetSharding,
+                               MeshSharding sourceSharding,
+                               MeshSharding targetSharding,
                                ShapedType sourceUnshardedShape,
                                TypedValue<ShapedType> sourceShard) {
   if (auto detectRes =
@@ -303,8 +303,8 @@ tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
 // If detected, returns the corresponding (source_tensor_axis,
 // target_tensor_axis, mesh_axis) tuple.
 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
-detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
-                                    MeshShardingAttr targetSharding) {
+detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
+                                    MeshSharding targetSharding) {
   for (size_t sourceTensorAxis = 0;
        sourceTensorAxis < sourceSharding.getSplitAxes().size();
        ++sourceTensorAxis) {
@@ -344,8 +344,8 @@ detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
   return std::nullopt;
 }
 
-static MeshShardingAttr
-targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+static MeshSharding
+targetShardingInMoveLastAxis(MLIRContext *ctx, MeshSharding sourceSharding,
                              int64_t sourceTensorAxis,
                              int64_t targetTensorAxis) {
   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
@@ -369,8 +369,8 @@ targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
   targetShardingSplitAxes[targetTensorAxis] =
       MeshAxesAttr::get(ctx, targetSplitAxes);
 
-  return MeshShardingAttr::get(
-      ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
+  return MeshSharding::get(
+      sourceSharding.getMeshAttr(), targetShardingSplitAxes,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
 }
 
@@ -386,9 +386,9 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
   return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
 }
 
-static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+static std::tuple<TypedValue<ShapedType>, MeshSharding>
 moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                              MeshShardingAttr sourceSharding,
+                              MeshSharding sourceSharding,
                               ShapedType sourceUnshardedShape,
                               TypedValue<ShapedType> sourceShard,
                               int64_t sourceTensorAxis,
@@ -396,7 +396,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   MLIRContext *ctx = builder.getContext();
   builder.setInsertionPointAfterValue(sourceShard);
 
-  MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
+  MeshSharding targetSharding = targetShardingInMoveLastAxis(
       ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
   ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
       sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
@@ -413,10 +413,10 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   return {targetShard, targetSharding};
 }
 
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
 tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                                 MeshShardingAttr sourceSharding,
-                                 MeshShardingAttr targetSharding,
+                                 MeshSharding sourceSharding,
+                                 MeshSharding targetSharding,
                                  ShapedType sourceUnshardedShape,
                                  TypedValue<ShapedType> sourceShard) {
   if (auto detectRes =
@@ -430,13 +430,57 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   return std::nullopt;
 }
 
+static TypedValue<ShapedType>
+updateHalosInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
+                        TypedValue<ShapedType> sourceShard,
+                        MeshSharding sourceSharding,
+                        MeshSharding targetSharding) {
+  assert(sourceSharding.getMesh() == targetSharding.getMesh());
+  assert(sourceSharding.getSplitAxes() == targetSharding.getSplitAxes());
+
+  auto res = builder
+                 .create<UpdateHaloOp>(
+                     sourceShard.getType(), // update halo keeps the source type
+                     mesh.getSymName(), SmallVector<MeshAxis>(), sourceShard,
+                     ::mlir::DenseI64ArrayAttr::get(
+                         builder.getContext(),
+                         sourceSharding.getStaticHaloSizes()),
+                     nullptr)
+                 .getResult();
+  return cast<TypedValue<ShapedType>>(res);
+}
+
+// Detect if the resharding is a halo update.
+static bool detectUpdateHalosInResharding(MeshSharding sourceSharding,
+                                          MeshSharding targetSharding) {
+  return (sourceSharding.sameExceptConstraint(targetSharding) &&
+          !targetSharding.getStaticHaloSizes().empty());
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
+tryUpdateHalosInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
+                           MeshSharding sourceSharding,
+                           MeshSharding targetSharding,
+                           ShapedType sourceUnshardedShape,
+                           TypedValue<ShapedType> sourceShard) {
+  if (detectUpdateHalosInResharding(sourceSharding, targetSharding)) {
+    return std::make_tuple(updateHalosInResharding(builder, mesh, sourceShard,
+                                                   sourceSharding,
+                                                   targetSharding),
+                           targetSharding);
+  }
+
+  return std::nullopt;
+}
+
 // Handles only resharding on a 1D mesh.
 // Currently the sharded tensor axes must be exactly divisible by the single
 // mesh axis size.
 static TypedValue<ShapedType>
 reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                MeshShardingAttr sourceSharding,
-                MeshShardingAttr targetSharding,
+                bool force,
+                MeshSharding sourceSharding,
+                MeshSharding targetSharding,
                 TypedValue<ShapedType> sourceUnshardedValue,
                 TypedValue<ShapedType> sourceShard) {
   assert(sourceShard.getType() ==
@@ -455,48 +499,61 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
   }
 
   TypedValue<ShapedType> targetShard;
-  MeshShardingAttr actualTargetSharding;
-  if (auto tryRes = tryMoveLastSplitAxisInResharding(
-          builder, mesh, reducedSourceSharding, targetSharding,
-          sourceUnshardedValue.getType(), reducedSourceShard)) {
-    std::tie(targetShard, actualTargetSharding) = tryRes.value();
-  } else if (auto tryRes = trySplitLastAxisInResharding(
-                 builder, mesh, reducedSourceSharding, targetSharding,
-                 reducedSourceShard)) {
-    std::tie(targetShard, actualTargetSharding) = tryRes.value();
-  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
-                 builder, mesh, reducedSourceSharding, targetSharding,
-                 sourceUnshardedValue.getType(), reducedSourceShard)) {
-    std::tie(targetShard, actualTargetSharding) = tryRes.value();
-  } else {
-    assert(false && "Did not find any pattern to apply.");
+  MeshSharding actualTargetSharding;
+  if (!force &&
+      reducedSourceSharding.getStaticHaloSizes().empty() &&
+      targetSharding.getStaticHaloSizes().empty() &&
+      reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
+      targetSharding.getStaticShardedDimsSizes().empty()) {
+    if (auto tryRes = tryMoveLastSplitAxisInResharding(
+            builder, mesh, reducedSourceSharding, targetSharding,
+            sourceUnshardedValue.getType(), reducedSourceShard)) {
+      std::tie(targetShard, actualTargetSharding) = tryRes.value();
+    } else if (auto tryRes = trySplitLastAxisInResharding(
+                   builder, mesh, reducedSourceSharding, targetSharding,
+                   reducedSourceShard)) {
+      std::tie(targetShard, actualTargetSharding) = tryRes.value();
+    } else if (auto tryRes = tryUnsplitLastAxisInResharding(
+                   builder, mesh, reducedSourceSharding, targetSharding,
+                   sourceUnshardedValue.getType(), reducedSourceShard)) {
+      std::tie(targetShard, actualTargetSharding) = tryRes.value();
+    }
+  } else if(force) {
+    if (auto tryRes = tryUpdateHalosInResharding(
+            builder, mesh, reducedSourceSharding, targetSharding,
+            sourceUnshardedValue.getType(), reducedSourceShard)) {
+      std::tie(targetShard, actualTargetSharding) = tryRes.value();
+    }
   }
-
+  assert(targetShard && "Did not find any pattern to apply.");
   assert(actualTargetSharding == targetSharding);
   assert(targetShard.getType() == targetShardType);
   return targetShard;
 }
 
 TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                               MeshShardingAttr sourceSharding,
-                               MeshShardingAttr targetSharding,
+                               bool force,
+                               MeshSharding sourceSharding,
+                               MeshSharding targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
   // Resort to handling only 1D meshes since the general case is complicated if
   // it needs to be communication efficient in terms of minimizing the data
   // transfered between devices.
-  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
+  return reshardOn1DMesh(builder, mesh, force, sourceSharding, targetSharding,
                          sourceUnshardedValue, sourceShard);
 }
 
 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
                                ShardOp target,
                                TypedValue<ShapedType> sourceShardValue) {
-  assert(source.getResult() == target.getOperand());
+  assert(source.getResult() == target.getSrc());
+  auto sourceSharding = source.getSharding();
+  auto targetSharding = target.getSharding();
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
-  return reshard(
-      implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
-      cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
+  return reshard(implicitLocOpBuilder, mesh, target.getForce(), sourceSharding, targetSharding,
+                 cast<TypedValue<ShapedType>>(source.getSrc()),
+                 sourceShardValue);
 }
 
 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
@@ -538,20 +595,28 @@ shardedBlockArgumentTypes(Block &block,
         assert(shardOp);
         MeshOp mesh = getMesh(shardOp, symbolTableCollection);
         return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
-                                          shardOp.getShardAttr()));
+                                          shardOp.getSharding()));
       });
   return res;
 }
 
+void spmdizeTriviallyShardableOperation(
+    Operation &op, ArrayRef<Value> spmdizedOperands,
+    ArrayRef<MeshSharding> operandShardings,
+    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
+    SymbolTableCollection &symbolTable, OpBuilder &builder);
+
 static LogicalResult spmdizeOperation(
     Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshShardingAttr> operandShardings,
-    ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+    ArrayRef<MeshSharding> operandShardings,
+    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
     SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
   ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
   if (!shardingInterface) {
     // If there is no sharding interface we are conservative and assume that
     // the op should be fully replicated no all devices.
+    // FIXME
+    // spmdizeTriviallyShardableOperation
     spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
                                     resultShardings, spmdizationMap,
                                     symbolTableCollection, builder);
@@ -572,41 +637,41 @@ static LogicalResult spmdizeOperation(
 
 // Retrieve the sharding annotations for the operands of the given operation.
 // If the type is not a ranked tensor it is not require to have an annotation.
-static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
-  SmallVector<MeshShardingAttr> res;
+static std::vector<MeshSharding> getOperandShardings(Operation &op) {
+  std::vector<MeshSharding> res;
   res.reserve(op.getNumOperands());
   llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
     TypedValue<RankedTensorType> rankedTensor =
         dyn_cast<TypedValue<RankedTensorType>>(operand);
     if (!rankedTensor) {
-      return MeshShardingAttr();
+      return MeshSharding();
     }
 
     Operation *definingOp = operand.getDefiningOp();
     assert(definingOp);
     ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
-    return shardOp.getShard();
+    return MeshSharding(shardOp.getSharding());
   });
   return res;
 }
 
 // Retrieve the sharding annotations for the results of the given operation.
 // If the type is not a ranked tensor it is not require to have an annotation.
-static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
-  SmallVector<MeshShardingAttr> res;
+static std::vector<MeshSharding> getResultShardings(Operation &op) {
+  std::vector<MeshSharding> res;
   res.reserve(op.getNumResults());
   llvm::transform(op.getResults(), std::back_inserter(res),
                   [](OpResult result) {
                     TypedValue<RankedTensorType> rankedTensor =
                         dyn_cast<TypedValue<RankedTensorType>>(result);
                     if (!rankedTensor) {
-                      return MeshShardingAttr();
+                      return MeshSharding();
                     }
 
                     assert(result.hasOneUse());
                     Operation *userOp = *result.getUsers().begin();
                     ShardOp shardOp = llvm::cast<ShardOp>(userOp);
-                    return shardOp.getShard();
+                    return MeshSharding(shardOp.getSharding());
                   });
   return res;
 }
@@ -620,13 +685,13 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
   // Check if 2 shard ops are chained. If not there is no need for resharding
   // as the source and target shared the same sharding.
   ShardOp srcShardOp =
-      dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
+      dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
   if (!srcShardOp) {
-    targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
+    targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
   } else {
     // Insert resharding.
     TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
-        spmdizationMap.lookup(srcShardOp.getOperand()));
+        spmdizationMap.lookup(srcShardOp.getSrc()));
     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
                               symbolTableCollection);
   }
@@ -640,6 +705,10 @@ static LogicalResult
 spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
                  SymbolTableCollection &symbolTableCollection,
                  OpBuilder &builder) {
+  if (isa<ShardingOp>(op)) {
+    return success();
+  }
+
   ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
   if (shardOp) {
     return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index 549b9f10388bd..9439f4099a49a 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES
 add_mlir_dialect_library(MLIRTensorDialect
   TensorDialect.cpp
   TensorOps.cpp
+  ShardingInterfaceImpl.cpp
   ValueBoundsOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp
new file mode 100644
index 0000000000000..15995a7bae038
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp
@@ -0,0 +1,101 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tensor-sharding-impl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::tensor;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of elementwise operations like tensor addition and multiplication.
+struct EmptyOpShardingInterface
+    : public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
+                                              tensor::EmptyOp> {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *) const {
+    return {utils::IteratorType::parallel};
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    MLIRContext *ctx = op->getContext();
+    Value val = op->getResult(0);
+    auto type = dyn_cast<RankedTensorType>(val.getType());
+    if (!type)
+      return {};
+    return {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)};
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshSharding> operandShardings,
+                        ArrayRef<MeshSharding> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    auto shardType = cast<ShapedType>(mesh::shardType(
+        op->getResult(0).getType(),
+        mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable),
+        resultShardings[0]));
+    Operation *newOp = nullptr;
+    // if the sharding introduces a new dynamic dimension, we take it from
+    // the dynamic sharding info. For now bail out if it's not
+    // provided.
+		assert(resultShardings.size() == 1);
+    if (!shardType.hasStaticShape()) {
+      assert(op->getResult(0).hasOneUse());
+      SmallVector<Value> newOperands;
+      auto oldType = cast<ShapedType>(op->getResult(0).getType());
+      assert(oldType.getRank() == shardType.getRank());
+      int currOldOprndNum = -1;
+      mesh::ShardShapeOp shapeForDevice;
+      Value device;
+			Operation *newSharding = nullptr;
+      for (auto i = 0; i < oldType.getRank(); ++i) {
+        if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
+          if (!newSharding) {
+						newSharding = builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
+            device = builder.create<mesh::ProcessLinearIndexOp>(
+                op->getLoc(), resultShardings[0].getMesh());
+            shapeForDevice = builder.create<mesh::ShardShapeOp>(
+                op->getLoc(), oldType.getShape(), newSharding->getResult(0), device);
+          }
+          newOperands.emplace_back(shapeForDevice.getResult()[i]);
+        } else if (oldType.isDynamicDim(i)) {
+          assert(shardType.isDynamicDim(i));
+          newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
+        }
+      }
+      newOp =
+          builder.create<tensor::EmptyOp>(op->getLoc(), shardType, newOperands);
+      spmdizationMap.map(op->getResult(0), newOp->getResult(0));
+    } else {
+      // `clone` will populate the mapping of old to new results.
+      newOp = builder.clone(*op, spmdizationMap);
+    }
+    newOp->getResult(0).setType(shardType);
+
+    return success();
+  }
+};
+} // namespace
+
+void mlir::tensor::registerShardingInterfaceExternalModels(
+    DialectRegistry &registry) {
+
+  registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
+    EmptyOp::template attachInterface<EmptyOpShardingInterface>(*ctx);
+  });
+}
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index 52f352cfedd8e..b105f5007d532 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -196,4 +196,4 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
   %res_replicated = mesh.shard %res_sharded to <@mesh_1d, [[], []]> annotate_for_users: tensor<4x8xi8>
   // CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
   return %res_replicated : tensor<4x8xi8>
-}
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 6e5df86b13106..74d8a086d5b34 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -20,24 +20,30 @@ mesh.mesh @mesh5(shape = ?)
 // CHECK-LABEL: func @mesh_shard_op_fully_replicated
 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
 func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}]]> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh0, [[]]> : tensor<4x8xf32>
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0, {{\[\[}}]] : !mesh.sharding
+  %s = mesh.sharding @mesh0, [[]] : !mesh.sharding
+  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
 // CHECK-LABEL: func @mesh_shard_op_1st_dim
 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
 func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0, {{\[\[}}0]] : !mesh.sharding
+  %s = mesh.sharding @mesh0, [[0]] : !mesh.sharding
+
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
 // CHECK-LABEL: func @mesh_shard_op_2nd_dim
 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
 func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh1, {{\[\[}}], [0]]> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh1, [[], [0]]> : tensor<4x8xf32>
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1, {{\[\[}}], [0]] : !mesh.sharding
+  %s = mesh.sharding @mesh1, [[], [0]] : !mesh.sharding
+  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -45,8 +51,10 @@ func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
 func.func @mesh_shard_op_1st_and_3rd_dim(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32>
     %arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]]> : tensor<4x8x16xf32>
-  %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]]> : tensor<4x8x16xf32>
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3, {{\[\[}}0], [], [1]] : !mesh.sharding
+  %s = mesh.sharding @mesh3, [[0], [], [1]] : !mesh.sharding
+  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32>
+  %0 = mesh.shard %arg0 to %s : tensor<4x8x16xf32>
   return %0 : tensor<4x8x16xf32>
 }
 
@@ -54,8 +62,10 @@ func.func @mesh_shard_op_1st_and_3rd_dim(
 func.func @mesh_shard_op_partial_max(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = max[1]> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = max[1]> : tensor<4x8xf32>
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3, {{\[\[}}0]] partial = max [1] : !mesh.sharding
+  %s = mesh.sharding @mesh3, [[0]] partial = max[1] : !mesh.sharding
+  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -63,8 +73,10 @@ func.func @mesh_shard_op_partial_max(
 func.func @mesh_shard_op_partial_min(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = min[1]> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = min[1]> : tensor<4x8xf32>
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3, {{\[\[}}0]] partial = min [1] : !mesh.sharding
+  %s = mesh.sharding @mesh3, [[0]] partial = min[1] : !mesh.sharding
+  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -72,8 +84,10 @@ func.func @mesh_shard_op_partial_min(
 func.func @mesh_shard_op_partial_generic(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = generic[1]> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = generic[1]> : tensor<4x8xf32>
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3, {{\[\[}}0]] partial = generic [1] : !mesh.sharding
+  %s = mesh.sharding @mesh3, [[0]] partial = generic[1] : !mesh.sharding
+  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -81,8 +95,10 @@ func.func @mesh_shard_op_partial_generic(
 func.func @mesh_shard_op_partial_sum(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1]> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1]> : tensor<4x8xf32>
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3, {{\[\[}}0]] partial = sum [1] : !mesh.sharding
+  %s = mesh.sharding @mesh3, [[0]] partial = sum[1] : !mesh.sharding
+  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -90,8 +106,10 @@ func.func @mesh_shard_op_partial_sum(
 func.func @mesh_shard_op_partial_sum_multi_axes(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1, 2]> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1, 2]> : tensor<4x8xf32>
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3, {{\[\[}}0]] partial = sum [1, 2] : !mesh.sharding
+  %s = mesh.sharding @mesh3, [[0]] partial = sum[1, 2] : !mesh.sharding
+  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -99,12 +117,15 @@ func.func @mesh_shard_op_partial_sum_multi_axes(
 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
 func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> 
                                   (tensor<4x8xf32>, tensor<4x8xf32>) {
-  // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32>                  
-  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
-  // CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}1]]> annotate_for_users : tensor<4x8xf32>
-  %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
-  // CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}2]]> annotate_for_users : tensor<4x8xf32>
-  %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
+  // CHECK-NEXT: %[[V0:.*]] = mesh.sharding @mesh0, {{\[\[}}0]] : !mesh.sharding                  
+  %s0 = mesh.sharding @mesh0, [[0]] : !mesh.sharding                
+  %0 = mesh.shard %arg0 to %s0 : tensor<4x8xf32>
+  // CHECK-DAG: mesh.sharding @mesh0, {{\[\[}}1]] : !mesh.sharding 
+  %s1 = mesh.sharding @mesh0, [[1]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<4x8xf32>
+  // CHECK-DAG: mesh.sharding @mesh0, {{\[\[}}2]] : !mesh.sharding 
+  %s2 = mesh.sharding @mesh0, [[2]] : !mesh.sharding
+  %2 = mesh.shard %0 to %s2 annotate_for_users : tensor<4x8xf32>
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
 
@@ -168,9 +189,9 @@ func.func @process_linear_index() -> index {
 func.func @all_reduce(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
-  // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = <max>
+  // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = max
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = max
     : tensor<3x4xf32> -> tensor<3x4xf64>
   return %0 : tensor<3x4xf64>
 }
@@ -442,10 +463,10 @@ func.func @reduce_scatter_static_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
   // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = <max> scatter_axis = 1
+  // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = max scatter_axis = 1
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
   %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
-    reduction = <max> scatter_axis = 1
+    reduction = max scatter_axis = 1
     : tensor<3x4xf32> -> tensor<3x1xf64>
   return %0 : tensor<3x1xf64>
 }
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index d7a1e2fd9d279..6888fa609601d 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt \
+// RUN: mlir-opt -allow-unregistered-dialect \
 // RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
 // RUN:   %s | FileCheck %s
 
@@ -158,3 +158,85 @@ func.func @incomplete_sharding(
   // CHECK: return %[[RES]] : tensor<4x16xf32>
   return %2 : tensor<8x16xf32>
 }
+
+mesh.mesh @mesh_1d_4(shape = 4)
+// CHECK-LABEL: func @update_halo_constraint
+func.func @update_halo_constraint(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<11x16xi8>
+  %in1: tensor<32x16xi8>
+  // CHECK-SAME: -> tensor<11x16xi8> {
+) -> tensor<32x16xi8> {
+  // CHECK: %[[RES:.*]] = mesh.update_halo %[[IN1]] on @mesh_1d_4 halo_sizes = [2, 1] : (tensor<11x16xi8>) -> tensor<11x16xi8>
+  %in1_sharded1 = mesh.shard %in1 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> : tensor<32x16xi8>
+  %in1_sharded2 = mesh.shard %in1_sharded1 to <@mesh_1d_4, [[0]] {<force = true halo_sizes = [2, 1]>}> annotate_for_users: tensor<32x16xi8>
+  // CHECK: return %[[RES]] : tensor<11x16xi8>
+  return %in1_sharded2 : tensor<32x16xi8>
+}
+
+// CHECK-LABEL: func @ew_chain_with_halo
+func.func @ew_chain_with_halo(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
+  %arg0: tensor<8x16xf32>)
+  // CHECK-SAME: -> tensor<5x16xf32>
+   -> tensor<8x16xf32> {
+  %sharding_annotated = mesh.shard %arg0 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> annotate_for_users : tensor<8x16xf32>
+  // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+  %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %sharding_annotated_0 = mesh.shard %0 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> : tensor<8x16xf32>
+  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+  %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %sharding_annotated_2 = mesh.shard %1 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> : tensor<8x16xf32>
+  %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+  %2 = tosa.negate %sharding_annotated_4 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %sharding_annotated_5 = mesh.shard %2 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> : tensor<8x16xf32>
+  %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
+  return %sharding_annotated_6 : tensor<8x16xf32>
+}
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @stencil_with_halo
+func.func @stencil_with_halo() -> () {
+  %a = "xxx.empty"() : () -> tensor<32x16xf32>
+  %sc1 = mesh.sharding_constraint sharded_dims = [] halo_sizes = [1, 2] : !mesh.constraint
+  %sa = mesh.shard %a to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc1 : tensor<32x16xf32>
+  %b = "xxx.empty"() : () -> tensor<8x16xf32>
+  %sc2 = mesh.sharding_constraint sharded_dims = [1, 2, 3, 2] halo_sizes = [] : !mesh.constraint
+  %sb = mesh.shard %b to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 : tensor<8x16xf32>
+
+  %sai1 = mesh.shard %sa to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc1 annotate_for_users : tensor<32x16xf32>
+  %v1 = "xxx.view"(%sa) {x = 1} : (tensor<32x16xf32>) -> tensor<8x16xf32>
+  %sv1 = mesh.shard %v1 to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 : tensor<8x16xf32>
+
+  %sai2 = mesh.shard %sa to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc1 annotate_for_users : tensor<32x16xf32>
+  %v2 = "xxx.view"(%sa) {x = 2} : (tensor<32x16xf32>) -> tensor<8x16xf32>
+  %sv2 = mesh.shard %v2 to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 : tensor<8x16xf32>
+  
+  %v1i = mesh.shard %sv1 to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 annotate_for_users : tensor<8x16xf32>
+  %v2i = mesh.shard %sv2 to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 annotate_for_users : tensor<8x16xf32>
+  %bo = mesh.shard %sb to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 annotate_for_users : tensor<8x16xf32>
+  %r = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%v1i, %v2i : tensor<8x16xf32>, tensor<8x16xf32>) outs(%bo : tensor<8x16xf32>) {
+    ^bb0(%in: f32, %in_56: f32, %out: f32):
+      %47 = arith.addf %in, %in_56 : f32
+      linalg.yield %47 : f32
+    } -> tensor<8x16xf32>
+  %sr = mesh.shard %r to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 : tensor<8x16xf32>
+
+  %sai3 = mesh.shard %sa to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc1 annotate_for_users : tensor<32x16xf32>
+  %sri = mesh.shard %sr to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 annotate_for_users : tensor<8x16xf32>
+  "xxx.insert_slice"(%sai3, %sri) : (tensor<32x16xf32>, tensor<8x16xf32>) -> ()
+  %sc3 = mesh.sharding_constraint sharded_dims = [] halo_sizes = [1, 2] force : !mesh.constraint
+  %sai4 = mesh.shard %sa to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc3 : tensor<32x16xf32>
+
+  return
+}
+// CHECK: %[[V0:.*]] = "xxx.empty"() : () -> tensor<11x16xf32>
+// CHECK-NEXT: %[[V1:.*]] = "xxx.empty"() : () -> tensor<?x16xf32>
+// CHECK-NEXT: %[[V2:.*]] = "xxx.view"([[V0]]) {x = 1 : i64} : (tensor<11x16xf32>) -> tensor<?x16xf32>
+// CHECK-NEXT: %[[V3:.*]] = "xxx.view"([[V0]]) {x = 2 : i64} : (tensor<11x16xf32>) -> tensor<?x16xf32>
+// CHECK-NEXT: %[[V4:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[V2]], [[V3]] : tensor<?x16xf32>, tensor<?x16xf32>) outs([[V1]] : tensor<?x16xf32>) {
+// CHECK: "xxx.insert_slice"([[V0]], [[V4]]) : (tensor<11x16xf32>, tensor<?x16xf32>) -> ()
+// CHECK-NEXT: %update_halo = mesh.update_halo [[V0]] on @mesh_1d_4 halo_sizes = [1, 2] : (tensor<11x16xf32>) -> tensor<11x16xf32>
+// CHECK-NEXT: return
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
index f96410245f281..b72bcdbad3edd 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -37,14 +37,14 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
 
     SymbolTableCollection symbolTable;
     mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
-        op, op.getShard().getMesh());
+        op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
 
     bool foundUser = false;
     for (auto user : op->getUsers()) {
       if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
         if (targetShardOp.getAnnotateForUsers() &&
             mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
-                        targetShardOp, targetShardOp.getShard().getMesh())) {
+                        targetShardOp, cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp()).getMeshAttr())) {
           foundUser = true;
           break;
         }
@@ -59,17 +59,16 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
       auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
       if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
           symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
-              targetShardOp, targetShardOp.getShard().getMesh()) != mesh) {
+              targetShardOp, cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp()).getMeshAttr()) != mesh) {
         continue;
       }
 
       ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
       ShapedType sourceShardShape =
-          shardShapedType(op.getResult().getType(), mesh, op.getShard());
+          shardShapedType(op.getResult().getType(), mesh, op.getSharding());
       TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
           builder
-              .create<UnrealizedConversionCastOp>(sourceShardShape,
-                                                  op.getOperand())
+              .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
               ->getResult(0));
       TypedValue<ShapedType> targetShard =
           reshard(builder, mesh, op, targetShardOp, sourceShard);

>From 57641a75dde2fc796c607e2e845699d87ba5541f Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 8 Jul 2024 18:36:05 +0200
Subject: [PATCH 2/8] fixing existing some sharding tests

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  81 +------
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |   4 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 105 +++++++--
 .../Mesh/Interfaces/ShardingInterface.cpp     |   6 +-
 .../Mesh/Transforms/ShardingPropagation.cpp   |   2 +-
 mlir/test/Dialect/Mesh/canonicalization.mlir  |   4 +-
 mlir/test/Dialect/Mesh/invalid.mlir           |  46 +++-
 .../Dialect/Mesh/resharding-spmdization.mlir  |  78 +++++--
 .../Dialect/Mesh/sharding-propagation.mlir    | 198 +++++++++++------
 mlir/test/Dialect/Mesh/simplifications.mlir   |  16 +-
 mlir/test/Dialect/Mesh/spmdization.mlir       | 209 +++++++++++-------
 11 files changed, 452 insertions(+), 297 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 12677c0bae740..6e1afcde5f0f5 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -49,6 +49,7 @@ class MeshSharding {
     SmallVector<Value> dynamic_halo_sizes;
     SmallVector<Value> dynamic_sharded_dims_sizes;
   public:
+    MeshSharding() = default;
     MeshSharding(Value rhs);
     static MeshSharding get(
         ::mlir::FlatSymbolRefAttr mesh_,
@@ -59,7 +60,6 @@ class MeshSharding {
         ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
         ArrayRef<Value> dynamic_halo_sizes_ = {},
         ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
-    MeshSharding() = default;
     ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
     ::llvm::StringRef getMesh() const { return mesh.getValue(); }
     ArrayRef<MeshAxesAttr> getSplitAxes() const {return split_axes; }
@@ -72,10 +72,10 @@ class MeshSharding {
     operator bool() const { return (!mesh) == false; }
     bool operator==(Value rhs) const;
     bool operator!=(Value rhs) const;
-    bool operator==(MeshSharding rhs) const;
-    bool operator!=(MeshSharding rhs) const;
-    template<typename RHS> bool sameExceptConstraint(RHS rhs) const;
-    template<typename RHS> bool sameConstraint(RHS rhs) const;
+    bool operator==(const MeshSharding &rhs) const;
+    bool operator!=(const MeshSharding &rhs) const;
+    bool sameExceptConstraint(const MeshSharding &rhs) const;
+    bool sameConstraint(const MeshSharding &rhs) const;
 };
 
 } // namespace mesh
@@ -193,80 +193,15 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 
 // Insert shard op if there is not one that already has the same sharding.
 // May insert resharding if required.
-void maybeInsertTargetShardingAnnotation(Value sharding,
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                          OpOperand &operand,
                                          OpBuilder &builder);
-void maybeInsertTargetShardingAnnotation(Value sharding,
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                          OpResult result, OpBuilder &builder);
-void maybeInsertSourceShardingAnnotation(Value sharding,
+void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
                                          OpOperand &operand,
                                          OpBuilder &builder);
 
-
-template<typename RHS>
-bool MeshSharding::sameExceptConstraint(RHS rhs) const {
-  if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
-    return false;
-  }
-
-  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
-    return false;
-  }
-
-  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
-  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
-                                    getSplitAxes().begin() + minSize),
-                   llvm::make_range(rhs.getSplitAxes().begin(),
-                                    rhs.getSplitAxes().begin() + minSize))) {
-    return false;
-  }
-
-  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
-                                       getSplitAxes().end()),
-                      std::mem_fn(&MeshAxesAttr::empty)) &&
-         llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
-                                       rhs.getSplitAxes().end()),
-                      std::mem_fn(&MeshAxesAttr::empty));
-}
-
-template<typename RHS>
-bool MeshSharding::sameConstraint(RHS rhs) const {
-    if (rhs.getStaticHaloSizes().size() == getStaticHaloSizes().size() ) {
-      if (!llvm::equal(llvm::make_range(getStaticHaloSizes().begin(), getStaticHaloSizes().end()),
-                       llvm::make_range(rhs.getStaticHaloSizes().begin(), rhs.getStaticHaloSizes().end()))) {
-        return false;
-      }
-    } else {
-      return false;
-    }
-    if (rhs.getStaticShardedDimsSizes().size() == getDynamicHaloSizes().size() ) {
-      if (!llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(), getStaticShardedDimsSizes().end()),
-                       llvm::make_range(rhs.getStaticShardedDimsSizes().begin(), rhs.getStaticShardedDimsSizes().end()))) {
-        return false;
-      }
-    } else {
-      return false;
-    }
-    if (rhs.getDynamicHaloSizes().size() == getStaticShardedDimsSizes().size() ) {
-      if (!llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(), getDynamicHaloSizes().end()),
-                       llvm::make_range(rhs.getDynamicHaloSizes().begin(), rhs.getDynamicHaloSizes().end()))) {
-        return false;
-      }
-    } else {
-      return false;
-    }
-    if (rhs.getDynamicShardedDimsSizes().size() == getDynamicShardedDimsSizes().size()) {
-      if (!llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(), getDynamicShardedDimsSizes().end()),
-                       llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(), rhs.getDynamicShardedDimsSizes().end()))) {
-        return false;
-      }
-    } else {
-      return false;
-    }
-    return true;
-}
-
-
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index b4de29f8e3214..17d7100b58165 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -253,8 +253,8 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
   let assemblyFormat = [{
     $mesh `,` $split_axes
     (`partial` `=` $partial_type $partial_axes^)?
-    oilist(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes) |
-           `sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes))
+    (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
+    (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
     attr-dict `:` type($result)
   }];
   let builders = [
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 9c7c79e602903..bb6674fd02ecf 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -230,7 +230,7 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   return type;
 }
 
-void mlir::mesh::maybeInsertTargetShardingAnnotation(Value sharding,
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpOperand &operand,
                                                      OpBuilder &builder) {
   OpBuilder::InsertionGuard insertionGuard(builder);
@@ -244,8 +244,9 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(Value sharding,
     return;
   }
 
+  auto shardingOp = builder.create<ShardingOp>(operandValue.getLoc(), sharding);
   auto newShardOp =
-      builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+      builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
                               /*annotate_for_users*/ false);
   IRRewriter rewriter(builder);
   rewriter.replaceUsesWithIf(
@@ -258,11 +259,11 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(Value sharding,
   }
 
   auto newShardOp2 = builder.create<ShardOp>(
-      operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
+      operandValue.getLoc(), newShardOp, shardingOp, /*annotate_for_users*/ true);
   rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
 }
 
-void mlir::mesh::maybeInsertTargetShardingAnnotation(Value sharding,
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
   for (auto &use : llvm::make_early_inc_range(result.getUses())) {
@@ -270,7 +271,7 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(Value sharding,
   }
 }
 
-void mlir::mesh::maybeInsertSourceShardingAnnotation(Value sharding,
+void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
                                                      OpOperand &operand,
                                                      OpBuilder &builder) {
   OpBuilder::InsertionGuard insertionGuard(builder);
@@ -287,8 +288,9 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(Value sharding,
   }
 
   builder.setInsertionPoint(operandOp);
+  auto shardingOp = builder.create<ShardingOp>(operand.get().getLoc(), sharding);
   auto newShardOp =
-      builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+      builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
                               /*annotate_for_users*/ true);
   IRRewriter rewriter(builder);
   rewriter.replaceUsesWithIf(
@@ -303,7 +305,7 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(Value sharding,
 
   builder.setInsertionPoint(newShardOp);
   auto newPreceedingShardOp =
-      builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+      builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
                               /*annotate_for_users*/ false);
   rewriter.replaceUsesWithIf(
       newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
@@ -413,41 +415,104 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, F
 
 
 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, mlir::mesh::MeshSharding from) {
+
   build(b, odsState,
-        b.getIndexType(),
+        ShardingType::get(b.getContext()),
         from.getMeshAttr(),
         MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
-        b.getDenseI16ArrayAttr(from.getPartialAxes()),
+        from.getPartialAxes().empty() ? DenseI16ArrayAttr() : b.getDenseI16ArrayAttr(from.getPartialAxes()),
         ::mlir::mesh::ReductionKindAttr::get(b.getContext(), from.getPartialType()),
-        b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
+        from.getStaticShardedDimsSizes().empty() ? DenseI64ArrayAttr() : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
         from.getDynamicShardedDimsSizes(),
-        b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
+        from.getStaticHaloSizes().empty() ? DenseI64ArrayAttr() : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
         from.getDynamicHaloSizes());
 }
 
+//===----------------------------------------------------------------------===//
+// MeshSharding
+//===----------------------------------------------------------------------===//
+
+    // ::mlir::FlatSymbolRefAttr MeshSharding::getMeshAttr() const { return mesh; }
+    // ::llvm::StringRef MeshSharding::getMesh() const { return mesh.getValue(); }
+    // ArrayRef<MeshAxesAttr> MeshSharding::getSplitAxes() const {return split_axes; }
+    // ArrayRef<MeshAxis> MeshSharding::getPartialAxes() const { if (partial_axes.empty()) return {}; return partial_axes; }
+    // ReductionKind MeshSharding::getPartialType() const { return partial_type; }
+    // ArrayRef<int64_t> MeshSharding::getStaticHaloSizes() const { if(static_halo_sizes.empty()) return {}; return static_halo_sizes; }
+    // ArrayRef<int64_t> MeshSharding::getStaticShardedDimsSizes() const { if(static_sharded_dims_sizes.empty()) return {}; return static_sharded_dims_sizes; }
+    // ArrayRef<Value> MeshSharding::getDynamicHaloSizes() const { if(dynamic_halo_sizes.empty()) return {}; return dynamic_halo_sizes; }
+    // ArrayRef<Value> MeshSharding::getDynamicShardedDimsSizes() const { if(dynamic_sharded_dims_sizes.empty()) return {}; return dynamic_sharded_dims_sizes; }
+    // operator MeshSharding::bool() const { return (!mesh) == false; }
+
+bool MeshSharding::sameExceptConstraint(const MeshSharding &rhs) const {
+  if (getMesh() != rhs.getMesh()) {
+    return false;
+  }
+
+  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+    return false;
+  }
+
+  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
+  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
+                                    getSplitAxes().begin() + minSize),
+                   llvm::make_range(rhs.getSplitAxes().begin(),
+                                    rhs.getSplitAxes().begin() + minSize))) {
+    return false;
+  }
+
+  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
+                                       getSplitAxes().end()),
+                      std::mem_fn(&MeshAxesAttr::empty)) &&
+         llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
+                                       rhs.getSplitAxes().end()),
+                      std::mem_fn(&MeshAxesAttr::empty));
+}
+
+bool MeshSharding::sameConstraint(const MeshSharding &rhs) const {
+    if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() 
+        || !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(), getStaticHaloSizes().end()),
+                        llvm::make_range(rhs.getStaticHaloSizes().begin(), rhs.getStaticHaloSizes().end()))) {
+      return false;
+    }
+    if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size()
+        || !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(), getStaticShardedDimsSizes().end()),
+                        llvm::make_range(rhs.getStaticShardedDimsSizes().begin(), rhs.getStaticShardedDimsSizes().end()))) {
+      return false;
+    }
+    if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size()
+        || !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(), getDynamicHaloSizes().end()),
+                        llvm::make_range(rhs.getDynamicHaloSizes().begin(), rhs.getDynamicHaloSizes().end()))) {
+      return false;
+    }
+    if (rhs.getDynamicShardedDimsSizes().size() != getDynamicShardedDimsSizes().size()
+        || !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(), getDynamicShardedDimsSizes().end()),
+                        llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(), rhs.getDynamicShardedDimsSizes().end()))) {
+      return false;
+    }
+    return true;
+}
+
 bool MeshSharding::operator==(Value rhs) const {
-  auto shardingOp =
-      mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
-  return shardingOp && sameExceptConstraint(shardingOp)
-         && sameConstraint(shardingOp);
+  return sameExceptConstraint(rhs)
+         && sameConstraint(rhs);
 }
 
 bool MeshSharding::operator!=(Value rhs) const {
   return !(*this == rhs);
 }
 
-bool MeshSharding::operator==(MeshSharding rhs) const {
+bool MeshSharding::operator==(const MeshSharding &rhs) const {
   return sameExceptConstraint(rhs) && sameConstraint(rhs);
 }
 
-bool MeshSharding::operator!=(MeshSharding rhs) const {
+bool MeshSharding::operator!=(const MeshSharding &rhs) const {
   return !(*this == rhs);
 }
 
 MeshSharding::MeshSharding(Value rhs) {
   auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
   assert(shardingOp && "expected sharding op");
-  get(shardingOp.getMeshAttr(),
+  *this = get(shardingOp.getMeshAttr(),
       shardingOp.getSplitAxes().getAxes(),
       shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
       shardingOp.getPartialType().value_or(ReductionKind::Sum),
@@ -472,7 +537,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
     res.split_axes[i] = MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
   }
 
-  auto do_copy = [&](auto src, auto dst) {
+  auto do_copy = [&](auto src, auto &dst) {
     dst.resize(src.size());
     for (auto [i, v] : llvm::enumerate(src)) {
       dst[i] = v;
@@ -526,7 +591,7 @@ LogicalResult ShardingOp::verify() {
       }
     }
     if (getStaticHaloSizes().size() != numSplitAxes * 2) {
-      return emitError() << "Halo sizes must be specified for all split axes.";
+      return emitError() << "halo sizes must be specified for all split axes.";
     }
   }
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index e525c31791261..df42a335e89df 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -517,8 +517,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                 ArrayRef<ReductionKind> reductionLoopKinds) {
   MeshSharding sharding = getShardingAttribute(
       result, shardingOption, map, loopTypes, reductionLoopKinds);
-  auto shardingOp = b.create<ShardingOp>(result.getLoc(), sharding);
-  maybeInsertTargetShardingAnnotation(shardingOp.getResult(), result, b);
+  maybeInsertTargetShardingAnnotation(sharding, result, b);
 
   return success();
 }
@@ -535,8 +534,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
     return failure();
   }
   OpBuilder::InsertionGuard guard(b);
-  auto shardingOp = b.create<ShardingOp>(opOperand.get().getLoc(), sharding.value());
-  maybeInsertSourceShardingAnnotation(shardingOp.getResult(), opOperand, b);
+  maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b);
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index ac264d93a9776..c98d975392dbb 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -285,7 +285,7 @@ static FailureOr<ShardingOption> selectShardingOption(
 // a `mesh.shard` operation for all remaining operands and results that do not
 // have sharding annotations.
 static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
-  if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
+  if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
     return success();
 
   ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 633324ae680eb..ea2bd29056ec7 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -31,7 +31,7 @@ func.func @all_reduce_default_reduction(
   %0 = mesh.all_reduce %arg0 on @mesh0
     mesh_axes = [0]
 // CHECK-NOT: reduction
-    reduction = <sum>
+    reduction = sum
     : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
@@ -159,7 +159,7 @@ func.func @reduce_scatter_default_reduction(
   %0 = mesh.reduce_scatter %arg0 on @mesh0
     mesh_axes = [0]
 // CHECK-NOT: reduction
-    reduction = <sum>
+    reduction = sum
     scatter_axis = 0
     : tensor<4xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 6d7df86d78406..421e15f7cd57b 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -15,7 +15,8 @@ mesh.mesh @mesh0(shape = 2x4)
 func.func @mesh_axis_duplicated_different_subarray(
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
   // expected-error at +1 {{mesh axis duplicated}}
-  %0 = mesh.shard %arg0 to <@mesh0, [[0], [0]]> : tensor<4x8xf32>
+  %s = mesh.sharding @mesh0, [[0], [0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -26,7 +27,8 @@ mesh.mesh @mesh0(shape = 2x4)
 func.func @mesh_axis_duplicated_same_subarray(
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
   // expected-error at +1 {{mesh axis duplicated}}
-  %0 = mesh.shard %arg0 to <@mesh0, [[0, 0]]> : tensor<4x8xf32>
+  %s = mesh.sharding @mesh0, [[0, 0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -37,7 +39,8 @@ mesh.mesh @mesh0(shape = 2x4)
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
   // expected-error at +1 {{mesh axis duplicated}}
-  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[0]> : tensor<4x8xf32>
+  %s = mesh.sharding @mesh0, [[0]] partial=max[0] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -48,7 +51,8 @@ mesh.mesh @mesh0(shape = 2x4)
 func.func @mesh_axis_negtive_in_split_part(
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
   // expected-error at +1 {{mesh axis is expected to be non-negative}}
-  %0 = mesh.shard %arg0 to <@mesh0, [[-1]]> : tensor<4x8xf32>
+  %s = mesh.sharding @mesh0, [[-1]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -59,16 +63,36 @@ mesh.mesh @mesh0(shape = 2x4)
 func.func @mesh_axis_negtive_in_partial(
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
   // expected-error at +1 {{mesh axis is expected to be non-negative}}
-  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[-1]> : tensor<4x8xf32>
+  %s = mesh.sharding @mesh0, [[0]] partial=max[-1] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
 // -----
 
 func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
-  // expected-error at +2 {{custom op 'mesh.shard' invalid kind of attribute specified}}
-  // expected-error at +1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'mesh' which is to be a `::mlir::FlatSymbolRefAttr`}}
-  %0 = mesh.shard %arg0 to <@a::@b, [[0]]> : tensor<4x8xf32>
+  // expected-error at +1 {{custom op 'mesh.sharding' invalid kind of attribute specified}}
+  %s = mesh.sharding @a::@b, [[0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  return
+}
+
+// -----
+
+func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
+  // expected-error at +1 {{halo sizes must be specified for all split axes}}
+  %s = mesh.sharding @mesh0, [[0], [1]] halo_sizes = [1, 2] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  return
+}
+
+// -----
+
+func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
+  // expected-error at +1 {{halo sizes and shard shapes are mutually exclusive}}
+  %s = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2] sharded_dims_sizes = [2, 2] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  return
 }
 
 // -----
@@ -180,7 +204,7 @@ func.func @process_linear_index_invalid_mesh_name() -> (index) {
 func.func @all_reduce_invalid_mesh_symbol(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = <sum>
+  %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = sum
     : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
@@ -192,7 +216,7 @@ mesh.mesh @mesh0(shape = 2x4)
 func.func @all_reduce_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
   // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = <sum>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = sum
     : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
@@ -204,7 +228,7 @@ mesh.mesh @mesh0(shape = 2x4)
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = <sum>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = sum
     : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index b3e305135ad8b..0ff221ef89e81 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -8,8 +8,22 @@ func.func @same_source_and_target_sharding(
   // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
   %arg0: tensor<2xf32>
 ) -> tensor<2xf32> {
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xf32>
-  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<2xf32>
+  %s0 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
+  %s1 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xf32>
+  // CHECK: return %[[ARG]]
+  return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @identical_source_and_target_sharding
+func.func @identical_source_and_target_sharding(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+  %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+  %s0 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
+  %1 = mesh.shard %0 to %s0 annotate_for_users : tensor<2xf32>
   // CHECK: return %[[ARG]]
   return %1 : tensor<2xf32>
 }
@@ -22,8 +36,10 @@ func.func @split_replicated_tensor_axis(
   // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1
   // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
   // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
-  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
+  %s0 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<3x14xf32>
+  %s1 = mesh.sharding @mesh_1d, [[], [0]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<3x14xf32>
   // CHECK: return %[[RESULT]] : tensor<3x14xf32>
   return %1 : tensor<3x14xf32>
 }
@@ -35,8 +51,10 @@ func.func @split_replicated_tensor_axis_dynamic(
 ) -> tensor<?x3x?xf32> {
   // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0
   // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
-  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
-  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
+  %s0 = mesh.sharding @mesh_1d_dynamic, [[], [], []] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<?x3x?xf32>
+  %s1 = mesh.sharding @mesh_1d_dynamic, [[0]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32>
   // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
   return %1 : tensor<?x3x?xf32>
 }
@@ -49,8 +67,10 @@ func.func @move_split_axis(
   // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
   // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
   // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
-  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+  %s0 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = mesh.sharding @mesh_1d, [[], [0]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
   // CHECK: return %[[RES]] : tensor<10x14xf32>
   return %1 : tensor<10x14xf32>
 }
@@ -64,8 +84,10 @@ func.func @move_split_axis_dynamic_mesh(
   // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
   // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
   // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
-  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
-  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+  %s0 = mesh.sharding @mesh_1d_dynamic, [[0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = mesh.sharding @mesh_1d_dynamic, [[], [0]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
   // CHECK: return %[[RES]] : tensor<10x14xf32>
   return %1 : tensor<10x14xf32>
 }
@@ -77,8 +99,10 @@ func.func @move_split_dynamic_axis(
 ) -> tensor<?x14xf32> {
   // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
   // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
-  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<?x14xf32>
+  %s0 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
+  %s1 = mesh.sharding @mesh_1d, [[], [0]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
   // CHECK: return %[[RES]] : tensor<?x14xf32>
   return %1 : tensor<?x14xf32>
 }
@@ -90,8 +114,10 @@ func.func @unshard_static_axis(
 ) -> tensor<10x14xf32> {
   // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
   // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
-  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+  %s0 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
   // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
   return %1 : tensor<10x14xf32>
 }
@@ -103,8 +129,10 @@ func.func @unshard_static_last_axis(
 ) -> tensor<10x14xf32> {
   // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
   // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[], [0]]> : tensor<10x14xf32>
-  %1 = mesh.shard %0 to <@mesh_1d, [[], []]> annotate_for_users : tensor<10x14xf32>
+  %s0 = mesh.sharding @mesh_1d, [[], [0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = mesh.sharding @mesh_1d, [[], []] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
   // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
   return %1 : tensor<10x14xf32>
 }
@@ -115,8 +143,10 @@ func.func @unshard_dynamic_axis(
   %arg0: tensor<?x14xf32>
 ) -> tensor<?x14xf32> {
   // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
-  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<?x14xf32>
+  %s0 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
+  %s1 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
   // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
   return %1 : tensor<?x14xf32>
 }
@@ -129,8 +159,10 @@ func.func @unshard_static_axis_on_dynamic_mesh_axis(
   // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
   // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
   // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
-  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
-  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[]]> annotate_for_users : tensor<10x14xf32>
+  %s0 = mesh.sharding @mesh_1d_dynamic, [[0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = mesh.sharding @mesh_1d_dynamic, [[]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
   // CHECK: return %[[RES]] : tensor<10x14xf32>
   return %1 : tensor<10x14xf32>
 }
@@ -141,8 +173,10 @@ func.func @partial_axis_to_full_replication(
   %arg0: tensor<10x14xf32>
 ) -> tensor<10x14xf32> {
   // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
-  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+  %s0 = mesh.sharding @mesh_1d, [[]] partial = sum[0] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
   // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
   return %1 : tensor<10x14xf32>
 }
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 11a80594adb79..1a4b21bea51b7 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -16,11 +16,14 @@ func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8
 // CHECK-LABEL: func.func @element_wise_on_def
 // CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
 func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
   // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
   %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
-  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S2:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]]  : tensor<8x16xf32>
+  %s1 = mesh.sharding @mesh_2d, [[0], [1]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1  : tensor<8x16xf32>
   // CHECK-NEXT:  return %[[V2]]
   return %1 : tensor<8x16xf32>
 }
@@ -28,11 +31,14 @@ func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
 // CHECK-LABEL: func.func @element_wise_on_use
 // CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
 func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
-  %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
+  %s0 = mesh.sharding @mesh_2d, [[0], [1]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 annotate_for_users  : tensor<8x16xf32>
   // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
   %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S2:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]]  : tensor<8x16xf32>
   // CHECK-NEXT:  return %[[V2]]
   return %1 : tensor<8x16xf32>
 }
@@ -40,12 +46,15 @@ func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
 // CHECK-LABEL: func.func @element_wise_on_graph_output
 // CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
 func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
   // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
   %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
-  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]]  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] annotate_for_users  : tensor<8x16xf32>
+  %s1 = mesh.sharding @mesh_2d, [[0], [1]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1 annotate_for_users  : tensor<8x16xf32>
   // CHECK-NEXT:  return %[[V3]]
   return %1 : tensor<8x16xf32>
 }
@@ -53,12 +62,15 @@ func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16x
 // CHECK-LABEL: func.func @element_wise_on_graph_input
 // CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
 func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
-  %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]]  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users  : tensor<8x16xf32>
+  %s0 = mesh.sharding @mesh_2d, [[0], [1]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0  : tensor<8x16xf32>
   // CHECK-NEXT:  %[[V2:.*]] = tosa.sigmoid %[[V1]]
   %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]]  : tensor<8x16xf32>
   // CHECK-NEXT:  return %[[V3]]
   return %1 : tensor<8x16xf32>
 }
@@ -66,18 +78,21 @@ func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf
 // CHECK-LABEL: func.func @arrow_structure
 // CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
 func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
-  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users  : tensor<8x16xf32>
   // CHECK-NEXT:  %[[V2:.*]] = tosa.tanh %[[V1]]
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]]  : tensor<8x16xf32>
   %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users  : tensor<8x16xf32>
   // CHECK-NEXT:  %[[V5:.*]] = tosa.abs %[[V4]]
-  // CHECK-NEXT:  %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT:   %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]]  : tensor<8x16xf32>
   %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
   // CHECK-NEXT:  %[[V7:.*]] = tosa.negate %[[V4]]
-  // CHECK-NEXT:  %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S8:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]]  : tensor<8x16xf32>
   %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  %3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+  %s3 = mesh.sharding @mesh_2d, [[0], [1]] : !mesh.sharding
+  %3 = mesh.shard %2 to %s3  : tensor<8x16xf32>
   // CHECK-NEXT: return %[[V6]], %[[V8]]
   return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
 }
@@ -85,12 +100,16 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
 // CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
 // CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
 func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<2x16x8xf32>
-  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
   // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
   %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<2x16x32xf32>
-  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d, {{\[\[}}0], [1]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]]  : tensor<2x16x32xf32>
+  %s1 = mesh.sharding @mesh_2d, [[0], [1]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1  : tensor<2x16x32xf32>
   // CHECK-NEXT:  return %[[V3]]
   return %1 : tensor<2x16x32xf32>
 }
@@ -98,12 +117,16 @@ func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: ten
 // CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k
 // CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
 func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
-  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d, {{\[\[}}], [1], [0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d, {{\[\[}}], [0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
   // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
   %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
-  %1 = mesh.shard %0 to <@mesh_2d, [[], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d, {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]]  : tensor<2x16x32xf32>
+  %s1 = mesh.sharding @mesh_2d, [[], [1]] partial = sum [0] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1  : tensor<2x16x32xf32>
   // CHECK-NEXT:  return %[[V3]]
   return %1 : tensor<2x16x32xf32>
 }
@@ -111,12 +134,16 @@ func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<
 // CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
 // CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
 func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
-  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d, {{\[\[}}], [1], [0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
+  %s0 = mesh.sharding @mesh_2d, [[], [1], [0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 annotate_for_users  : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d, {{\[\[}}], [0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
   // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
   %1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d, {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]]  : tensor<2x16x32xf32>
   // CHECK-NEXT:  return %[[V3]]
   return %1 : tensor<2x16x32xf32>
 }
@@ -124,13 +151,18 @@ func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<
 // CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
 // CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
 func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
-  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
-  %1 = mesh.shard %arg1 to <@mesh_2d, [[], [0]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d, {{\[\[}}], [1], [0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
+  %s0 = mesh.sharding @mesh_2d, [[], [1], [0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0 annotate_for_users  : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d, {{\[\[}}], [0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
+  %s1 = mesh.sharding @mesh_2d, [[], [0]] : !mesh.sharding
+  %1 = mesh.shard %arg1 to %s1 annotate_for_users  : tensor<2x8x32xf32>
   // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
   %2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d, {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]]  : tensor<2x16x32xf32>
   // CHECK-NEXT:  return %[[V3]]
   return %2 : tensor<2x16x32xf32>
 }
@@ -145,21 +177,23 @@ func.func @resolve_conflicting_annotations(
   %out_dps: tensor<2x2xf32>
 // CHECK-SAME: ) -> tensor<2x2xf32> {
 ) -> tensor<2x2xf32> {
-  // CHECK: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x3xf32>
-  // CHECK: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to <@mesh_2, {{\[}}]> annotate_for_users : tensor<2x3xf32>
-  // CHECK: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
-  // CHECK: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to <@mesh_2, {{\[}}]> annotate_for_users : tensor<2x2xf32>
-  %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
-
+  // CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2, {{\[\[}}0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]]  : tensor<2x3xf32>
+  // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2, [] : !mesh.sharding
+  // CHECK-NEXT:  %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<2x3xf32>
+  // CHECK-NEXT:  %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<3x2xf32>
+  // CHECK-NEXT:  %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<2x2xf32>
+  %sarg0_sharded = mesh.sharding @mesh_2, [[0]] : !mesh.sharding
+  %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded  : tensor<2x3xf32>
   // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
   // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
   %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
     outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
-
-  // CHECK: %[[MATMUL_SHARDED1:.*]] = mesh.shard %[[MATMUL]] to <@mesh_2, {{\[\[}}]]> : tensor<2x2xf32>
-  %res_sharded = mesh.shard %res to <@mesh_2, [[]]> : tensor<2x2xf32>
-
-  // CHECK: return %[[MATMUL_SHARDED1]] : tensor<2x2xf32>
+  // CHECK-NEXT: %[[SRES:.*]] = mesh.sharding @mesh_2, {{\[\[}}]] : !mesh.sharding
+  // CHECK-NEXT: %[[RES:.*]] = mesh.shard %[[MATMUL]] to %[[SRES]] : tensor<2x2xf32>
+  %sres_sharded = mesh.sharding @mesh_2, [[]] : !mesh.sharding
+  %res_sharded = mesh.shard %res to %sres_sharded  : tensor<2x2xf32>
+  // CHECK: return %[[RES]] : tensor<2x2xf32>
   return %res_sharded : tensor<2x2xf32>
 }
 
@@ -167,23 +201,30 @@ func.func @resolve_conflicting_annotations(
 // CHECK-LABEL: func.func @mlp_1d_weight_stationary
 // CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
 func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[], [], [0]]> : tensor<2x4x8xf32>
+  %s0 = mesh.sharding @mesh_1d, [[], [], [0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0  : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_1d, {{\[\[}}], [], [0]] : !mesh.sharding
+  // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_1d, {{\[\[}}], [], [0]] : !mesh.sharding
   // CHECK: %[[V0:.*]] = tosa.matmul
   %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK-DAG: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_1d, {{\[\[}}], [], [0]]> : tensor<2x4x32xf32>
-  // CHECK-DAG: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x32xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[V0]] to %[[S2]]  : tensor<2x4x32xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] annotate_for_users  : tensor<2x4x32xf32>
   // CHECK-DAG: %[[V3:.*]] = tosa.sigmoid %[[V2]]
   %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK-DAG: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_1d, {{\[\[}}], [], [0]]> : tensor<2x4x32xf32>
-  // CHECK-DAG: %[[V5:.*]] = mesh.shard %[[V4]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x32xf32>
-  // CHECK-DAG: %[[V6:.*]] = mesh.shard %[[ARG2]] to <@mesh_1d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x32x8xf32>
+  // CHECK-NEXT:  %[[V4:.*]] = mesh.shard %[[V3]] to %[[S2]]  : tensor<2x4x32xf32>
+  // CHECK-NEXT:  %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] annotate_for_users  : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[S6:.*]] = mesh.sharding @mesh_1d, {{\[\[}}], [0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V6:.*]] = mesh.shard %[[ARG2]] to %[[S6]] annotate_for_users  : tensor<2x32x8xf32>
   // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]]
   %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
-  // CHECK-DAG: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_1d, {{\[\[}}], [], []], partial = sum[0]> : tensor<2x4x8xf32>
-  %4 = mesh.shard %3 to <@mesh_1d, [[], [], []], partial = sum[0]> : tensor<2x4x8xf32>
-  // CHECK-DAG: %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_1d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
-  %5 = mesh.shard %4 to <@mesh_1d, [[], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
-  // CHECK-DAG: return %[[V9]]
+  %s4 = mesh.sharding @mesh_1d, [[], [], []] partial = sum [0] : !mesh.sharding
+  %4 = mesh.shard %3 to %s4  : tensor<2x4x8xf32>
+  // CHECK: %[[S8:.*]] = mesh.sharding @mesh_1d, {{\[\[}}], [], []] partial =  sum [0] : !mesh.sharding
+  // CHECK-NEXT:  %[[V8:.*]] = mesh.shard %[[V7]] to %[[S8]] : tensor<2x4x8xf32>
+  %s5 = mesh.sharding @mesh_1d, [[], [], [0]] : !mesh.sharding
+  %5 = mesh.shard %4 to %s5 annotate_for_users  : tensor<2x4x8xf32>
+  // CHECK:  %[[V9:.*]] = mesh.shard %[[V8]] to %[[S1]] annotate_for_users  : tensor<2x4x8xf32>
+  // CHECK-NEXT: return %[[V9]]
   return %5 : tensor<2x4x8xf32>
 }
 
@@ -191,26 +232,37 @@ func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
 // CHECK-LABEL: func.func @mlp_2d_weight_stationary
 // CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
 func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
-  // CHECK-DAG: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_3d, {{\[\[}}], [], [0, 1, 2]]> : tensor<2x4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh_3d, [[], [], [0, 1, 2]]> : tensor<2x4x8xf32>
-  // CHECK-DAG: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_3d, {{\[\[}}], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
-  // CHECK-DAG: %[[V2:.*]] = mesh.shard %[[ARG1]] to <@mesh_3d, {{\[\[}}], [0], [1, 2]]> annotate_for_users : tensor<2x8x32xf32>
+  // CHECK-DAG: %[[S0:.*]] = mesh.sharding @mesh_3d, {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]]  : tensor<2x4x8xf32>
+  %s0 = mesh.sharding @mesh_3d, [[], [], [0, 1, 2]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0  : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_3d, {{\[\[}}], [], [0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users  : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_3d, {{\[\[}}], [0], [1, 2]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[ARG1]] to %[[S2]] annotate_for_users  : tensor<2x8x32xf32>
   // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]]
   %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK-DAG: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_3d,  {{\[\[}}], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
-  %2 = mesh.shard %1 to <@mesh_3d, [[], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
-  // CHECK-DAG: %[[V5:.*]] = mesh.shard %[[V4]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> annotate_for_users : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[S4:.*]] = mesh.sharding @mesh_3d,  {{\[\[}}], [], [1, 2]] partial = sum [0] : !mesh.sharding
+  // CHECK-NEXT:  %[[V4:.*]] = mesh.shard %[[V3]] to %[[S4]]  : tensor<2x4x32xf32>
+  %s2 = mesh.sharding @mesh_3d, [[], [], [1, 2]] partial = sum [0] : !mesh.sharding
+  %2 = mesh.shard %1 to %s2  : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[S5:.*]] = mesh.sharding @mesh_3d, {{\[\[}}], [], [1, 2]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V5:.*]] = mesh.shard %[[V4]] to %[[S5]] annotate_for_users  : tensor<2x4x32xf32>
   // CHECK-DAG: %[[V6:.*]] = tosa.sigmoid %[[V5]]
   %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK-DAG: %[[V7:.*]] = mesh.shard %[[V6]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> : tensor<2x4x32xf32>
-  // CHECK-DAG: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_3d, {{\[\[}}], [], [1, 2]]> annotate_for_users : tensor<2x4x32xf32>
-  // CHECK-DAG: %[[V9:.*]] = mesh.shard %[[ARG2]] to <@mesh_3d, {{\[\[}}], [1, 2], [0]]> annotate_for_users : tensor<2x32x8xf32>
+  // CHECK-NEXT:  %[[V7:.*]] = mesh.shard %[[V6]] to %[[S5]]  : tensor<2x4x32xf32>
+  // CHECK-NEXT:  %[[V8:.*]] = mesh.shard %[[V7]] to %[[S5]] annotate_for_users  : tensor<2x4x32xf32>
+  // CHECK-DAG: %[[S9:.*]] = mesh.sharding @mesh_3d, {{\[\[}}], [1, 2], [0]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V9:.*]] = mesh.shard %[[ARG2]] to %[[S9]] annotate_for_users  : tensor<2x32x8xf32>
   // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]]
   %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
-  // CHECK-DAG: %[[V11:.*]] = mesh.shard %[[V10]] to <@mesh_3d, {{\[\[}}], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
-  %5 = mesh.shard %4 to <@mesh_3d, [[], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
-  // CHECK-DAG: %[[V12:.*]] = mesh.shard %[[V11]] to <@mesh_3d, {{\[\[}}], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
-  %6 = mesh.shard %5 to <@mesh_3d, [[], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
+  // CHECK-DAG: %[[S11:.*]] = mesh.sharding @mesh_3d, {{\[\[}}], [], [0]] partial = sum [1, 2] : !mesh.sharding
+  // CHECK-NEXT:  %[[V11:.*]] = mesh.shard %[[V10]] to %[[S11]]  : tensor<2x4x8xf32>
+  %s5 = mesh.sharding @mesh_3d, [[], [], [0]] partial = sum[1, 2] : !mesh.sharding
+  %5 = mesh.shard %4 to %s5  : tensor<2x4x8xf32>
+  // CHECK-NEXT:  %[[V12:.*]] = mesh.shard %[[V11]] to %[[S0]] annotate_for_users  : tensor<2x4x8xf32>
+  %s6 = mesh.sharding @mesh_3d, [[], [], [0, 1, 2]] : !mesh.sharding
+  %6 = mesh.shard %5 to %s6 annotate_for_users  : tensor<2x4x8xf32>
   // CHECK-DAG: return %[[V12]]
   return %6 : tensor<2x4x8xf32>
 }
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
index d748be82c5a46..2540fbf9510c4 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -100,8 +100,8 @@ func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
     %arg0: tensor<5xf32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg1: tensor<5xf32>) -> tensor<5xf32> {
-  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = <max>
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <max>
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = max
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = max
     : tensor<5xf32> -> tensor<5xf32>
   // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
   %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
@@ -138,13 +138,13 @@ func.func @all_reduce_arith_minimumf_endomorphism(
     %arg0: tensor<5xf32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg1: tensor<5xf32>) -> tensor<5xf32> {
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
     : tensor<5xf32> -> tensor<5xf32>
-  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
     : tensor<5xf32> -> tensor<5xf32>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
   %2 = arith.minimumf %0, %1 : tensor<5xf32>
-  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
   // CHECK: return %[[ALL_REDUCE_RES]]
   return %2 : tensor<5xf32>
 }
@@ -155,13 +155,13 @@ func.func @all_reduce_arith_minsi_endomorphism(
     %arg0: tensor<5xi32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
     %arg1: tensor<5xi32>) -> tensor<5xi32> {
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
     : tensor<5xi32> -> tensor<5xi32>
-  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
     : tensor<5xi32> -> tensor<5xi32>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
   %2 = arith.minsi %0, %1 : tensor<5xi32>
-  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
   // CHECK: return %[[ALL_REDUCE_RES]]
   return %2 : tensor<5xi32>
 }
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 6888fa609601d..e51b2a0b5ac40 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -10,8 +10,10 @@ func.func @full_replication(
   %arg0: tensor<2xi8>
 // CHECK-SAME: -> tensor<2xi8> {
 ) -> tensor<2xi8> {
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
-  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+  %s0 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
+  %s1 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
   // CHECK: return %[[ARG]] : tensor<2xi8>
   return %1 : tensor<2xi8>
 }
@@ -23,9 +25,12 @@ func.func @sharding_triplet(
 // CHECK-SAME: ) -> tensor<2xf32> {
 ) -> tensor<2xf32> {
   // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
-  %sharding_annotated = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xf32>
-  %sharding_annotated_0 = mesh.shard %sharding_annotated to <@mesh_1d, [[0]]> annotate_for_users : tensor<2xf32>
-  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to <@mesh_1d, [[]]> : tensor<2xf32>
+  %ssharding_annotated = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated  : tensor<2xf32>
+  %ssharding_annotated_0 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %sharding_annotated_0 = mesh.shard %sharding_annotated to %ssharding_annotated_0  annotate_for_users : tensor<2xf32>
+  %ssharding_annotated_1 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1  : tensor<2xf32>
   // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
   return %sharding_annotated_1 : tensor<2xf32>
 }
@@ -39,8 +44,10 @@ func.func @move_split_axis(
 ) -> tensor<2x2xi8> {
   // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d
   // CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2x2xi8>
-  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<2x2xi8>
+  %s0 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0  : tensor<2x2xi8>
+  %s1 = mesh.sharding @mesh_1d, [[], [0]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2x2xi8>
   // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
   return %1 : tensor<2x2xi8>
 }
@@ -63,12 +70,16 @@ func.func @unary_elementwise(
   %arg0: tensor<2xi8>
 // CHECK-SAME: -> tensor<1xi8> {
 ) -> tensor<2xi8> {
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %s0 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
+  %s1 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
   // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
   %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
-  %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %4 = mesh.shard %3 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %s3 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %3 = mesh.shard %2 to %s3  : tensor<2xi8>
+  %s4 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %4 = mesh.shard %3 to %s4  annotate_for_users : tensor<2xi8>
   // CHECK: return %[[RES]] : tensor<1xi8>
   return %4 : tensor<2xi8>
 }
@@ -82,14 +93,18 @@ func.func @unary_elementwise_with_resharding(
 ) -> tensor<2xi8> {
   // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
   // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
-  %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %s0 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
+  %s1 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
   // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
   %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
   // CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d
   // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
-  %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+  %s3 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %3 = mesh.shard %2 to %s3  : tensor<2xi8>
+  %s4 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %4 = mesh.shard %3 to %s4  annotate_for_users : tensor<2xi8>
   // CHECK: return %[[RES]] : tensor<2xi8>
   return %4 : tensor<2xi8>
 }
@@ -102,14 +117,20 @@ func.func @binary_elementwise(
   %arg1: tensor<2xi8>
 // CHECK-SAME: -> tensor<1xi8> {
 ) -> tensor<2xi8> {
-  %arg0_sharded = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %op_arg0 = mesh.shard %arg0_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
-  %arg1_sharded = mesh.shard %arg1 to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %op_arg1 = mesh.shard %arg1_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %sarg0_sharded = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded  : tensor<2xi8>
+  %sop_arg0 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %op_arg0 = mesh.shard %arg0_sharded to %sop_arg0  annotate_for_users : tensor<2xi8>
+  %sarg1_sharded = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %arg1_sharded = mesh.shard %arg1 to %sarg1_sharded  : tensor<2xi8>
+  %sop_arg1 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %op_arg1 = mesh.shard %arg1_sharded to %sop_arg1  annotate_for_users : tensor<2xi8>
   // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
   %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
-  %op_res_sharded = mesh.shard %op_res to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %res = mesh.shard %op_res_sharded to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %sop_res_sharded = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %op_res_sharded = mesh.shard %op_res to %sop_res_sharded  : tensor<2xi8>
+  %sres = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %res = mesh.shard %op_res_sharded to %sres  annotate_for_users : tensor<2xi8>
   // CHECK: return %[[RES]] : tensor<1xi8>
   return %res : tensor<2xi8>
 }
@@ -127,20 +148,26 @@ func.func @multiple_chained_ops(
 ) -> tensor<2xi8> {
   // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
   // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
-  %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %s0 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
+  %s1 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
   // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
   %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
   // CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d
   // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
-  %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+  %s3 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %3 = mesh.shard %2 to %s3  : tensor<2xi8>
+  %s4 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %4 = mesh.shard %3 to %s4  annotate_for_users : tensor<2xi8>
   // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
   %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
   // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 : 
   // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
-  %6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8>
-  %7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %s6 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %6 = mesh.shard %5 to %s6  : tensor<2xi8>
+  %s7 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %7 = mesh.shard %6 to %s7  annotate_for_users : tensor<2xi8>
   // CHECK: return %[[RESHARD3]] : tensor<1xi8>
   return %7 : tensor<2xi8>
 }
@@ -151,10 +178,12 @@ func.func @incomplete_sharding(
   %arg0: tensor<8x16xf32>
 // CHECK-SAME: -> tensor<4x16xf32> {
 ) -> tensor<8x16xf32> {
-  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32>
+  %s0 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %0 = mesh.shard %arg0 to %s0  annotate_for_users : tensor<8x16xf32>
   // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
   %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  %2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32>
+  %s2 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %2 = mesh.shard %1 to %s2  : tensor<8x16xf32>
   // CHECK: return %[[RES]] : tensor<4x16xf32>
   return %2 : tensor<8x16xf32>
 }
@@ -167,8 +196,9 @@ func.func @update_halo_constraint(
   // CHECK-SAME: -> tensor<11x16xi8> {
 ) -> tensor<32x16xi8> {
   // CHECK: %[[RES:.*]] = mesh.update_halo %[[IN1]] on @mesh_1d_4 halo_sizes = [2, 1] : (tensor<11x16xi8>) -> tensor<11x16xi8>
-  %in1_sharded1 = mesh.shard %in1 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> : tensor<32x16xi8>
-  %in1_sharded2 = mesh.shard %in1_sharded1 to <@mesh_1d_4, [[0]] {<force = true halo_sizes = [2, 1]>}> annotate_for_users: tensor<32x16xi8>
+  %sin1_sharded1 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %in1_sharded1 = mesh.shard %in1 to %sin1_sharded1  : tensor<32x16xi8>
+  %in1_sharded2 = mesh.shard %in1_sharded1 to %sin1_sharded1 annotate_for_users force : tensor<32x16xi8>
   // CHECK: return %[[RES]] : tensor<11x16xi8>
   return %in1_sharded2 : tensor<32x16xi8>
 }
@@ -179,64 +209,81 @@ func.func @ew_chain_with_halo(
   %arg0: tensor<8x16xf32>)
   // CHECK-SAME: -> tensor<5x16xf32>
    -> tensor<8x16xf32> {
-  %sharding_annotated = mesh.shard %arg0 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> annotate_for_users : tensor<8x16xf32>
+  %ssharding_annotated = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated  annotate_for_users : tensor<8x16xf32>
   // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
   %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  %sharding_annotated_0 = mesh.shard %0 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> : tensor<8x16xf32>
-  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> annotate_for_users : tensor<8x16xf32>
+  %ssharding_annotated_0 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_0 = mesh.shard %0 to %ssharding_annotated_0  : tensor<8x16xf32>
+  %ssharding_annotated_1 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1  annotate_for_users : tensor<8x16xf32>
   // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
   %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  %sharding_annotated_2 = mesh.shard %1 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> : tensor<8x16xf32>
-  %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> annotate_for_users : tensor<8x16xf32>
+  %ssharding_annotated_2 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2  : tensor<8x16xf32>
+  %ssharding_annotated_4 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4  annotate_for_users : tensor<8x16xf32>
   // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
   %2 = tosa.negate %sharding_annotated_4 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  %sharding_annotated_5 = mesh.shard %2 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> : tensor<8x16xf32>
-  %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to <@mesh_1d_4, [[0]] {<halo_sizes = [2, 1]>}> annotate_for_users : tensor<8x16xf32>
+  %ssharding_annotated_5 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5  : tensor<8x16xf32>
+  %ssharding_annotated_6 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to %ssharding_annotated_6  annotate_for_users : tensor<8x16xf32>
   // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
   return %sharding_annotated_6 : tensor<8x16xf32>
 }
 
-#map = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: func @stencil_with_halo
-func.func @stencil_with_halo() -> () {
-  %a = "xxx.empty"() : () -> tensor<32x16xf32>
-  %sc1 = mesh.sharding_constraint sharded_dims = [] halo_sizes = [1, 2] : !mesh.constraint
-  %sa = mesh.shard %a to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc1 : tensor<32x16xf32>
-  %b = "xxx.empty"() : () -> tensor<8x16xf32>
-  %sc2 = mesh.sharding_constraint sharded_dims = [1, 2, 3, 2] halo_sizes = [] : !mesh.constraint
-  %sb = mesh.shard %b to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 : tensor<8x16xf32>
-
-  %sai1 = mesh.shard %sa to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc1 annotate_for_users : tensor<32x16xf32>
-  %v1 = "xxx.view"(%sa) {x = 1} : (tensor<32x16xf32>) -> tensor<8x16xf32>
-  %sv1 = mesh.shard %v1 to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 : tensor<8x16xf32>
-
-  %sai2 = mesh.shard %sa to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc1 annotate_for_users : tensor<32x16xf32>
-  %v2 = "xxx.view"(%sa) {x = 2} : (tensor<32x16xf32>) -> tensor<8x16xf32>
-  %sv2 = mesh.shard %v2 to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 : tensor<8x16xf32>
+// #map = affine_map<(d0, d1) -> (d0, d1)>
+// // CHECK-LABEL: func @stencil_with_halo
+// func.func @stencil_with_halo() -> () {
+//   %a = "xxx.empty"() : () -> tensor<32x16xf32>
+//   %ssa = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [1, 2] : !mesh.sharding
+//   %sa = mesh.shard %a to %ssa : tensor<32x16xf32>
+//   %b = "xxx.empty"() : () -> tensor<8x16xf32>
+//   %ssb = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
+//   %sb = mesh.shard %b to %ssb  : tensor<8x16xf32>
+
+//   %ssai1 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [1, 2] : !mesh.sharding
+//   %sai1 = mesh.shard %sa to %ssai1  annotate_for_users : tensor<32x16xf32>
+//   %v1 = "xxx.view"(%sa) {x = 1} : (tensor<32x16xf32>) -> tensor<8x16xf32>
+//   %ssv1 = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
+//   %sv1 = mesh.shard %v1 to %ssv1  : tensor<8x16xf32>
+
+//   %ssai2 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [1, 2] : !mesh.sharding
+//   %sai2 = mesh.shard %sa to %ssai2  annotate_for_users : tensor<32x16xf32>
+//   %v2 = "xxx.view"(%sa) {x = 2} : (tensor<32x16xf32>) -> tensor<8x16xf32>
+//   %ssv2 = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
+//   %sv2 = mesh.shard %v2 to %ssv2  : tensor<8x16xf32>
   
-  %v1i = mesh.shard %sv1 to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 annotate_for_users : tensor<8x16xf32>
-  %v2i = mesh.shard %sv2 to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 annotate_for_users : tensor<8x16xf32>
-  %bo = mesh.shard %sb to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 annotate_for_users : tensor<8x16xf32>
-  %r = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%v1i, %v2i : tensor<8x16xf32>, tensor<8x16xf32>) outs(%bo : tensor<8x16xf32>) {
-    ^bb0(%in: f32, %in_56: f32, %out: f32):
-      %47 = arith.addf %in, %in_56 : f32
-      linalg.yield %47 : f32
-    } -> tensor<8x16xf32>
-  %sr = mesh.shard %r to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 : tensor<8x16xf32>
-
-  %sai3 = mesh.shard %sa to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc1 annotate_for_users : tensor<32x16xf32>
-  %sri = mesh.shard %sr to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc2 annotate_for_users : tensor<8x16xf32>
-  "xxx.insert_slice"(%sai3, %sri) : (tensor<32x16xf32>, tensor<8x16xf32>) -> ()
-  %sc3 = mesh.sharding_constraint sharded_dims = [] halo_sizes = [1, 2] force : !mesh.constraint
-  %sai4 = mesh.shard %sa to <@mesh_1d_4, [[0]]>, !mesh.constraint = %sc3 : tensor<32x16xf32>
-
-  return
-}
-// CHECK: %[[V0:.*]] = "xxx.empty"() : () -> tensor<11x16xf32>
-// CHECK-NEXT: %[[V1:.*]] = "xxx.empty"() : () -> tensor<?x16xf32>
-// CHECK-NEXT: %[[V2:.*]] = "xxx.view"([[V0]]) {x = 1 : i64} : (tensor<11x16xf32>) -> tensor<?x16xf32>
-// CHECK-NEXT: %[[V3:.*]] = "xxx.view"([[V0]]) {x = 2 : i64} : (tensor<11x16xf32>) -> tensor<?x16xf32>
-// CHECK-NEXT: %[[V4:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[V2]], [[V3]] : tensor<?x16xf32>, tensor<?x16xf32>) outs([[V1]] : tensor<?x16xf32>) {
-// CHECK: "xxx.insert_slice"([[V0]], [[V4]]) : (tensor<11x16xf32>, tensor<?x16xf32>) -> ()
-// CHECK-NEXT: %update_halo = mesh.update_halo [[V0]] on @mesh_1d_4 halo_sizes = [1, 2] : (tensor<11x16xf32>) -> tensor<11x16xf32>
-// CHECK-NEXT: return
+//   %sv1i = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
+//   %v1i = mesh.shard %sv1 to %sv1i  annotate_for_users : tensor<8x16xf32>
+//   %sv2i = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
+//   %v2i = mesh.shard %sv2 to %sv2i  annotate_for_users : tensor<8x16xf32>
+//   %sbo = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
+//   %bo = mesh.shard %sb to %sbo  annotate_for_users : tensor<8x16xf32>
+//   %r = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%v1i, %v2i : tensor<8x16xf32>, tensor<8x16xf32>) outs(%bo : tensor<8x16xf32>) {
+//     ^bb0(%in: f32, %in_56: f32, %out: f32):
+//       %47 = arith.addf %in, %in_56 : f32
+//       linalg.yield %47 : f32
+//     } -> tensor<8x16xf32>
+//   %ssr = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
+//   %sr = mesh.shard %r to %ssr  : tensor<8x16xf32>
+
+//   %ssai3 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [1, 2] : !mesh.sharding
+//   %sai3 = mesh.shard %sa to %ssai3  annotate_for_users : tensor<32x16xf32>
+//   %ssri = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
+//   %sri = mesh.shard %sr to %ssri  annotate_for_users : tensor<8x16xf32>
+//   "xxx.insert_slice"(%sai3, %sri) : (tensor<32x16xf32>, tensor<8x16xf32>) -> ()
+//   %ssai4 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [1, 2] : !mesh.sharding
+//   %sai4 = mesh.shard %sa to %ssai4 force : tensor<32x16xf32>
+
+//   return
+// }
+// COM: CHECK: %[[V0:.*]] = "xxx.empty"() : () -> tensor<11x16xf32>
+// COM: CHECK-NEXT: %[[V1:.*]] = "xxx.empty"() : () -> tensor<?x16xf32>
+// COM: CHECK-NEXT: %[[V2:.*]] = "xxx.view"([[V0]]) {x = 1 : i64} : (tensor<11x16xf32>) -> tensor<?x16xf32>
+// COM: CHECK-NEXT: %[[V3:.*]] = "xxx.view"([[V0]]) {x = 2 : i64} : (tensor<11x16xf32>) -> tensor<?x16xf32>
+// COM: CHECK-NEXT: %[[V4:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[V2]], [[V3]] : tensor<?x16xf32>, tensor<?x16xf32>) outs([[V1]] : tensor<?x16xf32>) {
+// COM: CHECK: "xxx.insert_slice"([[V0]], [[V4]]) : (tensor<11x16xf32>, tensor<?x16xf32>) -> ()
+// COM: CHECK-NEXT: %update_halo = mesh.update_halo [[V0]] on @mesh_1d_4 halo_sizes = [1, 2] : (tensor<11x16xf32>) -> tensor<11x16xf32>
+// COM: CHECK-NEXT: return

>From 3e925b21faa1c9a56cf3bcd824cb7a5874f85d89 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Tue, 9 Jul 2024 10:08:26 +0200
Subject: [PATCH 3/8] fixing use of force in spmdization, adding update_halo
 tests

---
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  2 +-
 mlir/test/Dialect/Mesh/spmdization.mlir       | 88 +++++++------------
 2 files changed, 32 insertions(+), 58 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 2ba7e9998b49c..8d3a355a6bbe5 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -494,7 +494,7 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
       handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
                                         sourceShard);
 
-  if (reducedSourceSharding == targetSharding) {
+  if (!force && reducedSourceSharding == targetSharding) {
     return reducedSourceShard;
   }
 
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index e51b2a0b5ac40..0b1c143c25334 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -189,8 +189,8 @@ func.func @incomplete_sharding(
 }
 
 mesh.mesh @mesh_1d_4(shape = 4)
-// CHECK-LABEL: func @update_halo_constraint
-func.func @update_halo_constraint(
+// CHECK-LABEL: func @update_halo_static
+func.func @update_halo_static(
   // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<11x16xi8>
   %in1: tensor<32x16xi8>
   // CHECK-SAME: -> tensor<11x16xi8> {
@@ -203,6 +203,35 @@ func.func @update_halo_constraint(
   return %in1_sharded2 : tensor<32x16xi8>
 }
 
+// CHECK-LABEL: func @update_halo_dynamic
+func.func @update_halo_dynamic(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<?x16xi8>
+  %in1: tensor<?x16xi8>
+  // CHECK-SAME: -> tensor<?x16xi8> {
+) -> tensor<?x16xi8> {
+  // CHECK: %[[RES:.*]] = mesh.update_halo %[[IN1]] on @mesh_1d_4 halo_sizes = [2, 1] : (tensor<?x16xi8>) -> tensor<?x16xi8>
+  %sin1_sharded1 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %in1_sharded1 = mesh.shard %in1 to %sin1_sharded1  : tensor<?x16xi8>
+  %in1_sharded2 = mesh.shard %in1_sharded1 to %sin1_sharded1 annotate_for_users force : tensor<?x16xi8>
+  // CHECK: return %[[RES]] : tensor<?x16xi8>
+  return %in1_sharded2 : tensor<?x16xi8>
+}
+
+mesh.mesh @mesh_1d_dyn(shape = ?)
+// CHECK-LABEL: func @update_halo_dynamic_mesh
+func.func @update_halo_dynamic_mesh(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<?x16xi8>
+  %in1: tensor<32x16xi8>
+  // CHECK-SAME: -> tensor<?x16xi8> {
+) -> tensor<32x16xi8> {
+  // CHECK: %[[RES:.*]] = mesh.update_halo %[[IN1]] on @mesh_1d_dyn halo_sizes = [2, 1] : (tensor<?x16xi8>) -> tensor<?x16xi8>
+  %sin1_sharded1 = mesh.sharding @mesh_1d_dyn, [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %in1_sharded1 = mesh.shard %in1 to %sin1_sharded1  : tensor<32x16xi8>
+  %in1_sharded2 = mesh.shard %in1_sharded1 to %sin1_sharded1 annotate_for_users force : tensor<32x16xi8>
+  // CHECK: return %[[RES]] : tensor<?x16xi8>
+  return %in1_sharded2 : tensor<32x16xi8>
+}
+
 // CHECK-LABEL: func @ew_chain_with_halo
 func.func @ew_chain_with_halo(
   // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
@@ -232,58 +261,3 @@ func.func @ew_chain_with_halo(
   // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
   return %sharding_annotated_6 : tensor<8x16xf32>
 }
-
-// #map = affine_map<(d0, d1) -> (d0, d1)>
-// // CHECK-LABEL: func @stencil_with_halo
-// func.func @stencil_with_halo() -> () {
-//   %a = "xxx.empty"() : () -> tensor<32x16xf32>
-//   %ssa = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [1, 2] : !mesh.sharding
-//   %sa = mesh.shard %a to %ssa : tensor<32x16xf32>
-//   %b = "xxx.empty"() : () -> tensor<8x16xf32>
-//   %ssb = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
-//   %sb = mesh.shard %b to %ssb  : tensor<8x16xf32>
-
-//   %ssai1 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [1, 2] : !mesh.sharding
-//   %sai1 = mesh.shard %sa to %ssai1  annotate_for_users : tensor<32x16xf32>
-//   %v1 = "xxx.view"(%sa) {x = 1} : (tensor<32x16xf32>) -> tensor<8x16xf32>
-//   %ssv1 = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
-//   %sv1 = mesh.shard %v1 to %ssv1  : tensor<8x16xf32>
-
-//   %ssai2 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [1, 2] : !mesh.sharding
-//   %sai2 = mesh.shard %sa to %ssai2  annotate_for_users : tensor<32x16xf32>
-//   %v2 = "xxx.view"(%sa) {x = 2} : (tensor<32x16xf32>) -> tensor<8x16xf32>
-//   %ssv2 = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
-//   %sv2 = mesh.shard %v2 to %ssv2  : tensor<8x16xf32>
-  
-//   %sv1i = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
-//   %v1i = mesh.shard %sv1 to %sv1i  annotate_for_users : tensor<8x16xf32>
-//   %sv2i = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
-//   %v2i = mesh.shard %sv2 to %sv2i  annotate_for_users : tensor<8x16xf32>
-//   %sbo = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
-//   %bo = mesh.shard %sb to %sbo  annotate_for_users : tensor<8x16xf32>
-//   %r = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%v1i, %v2i : tensor<8x16xf32>, tensor<8x16xf32>) outs(%bo : tensor<8x16xf32>) {
-//     ^bb0(%in: f32, %in_56: f32, %out: f32):
-//       %47 = arith.addf %in, %in_56 : f32
-//       linalg.yield %47 : f32
-//     } -> tensor<8x16xf32>
-//   %ssr = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
-//   %sr = mesh.shard %r to %ssr  : tensor<8x16xf32>
-
-//   %ssai3 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [1, 2] : !mesh.sharding
-//   %sai3 = mesh.shard %sa to %ssai3  annotate_for_users : tensor<32x16xf32>
-//   %ssri = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 2, 3, 2] : !mesh.sharding
-//   %sri = mesh.shard %sr to %ssri  annotate_for_users : tensor<8x16xf32>
-//   "xxx.insert_slice"(%sai3, %sri) : (tensor<32x16xf32>, tensor<8x16xf32>) -> ()
-//   %ssai4 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [1, 2] : !mesh.sharding
-//   %sai4 = mesh.shard %sa to %ssai4 force : tensor<32x16xf32>
-
-//   return
-// }
-// COM: CHECK: %[[V0:.*]] = "xxx.empty"() : () -> tensor<11x16xf32>
-// COM: CHECK-NEXT: %[[V1:.*]] = "xxx.empty"() : () -> tensor<?x16xf32>
-// COM: CHECK-NEXT: %[[V2:.*]] = "xxx.view"([[V0]]) {x = 1 : i64} : (tensor<11x16xf32>) -> tensor<?x16xf32>
-// COM: CHECK-NEXT: %[[V3:.*]] = "xxx.view"([[V0]]) {x = 2 : i64} : (tensor<11x16xf32>) -> tensor<?x16xf32>
-// COM: CHECK-NEXT: %[[V4:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[V2]], [[V3]] : tensor<?x16xf32>, tensor<?x16xf32>) outs([[V1]] : tensor<?x16xf32>) {
-// COM: CHECK: "xxx.insert_slice"([[V0]], [[V4]]) : (tensor<11x16xf32>, tensor<?x16xf32>) -> ()
-// COM: CHECK-NEXT: %update_halo = mesh.update_halo [[V0]] on @mesh_1d_4 halo_sizes = [1, 2] : (tensor<11x16xf32>) -> tensor<11x16xf32>
-// COM: CHECK-NEXT: return

>From eb76c34ca35bb00f0dc4470f583e2ea1c3f6f0d0 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Tue, 9 Jul 2024 11:11:40 +0200
Subject: [PATCH 4/8] fixing tests

---
 .../Linalg/mesh-sharding-propagation.mlir     | 24 +++--
 .../test/Dialect/Linalg/mesh-spmdization.mlir | 97 +++++++++++--------
 2 files changed, 71 insertions(+), 50 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
index 59fd548dc2ef2..0e79011624ca3 100644
--- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
@@ -14,20 +14,28 @@ func.func @matmul_shard_prallel_axis(
   // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
   %out_dps: tensor<2x2xf32>
 ) -> tensor<2x2xf32> {
-  // CHECK: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[}}[0]]> : tensor<2x3xf32>
-  // CHECK: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x3xf32>
-  // CHECK: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
-  // CHECK: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x2xf32>
-  %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+  // CHECK: %[[SIN1_ANNOTATED_0:.*]] = mesh.sharding @mesh_2, {{\[}}[0]] : !mesh.sharding
+  // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
+  // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2, {{\[}}[0]] : !mesh.sharding
+  // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
+  // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2, [] : !mesh.sharding
+  // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
+  // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2, {{\[}}[0]] : !mesh.sharding
+  // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
+  %sarg0_sharded = mesh.sharding @mesh_2, [[0]] : !mesh.sharding
+  %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
 
   // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
   // CHECK-SAME:  outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
   %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
     outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
 
-  // CHECK: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to <@mesh_2, {{\[}}[0]]> : tensor<2x2xf32>
-  // CHECK: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to <@mesh_2, {{\[}}[]]> annotate_for_users : tensor<2x2xf32>
-  %res_sharded = mesh.shard %res to <@mesh_2, [[]]> annotate_for_users : tensor<2x2xf32>
+  // CHECK: %[[SRES_ANNOTATED_0:.*]] = mesh.sharding @mesh_2, {{\[}}[0]] : !mesh.sharding
+  // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32>
+  // CHECK: %[[SRES_ANNOTATED_1:.*]] = mesh.sharding @mesh_2, {{\[}}[]] : !mesh.sharding
+  // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32>
+  %sres_sharded = mesh.sharding @mesh_2, [[]] : !mesh.sharding
+  %res_sharded = mesh.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32>
 
   // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
   return %res_sharded : tensor<2x2xf32>
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index b105f5007d532..1c5d6b8ff016a 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -18,12 +18,13 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor(
   %dps_out: tensor<2xi8>
 // CHECK-SAME: -> tensor<1xi8> {
 ) -> tensor<2xi8> {
-  %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
-  %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
-  %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %sharding = mesh.sharding @mesh_1d, [[0]]  : !mesh.sharding
+  %in1_sharded1 = mesh.shard %in1 to %sharding  : tensor<2xi8>
+  %in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+  %in2_sharded1 = mesh.shard %in2 to %sharding : tensor<2xi8>
+  %in2_sharded2 = mesh.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+  %dps_out_sharded1 = mesh.shard %dps_out to %sharding : tensor<2xi8>
+  %dps_out_shared2 = mesh.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8>
   // CHECK: %[[RES:.*]] = linalg.generic {
   // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]],
   // CHECK-SAME: iterator_types = ["parallel"]}
@@ -32,14 +33,14 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor(
   %res = linalg.generic {
       indexing_maps = [#map_identity_1d, #map_identity_1d, #map_identity_1d],
       iterator_types = ["parallel"]
-    } ins(%in1_shared2, %in2_shared2 : tensor<2xi8>, tensor<2xi8>)
+    } ins(%in1_sharded2, %in2_sharded2 : tensor<2xi8>, tensor<2xi8>)
       outs(%dps_out_shared2 : tensor<2xi8>) {
     ^bb0(%in1_scalar: i8, %in2_scalar: i8, %out: i8):
       %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8
       linalg.yield %res_scalar : i8
     } -> tensor<2xi8>
-  %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<2xi8>
-  %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+  %res_sharded1 = mesh.shard %res to %sharding : tensor<2xi8>
+  %res_shared2 = mesh.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8>
   // CHECK: return %[[RES]] : tensor<1xi8>
   return %res_shared2 : tensor<2xi8>
 }
@@ -58,20 +59,22 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
   %dps_out: tensor<4x8xi8>
 // CHECK-SAME: -> tensor<1x8xi8> {
 ) -> tensor<4x8xi8> {
-  %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<4x3xi8>
-  %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x3xi8>
-  %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[]]> : tensor<3x8xi8>
-  %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<3x8xi8>
-  %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<4x8xi8>
-  %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8>
+  %sharding = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x3xi8>
+  %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8>
+  %sharding2 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<3x8xi8>
+  %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8>
+  %dps_out_shared1 = mesh.shard %dps_out to %sharding : tensor<4x8xi8>
+  %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
   // CHECK: %[[RES:.*]] = linalg.matmul
   // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>)
   // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>)
   // CHECK-SAME: -> tensor<1x8xi8>
   %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>)
       outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
-  %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<4x8xi8>
-  %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8>
+  %res_shared1 = mesh.shard %res to %sharding : tensor<4x8xi8>
+  %res_shared2 = mesh.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
   // CHECK: return %[[RES]] : tensor<1x8xi8>
   return %res_shared2 : tensor<4x8xi8>
 }
@@ -90,12 +93,15 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
   %dps_out: tensor<4x8xi8>
 // CHECK-SAME: -> tensor<4x8xi8> {
 ) -> tensor<4x8xi8> {
-  %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8>
-  %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8>
-  %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8>
-  %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8>
-  %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8>
-  %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+  %sharding = mesh.sharding @mesh_1d, [[], [0]] : !mesh.sharding
+  %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8>
+  %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
+  %sharding2 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8>
+  %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
+  %sharding3 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8>
+  %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
   // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:  %[[C0_I8:.*]] = arith.constant 0 : i8
   // CHECK-DAG:  %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
@@ -114,8 +120,8 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
   // CHECK:      %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
   %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
       outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
-  %res_shared1 = mesh.shard %res to <@mesh_1d, [[]]> : tensor<4x8xi8>
-  %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+  %res_shared1 = mesh.shard %res to %sharding3 : tensor<4x8xi8>
+  %res_shared2 = mesh.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
   // CHECK:      return %[[ALL_REDUCED]] : tensor<4x8xi8>
   return %res_shared2 : tensor<4x8xi8>
 }
@@ -134,12 +140,16 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partia
   %dps_out: tensor<4x8xi8>
 // CHECK-SAME: -> tensor<4x8xi8> {
 ) -> tensor<4x8xi8> {
-  %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8>
-  %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8>
-  %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8>
-  %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8>
-  %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8>
-  %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+  %sharding = mesh.sharding @mesh_1d, [[], [0]] : !mesh.sharding
+  %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8>
+  %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
+  %sharding2 = mesh.sharding @mesh_1d, [[0]] : !mesh.sharding
+  %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8>
+  %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
+  %sharding3 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8>
+  %sdps_out_shared2 = mesh.sharding @mesh_1d, [[]] : !mesh.sharding
+  %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
   // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:  %[[C0_I8:.*]] = arith.constant 0 : i8
   // CHECK-DAG:  %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
@@ -157,8 +167,9 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partia
   // CHECK-SAME:     outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
   %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
       outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
-  %res_shared1 = mesh.shard %res to <@mesh_1d, [[]], partial = sum[0]> : tensor<4x8xi8>
-  %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]], partial = sum[0]> annotate_for_users: tensor<4x8xi8>
+  %sharding4 = mesh.sharding @mesh_1d, [[]] partial = sum[0] : !mesh.sharding
+  %res_shared1 = mesh.shard %res to %sharding4 : tensor<4x8xi8>
+  %res_shared2 = mesh.shard %res_shared1 to %sharding4 annotate_for_users : tensor<4x8xi8>
   // CHECK:      return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
   return %res_shared2 : tensor<4x8xi8>
 }
@@ -177,14 +188,16 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
   %dps_out: tensor<4x8xi8>
   // CHECK-SAME: -> tensor<4x8xi8> {
 ) -> tensor<4x8xi8> {
-  %in1_replicated1 = mesh.shard %in1 to <@mesh_1d, [[], []]> : tensor<4x6xi8>
-  %in1_replicated2 = mesh.shard %in1_replicated1 to <@mesh_1d, [[], []]> annotate_for_users : tensor<4x6xi8>
+  %sharding1 = mesh.sharding @mesh_1d, [[], []] : !mesh.sharding
+  %in1_replicated1 = mesh.shard %in1 to %sharding1 : tensor<4x6xi8>
+  %in1_replicated2 = mesh.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8>
   // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
-  %in2_replicated = mesh.shard %in2 to <@mesh_1d, [[], []]> : tensor<6x8xi8>
-  %in2_sharded = mesh.shard %in2_replicated to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<6x8xi8>
+  %in2_replicated = mesh.shard %in2 to %sharding1 : tensor<6x8xi8>
+  %sharding2 = mesh.sharding @mesh_1d, [[], [0]] : !mesh.sharding
+  %in2_sharded = mesh.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8>
   // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
-  %dps_out_replicated = mesh.shard %dps_out to <@mesh_1d, [[], []]> : tensor<4x8xi8>
-  %dps_out_sharded = mesh.shard %dps_out_replicated to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x8xi8>
+  %dps_out_replicated = mesh.shard %dps_out to %sharding1 : tensor<4x8xi8>
+  %dps_out_sharded = mesh.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8>
   // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
   // CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
   // CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
@@ -192,8 +205,8 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
   %res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>)
       outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8>
   // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
-  %res_sharded = mesh.shard %res to <@mesh_1d, [[], [0]]> : tensor<4x8xi8>
-  %res_replicated = mesh.shard %res_sharded to <@mesh_1d, [[], []]> annotate_for_users: tensor<4x8xi8>
+  %res_sharded = mesh.shard %res to %sharding2 : tensor<4x8xi8>
+  %res_replicated = mesh.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8>
   // CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
   return %res_replicated : tensor<4x8xi8>
-}
\ No newline at end of file
+}

>From cd6c6312b2a305cf5e38c9b5b47ea998184241de Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Tue, 9 Jul 2024 13:00:32 +0200
Subject: [PATCH 5/8] improved descriptions; more tests

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 104 +++++++++++-------
 mlir/test/Dialect/Mesh/ops.mlir               |  34 ++++++
 .../test/Dialect/Tensor/mesh-spmdization.mlir |  38 +++++++
 3 files changed, 138 insertions(+), 38 deletions(-)
 create mode 100644 mlir/test/Dialect/Tensor/mesh-spmdization.mlir

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 17d7100b58165..b707015cc3684 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -166,9 +166,9 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
   ]> {
   let summary = "Define a sharding of a tensor.";
   let description = [{
-    The MeshSharding is used in a `mesh.shard` operation.
-    It specifies how a tensor is sharded and distributed across the process
-    mesh.
+    The MeshSharding specifies how a tensor is sharded and distributed across the
+    process mesh. It is typically used in a `mesh.shard` operation.
+    The operation has the follwing attributes and operands:
 
     1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
     mesh where the distributed tensor is placed. The symbol must resolve to a
@@ -179,18 +179,18 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     its value is [x, y], it indicates that the tensor's i-th dimension is splitted
     along the x and y axes of the device mesh.
 
-    3. `partial_axes`: if not empty, this signifies that the tensor is partial
+    3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
     one along the specified mesh axes. An all-reduce should be applied to obtain
     the complete tensor, with reduction type being specified by `partial_type`.
 
-    4. `partial_type`: indicates the reduction type of the possible all-reduce
+    4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
     op. It has 4 possible values:
     `generic`: is not an allowed value inside a shard attribute.
 
     5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
     `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
     `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
-    halo of size 1 at the start of the dimension and a halo size is 2 at the end.
+    halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
     `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
     e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
     `?` indicates dynamic halo sizes.
@@ -204,36 +204,43 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     shard of shape 16x24x32.
     `?` indicates dynamic shard dimensions.
     
-    
     `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
 
-    Example:
+    Examples:
 
     ```
     mesh.mesh @mesh0(shape = 2x2x4)
+    mesh.mesh @mesh1d_4(shape = 4)
 
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even
     // if it's empty. Otherwise, a parsing error will occur.
-    #mesh.shard<@mesh0, [[]]>
+    %sharding0 = mesh.sharding @mesh0, [[]]
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0
-    #mesh.shard<@mesh0, [[0]]>
+    %sharding1 = mesh.sharding @mesh0, [[0]]
 
-    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
+    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
     // it is also a partial_sum along mesh axis 1.
-    #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
+    %sharding2 = mesh.sharding @mesh0, [[0], []] partial = sum[1]
 
-    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
+    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
-    #mesh.shard<@mesh0, [[0]], partial = max[1]>
+    %sharding3 = mesh.sharding @mesh0, [[0]] partial = max[1]
 
-    // Could be used in the attribute of mesh.shard op
-    %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+    // Could be used for a mesh.shard op
+    %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
 
-    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
+    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
     // and it has halo-sizes of 1 and 2 on the sharded dim.
-    %0 = mesh.shard %arg0 to <@mesh0, [[0]] {<halo_sizes = [1, 2]>}> : tensor<4x8xf32>
+    %halo_sharding = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2]
+    %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>
+    
+    // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
+    // and it has pre-defined shard sizes. The shards of the devices will have
+    // the following shapes: [4x2, 4x3, 4x4, 4x5]
+    %sharding4 = mesh.sharding @mesh1d_4, [[], [0]] sharded_dims_sizes = [2, 3, 4, 5]
+    %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];
     
@@ -289,7 +296,7 @@ def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
   );
   let results = (outs Variadic<Index>:$result);
   let assemblyFormat = [{
-      $shape $sharding $device attr-dict `:` type($result)
+      custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
   }];
   let builders = [
     OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
@@ -304,14 +311,15 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
   let summary = "Annotate on how a tensor is sharded across a mesh.";
   let description = [{
     The mesh.shard operation is designed to specify and guide the sharding
-    behavior of a tensor value across a mesh topology. This operation has one
-    operand and two attributes:
+    behavior of a tensor value across a mesh topology. This operation has two
+    operands and two optional attributes:
 
     1. `input`: This operand represents the tensor value that needs to be
     annotated for sharding.
 
     2. `sharding`: This attribute is type of `MeshShardingType`, which is the core data
-    structure to represent distribution of a tensor on a mesh.
+    structure to represent distribution of a tensor on a mesh. it is typically defiend
+    by an `mesh.sharding` operation.
 
     3. `annotate_for_users`: A unit attribute addressing the scenario when a
     tensor's sharding annotation differs based on its context of use (either as
@@ -320,19 +328,28 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
     as an operand in subsequent operations. If not, the sharding applies to the
     operation that defines the tensor value.
 
-    4. `force`: A unit attribute requesting an explicit sharding of the data not
-    allowing to be optimizied away. This is useful in the presence of halos and
-    inplace semantics.
+    4. `force`: A unit attribute requesting an explicit sharding of the data,
+    therefore not allowing to be optimizied away. This is useful in the presence
+    of halos and inplace semantics.
 
     Example:
     ```
     func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
-      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+      %sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding
+      %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
       ...
     }
 
     func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
-      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
+      %sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding
+      %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
+      ...
+    }
+    
+    func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
+      %sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding
+      %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
+      %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
       ...
     }
 
@@ -341,9 +358,12 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
     // operand of op2
     func.func @both_result_and_multi_operands_annotated(
         %arg0 : tensor<4x8xf32>) -> () {
-      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
-      %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
-      %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
+      %sharding = mesh.sharding @mesh0, [[0]] : !mesh.sharding
+      %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
+      %sharding1 = mesh.sharding @mesh0, [[1]] : !mesh.sharding
+      %1 = mesh.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+      %sharding2 = mesh.sharding @mesh0, [[2]] : !mesh.sharding
+      %2 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
       "op0"(%1) : ...
       "op1"(%2) : ...
       ...
@@ -354,29 +374,37 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
     ```
     func.func @annotate_on_same_result_with_different_sharding(
         %arg0 : tensor<4x8xf32>) -> () {
-      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
-      %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+      %sharding1 = mesh.sharding @mesh0, [[0]] : !mesh.sharding
+      %sharding2 = mesh.sharding @mesh0, [[1]] : !mesh.sharding
+      %0 = mesh.shard %arg0 to $sharding1 : tensor<4x8xf32>
+      %1 = mesh.shard %0 to sharding2 : tensor<4x8xf32>
       ...
     }
 
     func.func @annotate_on_same_result_same_value_with_different_sharding(
         %arg0 : tensor<4x8xf32>) -> () {
-      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
-      %1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+      %sharding1 = mesh.sharding @mesh0, [[0]] : !mesh.sharding
+      %sharding2 = mesh.sharding @mesh0, [[1]] : !mesh.sharding
+      %0 = mesh.shard %arg0 to %sharding1 : tensor<4x8xf32>
+      %1 = mesh.shard %arg0 to %sharding2 : tensor<4x8xf32>
       ...
     }
 
     func.func @annotate_on_same_operand_with_different_sharding(
         %arg0 : tensor<4x8xf32>) -> () {
-      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
-      %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
+      %sharding1 = mesh.sharding @mesh0, [[0]] : !mesh.sharding
+      %sharding2 = mesh.sharding @mesh0, [[1]] : !mesh.sharding
+      %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+      %1 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
       ...
     }
 
     func.func @result_annotated_after_operand(
         %arg0 : tensor<4x8xf32>) -> () {
-      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
-      %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+      %sharding1 = mesh.sharding @mesh0, [[0]] : !mesh.sharding
+      %sharding2 = mesh.sharding @mesh0, [[1]] : !mesh.sharding
+      %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+      %1 = mesh.shard %0 to %sharding2 : tensor<4x8xf32>
       ...
     }
     ```
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 74d8a086d5b34..8c9c505321c80 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -129,6 +129,40 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
 
+// CHECK-LABEL: func @mesh_shard_halo_sizes
+func.func @mesh_shard_halo_sizes() -> () {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : i64
+  %c3 = arith.constant 3 : i64
+  // CHECK: mesh.sharding @mesh4, {{\[\[}}0]] halo_sizes = [1, 4] : !mesh.sharding
+  %sharding1 = mesh.sharding @mesh4, [[0]] halo_sizes = [1, 4] : !mesh.sharding
+  // CHECK: mesh.sharding @mesh4, {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !mesh.sharding
+  %sharding2 = mesh.sharding @mesh4, [[0]] halo_sizes = [4, %c3] : !mesh.sharding
+  return
+}
+
+// CHECK-LABEL: func @mesh_shard_dim_sizes
+func.func @mesh_shard_dim_sizes() -> () {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : i64
+  %c3 = arith.constant 3 : i64
+  // CHECK: mesh.sharding @mesh4, {{\[\[}}0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
+  %sharding1 = mesh.sharding @mesh4, [[0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
+  // CHECK: mesh.sharding @mesh4, {{\[\[}}0]] sharded_dims_sizes = [4, %[[C3]], 1] : !mesh.sharding
+  %sharding2 = mesh.sharding @mesh4, [[0]] sharded_dims_sizes = [4, %c3, 1] : !mesh.sharding
+  return
+}
+
+// CHECK-LABEL: func @mesh_shard_op_force
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+func.func @mesh_shard_op_force(%arg0 : tensor<4x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) {
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0, {{\[\[}}]] : !mesh.sharding
+  %s = mesh.sharding @mesh0, [[]] : !mesh.sharding
+  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] force : tensor<4x8xf32>
+  %1 = mesh.shard %arg0 to %s force : tensor<4x8xf32>
+  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] annotate_for_users force : tensor<4x8xf32>
+  %2 = mesh.shard %arg0 to %s annotate_for_users force : tensor<4x8xf32>
+  return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
+}
+
 // CHECK-LABEL: func @mesh_shape
 func.func @mesh_shape() -> (index, index) {
   // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
new file mode 100644
index 0000000000000..b1fd8e6a423b1
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -0,0 +1,38 @@
+mesh.mesh @mesh_1d_4(shape = 4)
+
+// CHECK-LABEL: func @tensor_empty_static_sharded_dims_sizes
+func.func @tensor_empty_static_sharded_dims_sizes() -> () {
+  %b = tensor.empty() : tensor<8x16xf32>
+  %sharding = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
+  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4, {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
+  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
+  // CHECK:  tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
+
+  return
+}
+
+// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_sizes
+// CHECK-SAME: %[[A0:.*]]: index
+func.func @tensor_empty_dynamic_sharded_dims_sizes(%arg0 : index) -> () {
+  %b = tensor.empty(%arg0) : tensor<8x?xf32>
+  %sharding = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
+  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4, {{\[\[}}0]] sharded_dims_sizes = [1, 3, 3, 1] : !mesh.sharding
+  // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
+  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
+  // CHECK:  tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
+
+  return
+}
+
+// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
+func.func @tensor_empty_same_static_dims_sizes() -> () {
+  %b = tensor.empty() : tensor<8x16xf32>
+  %sharding = mesh.sharding @mesh_1d_4, [[0]] sharded_dims_sizes = [4, 4, 4, 4] : !mesh.sharding
+  %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
+  // CHECK-NEXT:  tensor.empty() : tensor<4x16xf32>
+
+  return
+}

>From 2926d8c54dfc5f910d9f64702d12e0e2413e04af Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Tue, 9 Jul 2024 13:08:51 +0200
Subject: [PATCH 6/8] adding test mesh_shape

---
 mlir/test/Dialect/Mesh/ops.mlir | 17 +++++++++++++++--
 1 file changed, 15 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 8c9c505321c80..629011771b482 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -140,8 +140,8 @@ func.func @mesh_shard_halo_sizes() -> () {
   return
 }
 
-// CHECK-LABEL: func @mesh_shard_dim_sizes
-func.func @mesh_shard_dim_sizes() -> () {
+// CHECK-LABEL: func @mesh_shard_dims_sizes
+func.func @mesh_shard_dims_sizes() -> () {
   // CHECK: %[[C3:.*]] = arith.constant 3 : i64
   %c3 = arith.constant 3 : i64
   // CHECK: mesh.sharding @mesh4, {{\[\[}}0]] sharded_dims_sizes = [1, 4, 2] : !mesh.sharding
@@ -163,6 +163,19 @@ func.func @mesh_shard_op_force(%arg0 : tensor<4x8xf32>) -> (tensor<4x8xf32>, ten
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
 
+// CHECK-LABEL: func @mesh_shard_shape
+func.func @mesh_shard_shape() {
+  // CHECK: %[[C3:.*]] = arith.constant 3 : index
+  %c3 = arith.constant 3 : index
+  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0, {{\[\[}}]] : !mesh.sharding
+  %s = mesh.sharding @mesh0, [[]] : !mesh.sharding
+  // CHECK-NEXT: mesh.shard_shape 8x? %[[S]] %[[C3]] : index, index
+  %shp:2 = mesh.shard_shape 8x? %s %c3 : index, index
+  // CHECK-NEXT: mesh.shard_shape 8x4 %[[S]] %[[C3]] : index, index
+  %shp1:2 = mesh.shard_shape 8x4 %s %c3 : index, index
+  return
+}
+
 // CHECK-LABEL: func @mesh_shape
 func.func @mesh_shape() -> (index, index) {
   // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index

>From 1f46e21327a49ba08a045872f88ee3e629d6072d Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Tue, 9 Jul 2024 13:28:33 +0200
Subject: [PATCH 7/8] clang-format

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  89 ++++---
 .../Mesh/Interfaces/ShardingInterface.h       |  21 +-
 .../Mesh/Interfaces/ShardingInterfaceImpl.h   |  12 +-
 .../mlir/Interfaces/InferTypeOpInterface.h    |   2 +-
 .../Transforms/MeshShardingInterfaceImpl.cpp  |   7 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 240 ++++++++++--------
 .../Mesh/Interfaces/ShardingInterface.cpp     |  67 ++---
 .../Mesh/Transforms/ShardingPropagation.cpp   |  12 +-
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  77 +++---
 .../Tensor/IR/ShardingInterfaceImpl.cpp       |  10 +-
 .../Mesh/TestReshardingSpmdization.cpp        |   9 +-
 11 files changed, 301 insertions(+), 245 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 6e1afcde5f0f5..3c467d6f95948 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -39,43 +39,47 @@ namespace mlir {
 namespace mesh {
 
 class MeshSharding {
-  private:
-    ::mlir::FlatSymbolRefAttr mesh;
-    SmallVector<MeshAxesAttr> split_axes;
-    SmallVector<MeshAxis> partial_axes;
-    ReductionKind partial_type;
-    SmallVector<int64_t> static_halo_sizes;
-    SmallVector<int64_t> static_sharded_dims_sizes;
-    SmallVector<Value> dynamic_halo_sizes;
-    SmallVector<Value> dynamic_sharded_dims_sizes;
-  public:
-    MeshSharding() = default;
-    MeshSharding(Value rhs);
-    static MeshSharding get(
-        ::mlir::FlatSymbolRefAttr mesh_,
-        ArrayRef<MeshAxesAttr> split_axes_,
-        ArrayRef<MeshAxis> partial_axes_ = {},
-        ReductionKind partial_type_ = ReductionKind::Sum,
-        ArrayRef<int64_t> static_halo_sizes_ = {},
-        ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
-        ArrayRef<Value> dynamic_halo_sizes_ = {},
-        ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
-    ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
-    ::llvm::StringRef getMesh() const { return mesh.getValue(); }
-    ArrayRef<MeshAxesAttr> getSplitAxes() const {return split_axes; }
-    ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
-    ReductionKind getPartialType() const { return partial_type; }
-    ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
-    ArrayRef<int64_t> getStaticShardedDimsSizes() const { return static_sharded_dims_sizes; }
-    ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
-    ArrayRef<Value> getDynamicShardedDimsSizes() const { return dynamic_sharded_dims_sizes; }
-    operator bool() const { return (!mesh) == false; }
-    bool operator==(Value rhs) const;
-    bool operator!=(Value rhs) const;
-    bool operator==(const MeshSharding &rhs) const;
-    bool operator!=(const MeshSharding &rhs) const;
-    bool sameExceptConstraint(const MeshSharding &rhs) const;
-    bool sameConstraint(const MeshSharding &rhs) const;
+private:
+  ::mlir::FlatSymbolRefAttr mesh;
+  SmallVector<MeshAxesAttr> split_axes;
+  SmallVector<MeshAxis> partial_axes;
+  ReductionKind partial_type;
+  SmallVector<int64_t> static_halo_sizes;
+  SmallVector<int64_t> static_sharded_dims_sizes;
+  SmallVector<Value> dynamic_halo_sizes;
+  SmallVector<Value> dynamic_sharded_dims_sizes;
+
+public:
+  MeshSharding() = default;
+  MeshSharding(Value rhs);
+  static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
+                          ArrayRef<MeshAxesAttr> split_axes_,
+                          ArrayRef<MeshAxis> partial_axes_ = {},
+                          ReductionKind partial_type_ = ReductionKind::Sum,
+                          ArrayRef<int64_t> static_halo_sizes_ = {},
+                          ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+                          ArrayRef<Value> dynamic_halo_sizes_ = {},
+                          ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+  ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
+  ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+  ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
+  ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
+  ReductionKind getPartialType() const { return partial_type; }
+  ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
+  ArrayRef<int64_t> getStaticShardedDimsSizes() const {
+    return static_sharded_dims_sizes;
+  }
+  ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
+  ArrayRef<Value> getDynamicShardedDimsSizes() const {
+    return dynamic_sharded_dims_sizes;
+  }
+  operator bool() const { return (!mesh) == false; }
+  bool operator==(Value rhs) const;
+  bool operator!=(Value rhs) const;
+  bool operator==(const MeshSharding &rhs) const;
+  bool operator!=(const MeshSharding &rhs) const;
+  bool sameExceptConstraint(const MeshSharding &rhs) const;
+  bool sameConstraint(const MeshSharding &rhs) const;
 };
 
 } // namespace mesh
@@ -131,9 +135,10 @@ mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
 template <>
 inline mesh::MeshOp
 getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
-  return getMesh(op.getOperation(),
-                 cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
-                 symbolTableCollection);
+  return getMesh(
+      op.getOperation(),
+      cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
+      symbolTableCollection);
 }
 
 // Get the number of processes that participate in each group
@@ -196,8 +201,8 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                          OpOperand &operand,
                                          OpBuilder &builder);
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                         OpResult result, OpBuilder &builder);
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
+                                         OpBuilder &builder);
 void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
                                          OpOperand &operand,
                                          OpBuilder &builder);
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index dd662247dc639..b4d25cef05a7b 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -46,19 +46,16 @@ struct ShardingOption {
 
 // This method retrieves the 'MeshSharding' from a given operation
 // result and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshSharding>>
-getMeshSharding(OpResult result);
+FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpResult result);
 
 // This method retrieves the 'MeshSharding' from a given operation
 // operand and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshSharding>>
-getMeshSharding(OpOperand &opOperand);
+FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpOperand &opOperand);
 
 namespace detail {
 
 FailureOr<ShardingOption>
-defaultGetShardingOption(Operation *op,
-                         ArrayRef<MeshSharding> operandShardings,
+defaultGetShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
                          ArrayRef<MeshSharding> resultShardings);
 
 FailureOr<std::vector<MeshSharding>>
@@ -72,11 +69,13 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
 } // namespace detail
 
 // Assumes full replication on all ranked tensor arguments and results.
-void spmdizeFullyReplicatedOperation(
-    Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshSharding> operandShardings,
-    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
-    SymbolTableCollection &symbolTable, OpBuilder &builder);
+void spmdizeFullyReplicatedOperation(Operation &op,
+                                     ArrayRef<Value> spmdizedOperands,
+                                     ArrayRef<MeshSharding> operandShardings,
+                                     ArrayRef<MeshSharding> resultShardings,
+                                     IRMapping &spmdizationMap,
+                                     SymbolTableCollection &symbolTable,
+                                     OpBuilder &builder);
 
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index a25ba2bf649b0..2af8b2bd1d906 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -42,11 +42,13 @@ SmallVector<MeshAxis> getReductionMeshAxes(
 
 // Inserts a clone of the operation that has all ranked tensor
 // arguments/results sharded.
-void spmdizeTriviallyShardableOperation(
-    Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshSharding> operandShardings,
-    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
-    SymbolTableCollection &symbolTable, OpBuilder &builder);
+void spmdizeTriviallyShardableOperation(Operation &op,
+                                        ArrayRef<Value> spmdizedOperands,
+                                        ArrayRef<MeshSharding> operandShardings,
+                                        ArrayRef<MeshSharding> resultShardings,
+                                        IRMapping &spmdizationMap,
+                                        SymbolTableCollection &symbolTable,
+                                        OpBuilder &builder);
 
 // All ranked tensor argument and result dimensions have
 // independent parallel loop iterators.
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 237cfea223b66..47bcfc9bbd4f9 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -110,7 +110,7 @@ class ShapedTypeComponents {
 
 public:
   /// Default construction is an unranked shape.
-  ShapedTypeComponents() : elementType(nullptr), attr(nullptr){};
+  ShapedTypeComponents() : elementType(nullptr), attr(nullptr) {};
   ShapedTypeComponents(Type elementType)
       : elementType(elementType), attr(nullptr), ranked(false) {}
   ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 33686d344e828..d47a82b59bcad 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -102,8 +102,7 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
   return getReductionKind(reductionOp.value());
 }
 
-static MeshOp getMesh(Operation *op,
-                      ArrayRef<MeshSharding> operandShardings,
+static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
                       ArrayRef<MeshSharding> resultShardings,
                       SymbolTableCollection &symbolTable) {
   for (MeshSharding sharding : operandShardings) {
@@ -199,8 +198,8 @@ static void createAllReduceForResultWithoutPartialSharding(
 
   Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
   Value reducedValue = builder.create<mesh::AllReduceOp>(
-      spmdizedLinalgOpResult, resultSharding.getMesh(),
-      allReduceMeshAxes, reductionKind);
+      spmdizedLinalgOpResult, resultSharding.getMesh(), allReduceMeshAxes,
+      reductionKind);
   spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
 }
 
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index bb6674fd02ecf..e39f267f66b11 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -216,9 +216,9 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
                                  MeshSharding sharding) {
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
-  shardShape(
-      shape.getShape(), mesh.getShape(), sharding.getSplitAxes(), resShapeArr,
-      sharding.getStaticShardedDimsSizes(), sharding.getStaticHaloSizes());
+  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+             resShapeArr, sharding.getStaticShardedDimsSizes(),
+             sharding.getStaticHaloSizes());
   return shape.clone(resShapeArr);
 }
 
@@ -258,8 +258,9 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
     return;
   }
 
-  auto newShardOp2 = builder.create<ShardOp>(
-      operandValue.getLoc(), newShardOp, shardingOp, /*annotate_for_users*/ true);
+  auto newShardOp2 =
+      builder.create<ShardOp>(operandValue.getLoc(), newShardOp, shardingOp,
+                              /*annotate_for_users*/ true);
   rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
 }
 
@@ -288,7 +289,8 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
   }
 
   builder.setInsertionPoint(operandOp);
-  auto shardingOp = builder.create<ShardingOp>(operand.get().getLoc(), sharding);
+  auto shardingOp =
+      builder.create<ShardingOp>(operand.get().getLoc(), sharding);
   auto newShardOp =
       builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
                               /*annotate_for_users*/ true);
@@ -386,45 +388,68 @@ void MeshShapeOp::getAsmResultNames(
 // mesh.sharding
 //===----------------------------------------------------------------------===//
 
-void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes, ArrayRef<MeshAxis> partial_axes, mesh::ReductionKind partial_type, ArrayRef<int64_t> static_halo_sizes, ArrayRef<int64_t> static_sharded_dims_sizes) {
-      // SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
-      //             split_axes, [&](ArrayRef<MeshAxis> array) {
-      //     return MeshAxesAttr::get(b.getContext(), array);
-      // });
-      return build(b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
-                   ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes), ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
-                   ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {}, ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes), {});
-    
-}
-
-void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes) {
-      return build(b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
-                   {}, {}, {}, {});
-    
-}
-
-void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes, ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes, ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
-      mlir::SmallVector<int64_t> staticHalos, staticDims;
-      mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
-      dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
-      dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
-      return build(b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
-                   ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos, ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
-    
-}
-
-
-void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, mlir::mesh::MeshSharding from) {
-
-  build(b, odsState,
-        ShardingType::get(b.getContext()),
-        from.getMeshAttr(),
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+                       FlatSymbolRefAttr mesh,
+                       ArrayRef<MeshAxesAttr> split_axes,
+                       ArrayRef<MeshAxis> partial_axes,
+                       mesh::ReductionKind partial_type,
+                       ArrayRef<int64_t> static_halo_sizes,
+                       ArrayRef<int64_t> static_sharded_dims_sizes) {
+  // SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
+  //             split_axes, [&](ArrayRef<MeshAxis> array) {
+  //     return MeshAxesAttr::get(b.getContext(), array);
+  // });
+  return build(
+      b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
+      ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
+      ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halo_sizes), {},
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_sharded_dims_sizes),
+      {});
+}
+
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+                       FlatSymbolRefAttr mesh,
+                       ArrayRef<MeshAxesAttr> split_axes) {
+  return build(
+      b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
+      ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+      {}, {}, {}, {});
+}
+
+void ShardingOp::build(
+    ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+    FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
+    ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
+    ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_sizes) {
+  mlir::SmallVector<int64_t> staticHalos, staticDims;
+  mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
+  dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
+  dispatchIndexOpFoldResults(sharded_dims_sizes, dynamicDims, staticDims);
+  return build(
+      b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
+      ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
+      ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
+}
+
+void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
+                       mlir::mesh::MeshSharding from) {
+
+  build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
         MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
-        from.getPartialAxes().empty() ? DenseI16ArrayAttr() : b.getDenseI16ArrayAttr(from.getPartialAxes()),
-        ::mlir::mesh::ReductionKindAttr::get(b.getContext(), from.getPartialType()),
-        from.getStaticShardedDimsSizes().empty() ? DenseI64ArrayAttr() : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
+        from.getPartialAxes().empty()
+            ? DenseI16ArrayAttr()
+            : b.getDenseI16ArrayAttr(from.getPartialAxes()),
+        ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
+                                             from.getPartialType()),
+        from.getStaticShardedDimsSizes().empty()
+            ? DenseI64ArrayAttr()
+            : b.getDenseI64ArrayAttr(from.getStaticShardedDimsSizes()),
         from.getDynamicShardedDimsSizes(),
-        from.getStaticHaloSizes().empty() ? DenseI64ArrayAttr() : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
+        from.getStaticHaloSizes().empty()
+            ? DenseI64ArrayAttr()
+            : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
         from.getDynamicHaloSizes());
 }
 
@@ -432,16 +457,23 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, m
 // MeshSharding
 //===----------------------------------------------------------------------===//
 
-    // ::mlir::FlatSymbolRefAttr MeshSharding::getMeshAttr() const { return mesh; }
-    // ::llvm::StringRef MeshSharding::getMesh() const { return mesh.getValue(); }
-    // ArrayRef<MeshAxesAttr> MeshSharding::getSplitAxes() const {return split_axes; }
-    // ArrayRef<MeshAxis> MeshSharding::getPartialAxes() const { if (partial_axes.empty()) return {}; return partial_axes; }
-    // ReductionKind MeshSharding::getPartialType() const { return partial_type; }
-    // ArrayRef<int64_t> MeshSharding::getStaticHaloSizes() const { if(static_halo_sizes.empty()) return {}; return static_halo_sizes; }
-    // ArrayRef<int64_t> MeshSharding::getStaticShardedDimsSizes() const { if(static_sharded_dims_sizes.empty()) return {}; return static_sharded_dims_sizes; }
-    // ArrayRef<Value> MeshSharding::getDynamicHaloSizes() const { if(dynamic_halo_sizes.empty()) return {}; return dynamic_halo_sizes; }
-    // ArrayRef<Value> MeshSharding::getDynamicShardedDimsSizes() const { if(dynamic_sharded_dims_sizes.empty()) return {}; return dynamic_sharded_dims_sizes; }
-    // operator MeshSharding::bool() const { return (!mesh) == false; }
+// ::mlir::FlatSymbolRefAttr MeshSharding::getMeshAttr() const { return mesh; }
+// ::llvm::StringRef MeshSharding::getMesh() const { return mesh.getValue(); }
+// ArrayRef<MeshAxesAttr> MeshSharding::getSplitAxes() const {return split_axes;
+// } ArrayRef<MeshAxis> MeshSharding::getPartialAxes() const { if
+// (partial_axes.empty()) return {}; return partial_axes; } ReductionKind
+// MeshSharding::getPartialType() const { return partial_type; }
+// ArrayRef<int64_t> MeshSharding::getStaticHaloSizes() const {
+// if(static_halo_sizes.empty()) return {}; return static_halo_sizes; }
+// ArrayRef<int64_t> MeshSharding::getStaticShardedDimsSizes() const {
+// if(static_sharded_dims_sizes.empty()) return {}; return
+// static_sharded_dims_sizes; } ArrayRef<Value>
+// MeshSharding::getDynamicHaloSizes() const { if(dynamic_halo_sizes.empty())
+// return {}; return dynamic_halo_sizes; } ArrayRef<Value>
+// MeshSharding::getDynamicShardedDimsSizes() const {
+// if(dynamic_sharded_dims_sizes.empty()) return {}; return
+// dynamic_sharded_dims_sizes; } operator MeshSharding::bool() const { return
+// (!mesh) == false; }
 
 bool MeshSharding::sameExceptConstraint(const MeshSharding &rhs) const {
   if (getMesh() != rhs.getMesh()) {
@@ -469,37 +501,43 @@ bool MeshSharding::sameExceptConstraint(const MeshSharding &rhs) const {
 }
 
 bool MeshSharding::sameConstraint(const MeshSharding &rhs) const {
-    if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() 
-        || !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(), getStaticHaloSizes().end()),
-                        llvm::make_range(rhs.getStaticHaloSizes().begin(), rhs.getStaticHaloSizes().end()))) {
-      return false;
-    }
-    if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size()
-        || !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(), getStaticShardedDimsSizes().end()),
-                        llvm::make_range(rhs.getStaticShardedDimsSizes().begin(), rhs.getStaticShardedDimsSizes().end()))) {
-      return false;
-    }
-    if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size()
-        || !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(), getDynamicHaloSizes().end()),
-                        llvm::make_range(rhs.getDynamicHaloSizes().begin(), rhs.getDynamicHaloSizes().end()))) {
-      return false;
-    }
-    if (rhs.getDynamicShardedDimsSizes().size() != getDynamicShardedDimsSizes().size()
-        || !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(), getDynamicShardedDimsSizes().end()),
-                        llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(), rhs.getDynamicShardedDimsSizes().end()))) {
-      return false;
-    }
-    return true;
+  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
+      !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
+                                    getStaticHaloSizes().end()),
+                   llvm::make_range(rhs.getStaticHaloSizes().begin(),
+                                    rhs.getStaticHaloSizes().end()))) {
+    return false;
+  }
+  if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size() ||
+      !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
+                                    getStaticShardedDimsSizes().end()),
+                   llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
+                                    rhs.getStaticShardedDimsSizes().end()))) {
+    return false;
+  }
+  if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size() ||
+      !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
+                                    getDynamicHaloSizes().end()),
+                   llvm::make_range(rhs.getDynamicHaloSizes().begin(),
+                                    rhs.getDynamicHaloSizes().end()))) {
+    return false;
+  }
+  if (rhs.getDynamicShardedDimsSizes().size() !=
+          getDynamicShardedDimsSizes().size() ||
+      !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
+                                    getDynamicShardedDimsSizes().end()),
+                   llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(),
+                                    rhs.getDynamicShardedDimsSizes().end()))) {
+    return false;
+  }
+  return true;
 }
 
 bool MeshSharding::operator==(Value rhs) const {
-  return sameExceptConstraint(rhs)
-         && sameConstraint(rhs);
+  return sameExceptConstraint(rhs) && sameConstraint(rhs);
 }
 
-bool MeshSharding::operator!=(Value rhs) const {
-  return !(*this == rhs);
-}
+bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
 
 bool MeshSharding::operator==(const MeshSharding &rhs) const {
   return sameExceptConstraint(rhs) && sameConstraint(rhs);
@@ -512,29 +550,29 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const {
 MeshSharding::MeshSharding(Value rhs) {
   auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
   assert(shardingOp && "expected sharding op");
-  *this = get(shardingOp.getMeshAttr(),
-      shardingOp.getSplitAxes().getAxes(),
-      shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
-      shardingOp.getPartialType().value_or(ReductionKind::Sum),
-      shardingOp.getStaticHaloSizes(),
-      shardingOp.getStaticShardedDimsSizes(),
-      SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
-      SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
+  *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
+              shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
+              shardingOp.getPartialType().value_or(ReductionKind::Sum),
+              shardingOp.getStaticHaloSizes(),
+              shardingOp.getStaticShardedDimsSizes(),
+              SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
+              SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
 }
 
 MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
-                  ArrayRef<MeshAxesAttr> split_axes_,
-                  ArrayRef<MeshAxis> partial_axes_,
-                  ReductionKind partial_type_,
-                  ArrayRef<int64_t> static_halo_sizes_,
-                  ArrayRef<int64_t> static_sharded_dims_sizes_,
-                  ArrayRef<Value> dynamic_halo_sizes_,
-                  ArrayRef<Value> dynamic_sharded_dims_sizes_) {
+                               ArrayRef<MeshAxesAttr> split_axes_,
+                               ArrayRef<MeshAxis> partial_axes_,
+                               ReductionKind partial_type_,
+                               ArrayRef<int64_t> static_halo_sizes_,
+                               ArrayRef<int64_t> static_sharded_dims_sizes_,
+                               ArrayRef<Value> dynamic_halo_sizes_,
+                               ArrayRef<Value> dynamic_sharded_dims_sizes_) {
   MeshSharding res;
   res.mesh = mesh_;
   res.split_axes.resize(split_axes_.size());
   for (auto [i, axis] : llvm::enumerate(split_axes_)) {
-    res.split_axes[i] = MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
+    res.split_axes[i] =
+        MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
   }
 
   auto do_copy = [&](auto src, auto &dst) {
@@ -576,14 +614,15 @@ LogicalResult ShardingOp::verify() {
     if (failed(checkMeshAxis(subAxesArray)))
       return failure();
   }
-  if (getPartialAxes().has_value() && failed(checkMeshAxis(getPartialAxes().value())))
+  if (getPartialAxes().has_value() &&
+      failed(checkMeshAxis(getPartialAxes().value())))
     return failure();
 
   if (!getStaticHaloSizes().empty() && !getStaticShardedDimsSizes().empty()) {
     return emitOpError("halo sizes and shard shapes are mutually exclusive");
   }
-  
-  if(!getStaticHaloSizes().empty()) {
+
+  if (!getStaticHaloSizes().empty()) {
     auto numSplitAxes = getSplitAxes().getAxes().size();
     for (auto splitAxis : getSplitAxes().getAxes()) {
       if (splitAxis.empty()) {
@@ -607,7 +646,10 @@ void ShardingOp::getAsmResultNames(
 // mesh.shard_shape
 //===----------------------------------------------------------------------===//
 
-void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<int64_t> shape, ::mlir::Value sharding, ::mlir::Value device) {
+void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
+                         ::mlir::OperationState &odsState,
+                         ::llvm::ArrayRef<int64_t> shape,
+                         ::mlir::Value sharding, ::mlir::Value device) {
   SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
   build(odsBuilder, odsState, resType, shape, sharding, device);
 }
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index df42a335e89df..2c5513b0b2c1c 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -91,8 +91,9 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
   return positions;
 }
 
-template<typename T>
-SmallVector<MeshAxesAttr> fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {
+template <typename T>
+SmallVector<MeshAxesAttr>
+fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {
   SmallVector<MeshAxesAttr> res;
   for (const auto &v : vec) {
     res.emplace_back(MeshAxesAttr::get(ctxt, v));
@@ -140,7 +141,7 @@ mesh::getMeshSharding(OpResult result) {
     for (size_t i = 1; i < shardOps.size(); ++i) {
       // TODO: Deduce a reasonable mesh sharding attr for def when they are
       // different
-      assert(shardForDef == shardOps[i].getSharding()  &&
+      assert(shardForDef == shardOps[i].getSharding() &&
              "only support all shard ops have the same mesh sharding attr");
     }
     return std::make_pair(true, shardForDef);
@@ -152,7 +153,8 @@ FailureOr<std::pair<bool, MeshSharding>>
 mesh::getMeshSharding(OpOperand &opOperand) {
   Value val = opOperand.get();
   if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
-    return std::make_pair(shardOp.getAnnotateForUsers(), MeshSharding(shardOp.getSharding()));
+    return std::make_pair(shardOp.getAnnotateForUsers(),
+                          MeshSharding(shardOp.getSharding()));
 
   return failure();
 }
@@ -259,9 +261,10 @@ static LogicalResult fillShardingOption(Operation *op,
 
 } // namespace
 
-FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
-    Operation *op, ArrayRef<MeshSharding> operandShardings,
-    ArrayRef<MeshSharding> resultShardings) {
+FailureOr<ShardingOption>
+mesh::detail::defaultGetShardingOption(Operation *op,
+                                       ArrayRef<MeshSharding> operandShardings,
+                                       ArrayRef<MeshSharding> resultShardings) {
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
   ShardingOption shardingOption;
 
@@ -343,8 +346,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
       if (loopIndices->size() == 1) {
         unsigned loopIdx = *loopIndices->begin();
         visitedLoopIndices.insert(loopIdx);
-        if (failed(fillShardingOption(op, shardingOption, shardAttr.getMeshAttr(),
-                                      axes, loopIdx)))
+        if (failed(fillShardingOption(op, shardingOption,
+                                      shardAttr.getMeshAttr(), axes, loopIdx)))
           return failure();
       }
       // If multiple loop indices correspond to a dimension of an operand, it is
@@ -398,10 +401,11 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 }
 
 // Get the sharding attributed for the given result and sharding option.
-MeshSharding
-getShardingAttribute(OpResult result, const ShardingOption &shardingOption,
-                     AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
-                     ArrayRef<ReductionKind> reductionLoopKinds) {
+MeshSharding getShardingAttribute(OpResult result,
+                                  const ShardingOption &shardingOption,
+                                  AffineMap map,
+                                  ArrayRef<utils::IteratorType> loopTypes,
+                                  ArrayRef<ReductionKind> reductionLoopKinds) {
   auto resultType = cast<RankedTensorType>(result.getType());
   SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
   SmallVector<MeshAxis> partialAxes;
@@ -438,7 +442,8 @@ getShardingAttribute(OpResult result, const ShardingOption &shardingOption,
 
   removeTrailingEmptySubArray(splitAxes);
   return MeshSharding::get(shardingOption.mesh,
-                           fromArrayOfVector(result.getContext(), splitAxes), partialAxes, partialType);
+                           fromArrayOfVector(result.getContext(), splitAxes),
+                           partialAxes, partialType);
 }
 
 static FailureOr<MeshSharding>
@@ -471,7 +476,9 @@ getShardingAttribute(OpOperand &opOperand, const ShardingOption &shardingOption,
   }
 
   removeTrailingEmptySubArray(splitAxes);
-  return MeshSharding::get(shardingOption.mesh, fromArrayOfVector(opOperand.get().getContext(), splitAxes));
+  return MeshSharding::get(
+      shardingOption.mesh,
+      fromArrayOfVector(opOperand.get().getContext(), splitAxes));
 }
 
 FailureOr<std::vector<MeshSharding>>
@@ -515,8 +522,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                 AffineMap map,
                                 ArrayRef<utils::IteratorType> loopTypes,
                                 ArrayRef<ReductionKind> reductionLoopKinds) {
-  MeshSharding sharding = getShardingAttribute(
-      result, shardingOption, map, loopTypes, reductionLoopKinds);
+  MeshSharding sharding = getShardingAttribute(result, shardingOption, map,
+                                               loopTypes, reductionLoopKinds);
   maybeInsertTargetShardingAnnotation(sharding, result, b);
 
   return success();
@@ -581,19 +588,19 @@ isValueCompatibleWithFullReplicationSharding(Value value,
 }
 
 template <typename ValueRange, typename MeshShardingRage>
-static bool areValuesCompatibleWithFullReplicationShardings(
-    ValueRange &&values, MeshShardingRage &&shardings) {
+static bool
+areValuesCompatibleWithFullReplicationShardings(ValueRange &&values,
+                                                MeshShardingRage &&shardings) {
   if (std::size(values) != std::size(shardings)) {
     return false;
   }
-  return llvm::all_of(llvm::zip_equal(
-                          std::forward<ValueRange>(values),
-                          std::forward<MeshShardingRage>(shardings)),
-                      [](auto valueAndSharding) {
-                        return isValueCompatibleWithFullReplicationSharding(
-                            std::get<0>(valueAndSharding),
-                            std::get<1>(valueAndSharding));
-                      });
+  return llvm::all_of(
+      llvm::zip_equal(std::forward<ValueRange>(values),
+                      std::forward<MeshShardingRage>(shardings)),
+      [](auto valueAndSharding) {
+        return isValueCompatibleWithFullReplicationSharding(
+            std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
+      });
 }
 #endif // NDEBUG
 
@@ -703,8 +710,8 @@ void mesh::spmdizeTriviallyShardableOperation(
   // Set the result types to the sharded counterparts.
   for (auto [oldResult, newResult, sharding] :
        llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
-    newResult.setType(shardType(newResult.getType(),
-                                getMesh(&op, sharding.getMeshAttr(), symbolTable),
-                                sharding));
+    newResult.setType(
+        shardType(newResult.getType(),
+                  getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
   }
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index c98d975392dbb..4bd3b425219c1 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -116,8 +116,7 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
 
   std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
     if (i == mustShardings.size()) {
-      allShardingAttrs.push_back(
-          std::vector<MeshSharding>(curShardingAttrs));
+      allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
       return;
     }
 
@@ -158,8 +157,7 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
 // 3. All other cases. Resharding is required for operands/results with
 //   annotation targeting explicitly this operation.
 ReshardingRquirementKind getReshardingRquirementKind(
-    Operation *op,
-    const std::vector<MeshSharding> &operandAndResultShardings) {
+    Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
   ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
 
   size_t operandsCount = op->getOperands().size();
@@ -223,8 +221,7 @@ static FailureOr<ShardingOption> selectShardingOption(
   SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
       shardingOptionsAndReshardingRequirements;
 
-  for (ArrayRef<MeshSharding> resultShardings :
-       possibleResultShardingAttrs) {
+  for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
     for (ArrayRef<MeshSharding> operandShardings :
          possibleOperandShardingAttrs) {
       FailureOr<ShardingOption> shardingOption =
@@ -285,7 +282,8 @@ static FailureOr<ShardingOption> selectShardingOption(
 // a `mesh.shard` operation for all remaining operands and results that do not
 // have sharding annotations.
 static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
-  if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
+  if (op->hasTrait<OpTrait::IsTerminator>() ||
+      llvm::isa<mesh::ShardOp, mesh::ShardingOp>(op))
     return success();
 
   ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 8d3a355a6bbe5..a53677adca88b 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -99,16 +99,16 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
                 [&targetShardingPartialAxesSet](Axis a) {
                   return targetShardingPartialAxesSet.contains(a);
                 });
-  MeshSharding resultSharding =
-      MeshSharding::get(sourceSharding.getMeshAttr(),
-                            sourceSharding.getSplitAxes(), remainingPartialAxes,
-                            sourceSharding.getPartialType());
+  MeshSharding resultSharding = MeshSharding::get(
+      sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
+      remainingPartialAxes, sourceSharding.getPartialType());
   return {resultValue, resultSharding};
 }
 
-static MeshSharding
-targetShardingInSplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding,
-                              int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
+                                                  MeshSharding sourceSharding,
+                                                  int64_t splitTensorAxis,
+                                                  MeshAxis splitMeshAxis) {
   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
@@ -228,10 +228,9 @@ detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
   return std::nullopt;
 }
 
-static MeshSharding
-targetShardingInUnsplitLastAxis(MLIRContext *ctx,
-                                MeshSharding sourceSharding,
-                                int64_t splitTensorAxis) {
+static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+                                                    MeshSharding sourceSharding,
+                                                    int64_t splitTensorAxis) {
   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
@@ -344,10 +343,10 @@ detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
   return std::nullopt;
 }
 
-static MeshSharding
-targetShardingInMoveLastAxis(MLIRContext *ctx, MeshSharding sourceSharding,
-                             int64_t sourceTensorAxis,
-                             int64_t targetTensorAxis) {
+static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
+                                                 MeshSharding sourceSharding,
+                                                 int64_t sourceTensorAxis,
+                                                 int64_t targetTensorAxis) {
   SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
@@ -438,15 +437,15 @@ updateHalosInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   assert(sourceSharding.getMesh() == targetSharding.getMesh());
   assert(sourceSharding.getSplitAxes() == targetSharding.getSplitAxes());
 
-  auto res = builder
-                 .create<UpdateHaloOp>(
-                     sourceShard.getType(), // update halo keeps the source type
-                     mesh.getSymName(), SmallVector<MeshAxis>(), sourceShard,
-                     ::mlir::DenseI64ArrayAttr::get(
-                         builder.getContext(),
-                         sourceSharding.getStaticHaloSizes()),
-                     nullptr)
-                 .getResult();
+  auto res =
+      builder
+          .create<UpdateHaloOp>(
+              sourceShard.getType(), // update halo keeps the source type
+              mesh.getSymName(), SmallVector<MeshAxis>(), sourceShard,
+              ::mlir::DenseI64ArrayAttr::get(
+                  builder.getContext(), sourceSharding.getStaticHaloSizes()),
+              nullptr)
+          .getResult();
   return cast<TypedValue<ShapedType>>(res);
 }
 
@@ -477,10 +476,8 @@ tryUpdateHalosInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
 // Currently the sharded tensor axes must be exactly divisible by the single
 // mesh axis size.
 static TypedValue<ShapedType>
-reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                bool force,
-                MeshSharding sourceSharding,
-                MeshSharding targetSharding,
+reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, bool force,
+                MeshSharding sourceSharding, MeshSharding targetSharding,
                 TypedValue<ShapedType> sourceUnshardedValue,
                 TypedValue<ShapedType> sourceShard) {
   assert(sourceShard.getType() ==
@@ -500,8 +497,7 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
 
   TypedValue<ShapedType> targetShard;
   MeshSharding actualTargetSharding;
-  if (!force &&
-      reducedSourceSharding.getStaticHaloSizes().empty() &&
+  if (!force && reducedSourceSharding.getStaticHaloSizes().empty() &&
       targetSharding.getStaticHaloSizes().empty() &&
       reducedSourceSharding.getStaticShardedDimsSizes().empty() &&
       targetSharding.getStaticShardedDimsSizes().empty()) {
@@ -518,7 +514,7 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
                    sourceUnshardedValue.getType(), reducedSourceShard)) {
       std::tie(targetShard, actualTargetSharding) = tryRes.value();
     }
-  } else if(force) {
+  } else if (force) {
     if (auto tryRes = tryUpdateHalosInResharding(
             builder, mesh, reducedSourceSharding, targetSharding,
             sourceUnshardedValue.getType(), reducedSourceShard)) {
@@ -532,8 +528,7 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
 }
 
 TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                               bool force,
-                               MeshSharding sourceSharding,
+                               bool force, MeshSharding sourceSharding,
                                MeshSharding targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
@@ -551,8 +546,8 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
   auto sourceSharding = source.getSharding();
   auto targetSharding = target.getSharding();
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
-  return reshard(implicitLocOpBuilder, mesh, target.getForce(), sourceSharding, targetSharding,
-                 cast<TypedValue<ShapedType>>(source.getSrc()),
+  return reshard(implicitLocOpBuilder, mesh, target.getForce(), sourceSharding,
+                 targetSharding, cast<TypedValue<ShapedType>>(source.getSrc()),
                  sourceShardValue);
 }
 
@@ -600,11 +595,13 @@ shardedBlockArgumentTypes(Block &block,
   return res;
 }
 
-void spmdizeTriviallyShardableOperation(
-    Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshSharding> operandShardings,
-    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
-    SymbolTableCollection &symbolTable, OpBuilder &builder);
+void spmdizeTriviallyShardableOperation(Operation &op,
+                                        ArrayRef<Value> spmdizedOperands,
+                                        ArrayRef<MeshSharding> operandShardings,
+                                        ArrayRef<MeshSharding> resultShardings,
+                                        IRMapping &spmdizationMap,
+                                        SymbolTableCollection &symbolTable,
+                                        OpBuilder &builder);
 
 static LogicalResult spmdizeOperation(
     Operation &op, ArrayRef<Value> spmdizedOperands,
diff --git a/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp
index 15995a7bae038..4a79e983d955d 100644
--- a/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp
@@ -53,7 +53,7 @@ struct EmptyOpShardingInterface
     // if the sharding introduces a new dynamic dimension, we take it from
     // the dynamic sharding info. For now bail out if it's not
     // provided.
-		assert(resultShardings.size() == 1);
+    assert(resultShardings.size() == 1);
     if (!shardType.hasStaticShape()) {
       assert(op->getResult(0).hasOneUse());
       SmallVector<Value> newOperands;
@@ -62,15 +62,17 @@ struct EmptyOpShardingInterface
       int currOldOprndNum = -1;
       mesh::ShardShapeOp shapeForDevice;
       Value device;
-			Operation *newSharding = nullptr;
+      Operation *newSharding = nullptr;
       for (auto i = 0; i < oldType.getRank(); ++i) {
         if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
           if (!newSharding) {
-						newSharding = builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
+            newSharding =
+                builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
             device = builder.create<mesh::ProcessLinearIndexOp>(
                 op->getLoc(), resultShardings[0].getMesh());
             shapeForDevice = builder.create<mesh::ShardShapeOp>(
-                op->getLoc(), oldType.getShape(), newSharding->getResult(0), device);
+                op->getLoc(), oldType.getShape(), newSharding->getResult(0),
+                device);
           }
           newOperands.emplace_back(shapeForDevice.getResult()[i]);
         } else if (oldType.isDynamicDim(i)) {
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
index b72bcdbad3edd..98992c4cc11f9 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -44,7 +44,10 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
       if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
         if (targetShardOp.getAnnotateForUsers() &&
             mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
-                        targetShardOp, cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp()).getMeshAttr())) {
+                        targetShardOp,
+                        cast<ShardingOp>(
+                            targetShardOp.getSharding().getDefiningOp())
+                            .getMeshAttr())) {
           foundUser = true;
           break;
         }
@@ -59,7 +62,9 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
       auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
       if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
           symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
-              targetShardOp, cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp()).getMeshAttr()) != mesh) {
+              targetShardOp,
+              cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
+                  .getMeshAttr()) != mesh) {
         continue;
       }
 

>From 2ddab145342e60cbfc7ac04b81cc359711001e37 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Tue, 9 Jul 2024 15:46:15 +0200
Subject: [PATCH 8/8] fixing axes parameter to mesh.update_halo, cleanup

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  9 +++---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 32 +++----------------
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  9 ++++--
 .../Tensor/IR/ShardingInterfaceImpl.cpp       |  2 +-
 mlir/test/Dialect/Mesh/spmdization.mlir       |  8 ++---
 .../test/Dialect/Tensor/mesh-spmdization.mlir |  4 +++
 6 files changed, 25 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index b707015cc3684..49c4037942f6f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -1059,9 +1059,10 @@ def Mesh_UpdateHaloOp : Mesh_CollectiveCommunicationOpBase<"update_halo", [
     AllShapesMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
-  let summary = "Send over a device mesh.";
+  let summary = "Update halo data.";
   let description = [{
-    Send from one device to another within a device group.
+    Assume all devices hold a tensors with same-sized halo data.
+    Optionally resize to new halo sizes.
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
@@ -1072,11 +1073,11 @@ def Mesh_UpdateHaloOp : Mesh_CollectiveCommunicationOpBase<"update_halo", [
     AnyRankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    $input `on` $mesh
+    (`mesh_axes` `=` $mesh_axes^)?
     `halo_sizes` `=` $dynamic_halo_sizes
     (`target_halo_sizes` `=` $target_halo_sizes^)?
     attr-dict `:` functional-type(operands, results)
   }];
-  // let hasCanonicalizer = 1;
 }
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index e39f267f66b11..a00d49886e606 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -170,6 +170,9 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                  inShape[inAxis] == ShapedType::kDynamic);
         }
       } else {
+        // find sharded dims in sharded_dims_sizes with same static size on
+        // all devices. Use kDynamic for dimensions with dynamic or different
+        // sizes in sharded_dims_sizes.
         auto sz = shardedDims[tensorAxis];
         bool same = true;
         for (size_t i = tensorAxis + inShape.size(); i < shardedDims.size();
@@ -179,11 +182,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
             break;
           }
         }
-        if (same) {
-          outShape[tensorAxis] = sz;
-        } else {
-          outShape[tensorAxis] = ShapedType::kDynamic;
-        }
+        outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
       }
     }
   } else {
@@ -194,6 +193,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
     }
 
     if (!haloSizes.empty()) {
+      // add halo sizes if requested
       int haloAxis = 0;
       for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
         if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
@@ -395,10 +395,6 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
                        mesh::ReductionKind partial_type,
                        ArrayRef<int64_t> static_halo_sizes,
                        ArrayRef<int64_t> static_sharded_dims_sizes) {
-  // SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
-  //             split_axes, [&](ArrayRef<MeshAxis> array) {
-  //     return MeshAxesAttr::get(b.getContext(), array);
-  // });
   return build(
       b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
@@ -457,24 +453,6 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
 // MeshSharding
 //===----------------------------------------------------------------------===//
 
-// ::mlir::FlatSymbolRefAttr MeshSharding::getMeshAttr() const { return mesh; }
-// ::llvm::StringRef MeshSharding::getMesh() const { return mesh.getValue(); }
-// ArrayRef<MeshAxesAttr> MeshSharding::getSplitAxes() const {return split_axes;
-// } ArrayRef<MeshAxis> MeshSharding::getPartialAxes() const { if
-// (partial_axes.empty()) return {}; return partial_axes; } ReductionKind
-// MeshSharding::getPartialType() const { return partial_type; }
-// ArrayRef<int64_t> MeshSharding::getStaticHaloSizes() const {
-// if(static_halo_sizes.empty()) return {}; return static_halo_sizes; }
-// ArrayRef<int64_t> MeshSharding::getStaticShardedDimsSizes() const {
-// if(static_sharded_dims_sizes.empty()) return {}; return
-// static_sharded_dims_sizes; } ArrayRef<Value>
-// MeshSharding::getDynamicHaloSizes() const { if(dynamic_halo_sizes.empty())
-// return {}; return dynamic_halo_sizes; } ArrayRef<Value>
-// MeshSharding::getDynamicShardedDimsSizes() const {
-// if(dynamic_sharded_dims_sizes.empty()) return {}; return
-// dynamic_sharded_dims_sizes; } operator MeshSharding::bool() const { return
-// (!mesh) == false; }
-
 bool MeshSharding::sameExceptConstraint(const MeshSharding &rhs) const {
   if (getMesh() != rhs.getMesh()) {
     return false;
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index a53677adca88b..2ff18c23cc416 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -437,11 +437,16 @@ updateHalosInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   assert(sourceSharding.getMesh() == targetSharding.getMesh());
   assert(sourceSharding.getSplitAxes() == targetSharding.getSplitAxes());
 
+  SmallVector<MeshAxis> splitMeshAxes;
+  for (auto axis : targetSharding.getSplitAxes()) {
+    assert(axis.size() == 1);
+    splitMeshAxes.emplace_back(axis[0]);
+  }
   auto res =
       builder
           .create<UpdateHaloOp>(
               sourceShard.getType(), // update halo keeps the source type
-              mesh.getSymName(), SmallVector<MeshAxis>(), sourceShard,
+              mesh.getSymName(), splitMeshAxes, sourceShard,
               ::mlir::DenseI64ArrayAttr::get(
                   builder.getContext(), sourceSharding.getStaticHaloSizes()),
               nullptr)
@@ -612,8 +617,6 @@ static LogicalResult spmdizeOperation(
   if (!shardingInterface) {
     // If there is no sharding interface we are conservative and assume that
     // the op should be fully replicated no all devices.
-    // FIXME
-    // spmdizeTriviallyShardableOperation
     spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
                                     resultShardings, spmdizationMap,
                                     symbolTableCollection, builder);
diff --git a/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp
index 4a79e983d955d..354015367451e 100644
--- a/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/ShardingInterfaceImpl.cpp
@@ -22,7 +22,7 @@ using namespace mlir::mesh;
 
 namespace {
 
-// Sharding of elementwise operations like tensor addition and multiplication.
+// Sharding of tensor.empty
 struct EmptyOpShardingInterface
     : public ShardingInterface::ExternalModel<EmptyOpShardingInterface,
                                               tensor::EmptyOp> {
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 0b1c143c25334..751a4ea258431 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect \
+// RUN: mlir-opt \
 // RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
 // RUN:   %s | FileCheck %s
 
@@ -195,7 +195,7 @@ func.func @update_halo_static(
   %in1: tensor<32x16xi8>
   // CHECK-SAME: -> tensor<11x16xi8> {
 ) -> tensor<32x16xi8> {
-  // CHECK: %[[RES:.*]] = mesh.update_halo %[[IN1]] on @mesh_1d_4 halo_sizes = [2, 1] : (tensor<11x16xi8>) -> tensor<11x16xi8>
+  // CHECK: %[[RES:.*]] = mesh.update_halo %[[IN1]] on @mesh_1d_4  mesh_axes = [0] halo_sizes = [2, 1] : (tensor<11x16xi8>) -> tensor<11x16xi8>
   %sin1_sharded1 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
   %in1_sharded1 = mesh.shard %in1 to %sin1_sharded1  : tensor<32x16xi8>
   %in1_sharded2 = mesh.shard %in1_sharded1 to %sin1_sharded1 annotate_for_users force : tensor<32x16xi8>
@@ -209,7 +209,7 @@ func.func @update_halo_dynamic(
   %in1: tensor<?x16xi8>
   // CHECK-SAME: -> tensor<?x16xi8> {
 ) -> tensor<?x16xi8> {
-  // CHECK: %[[RES:.*]] = mesh.update_halo %[[IN1]] on @mesh_1d_4 halo_sizes = [2, 1] : (tensor<?x16xi8>) -> tensor<?x16xi8>
+  // CHECK: %[[RES:.*]] = mesh.update_halo %[[IN1]] on @mesh_1d_4  mesh_axes = [0] halo_sizes = [2, 1] : (tensor<?x16xi8>) -> tensor<?x16xi8>
   %sin1_sharded1 = mesh.sharding @mesh_1d_4, [[0]] halo_sizes = [2, 1] : !mesh.sharding
   %in1_sharded1 = mesh.shard %in1 to %sin1_sharded1  : tensor<?x16xi8>
   %in1_sharded2 = mesh.shard %in1_sharded1 to %sin1_sharded1 annotate_for_users force : tensor<?x16xi8>
@@ -224,7 +224,7 @@ func.func @update_halo_dynamic_mesh(
   %in1: tensor<32x16xi8>
   // CHECK-SAME: -> tensor<?x16xi8> {
 ) -> tensor<32x16xi8> {
-  // CHECK: %[[RES:.*]] = mesh.update_halo %[[IN1]] on @mesh_1d_dyn halo_sizes = [2, 1] : (tensor<?x16xi8>) -> tensor<?x16xi8>
+  // CHECK: %[[RES:.*]] = mesh.update_halo %[[IN1]] on @mesh_1d_dyn  mesh_axes = [0] halo_sizes = [2, 1] : (tensor<?x16xi8>) -> tensor<?x16xi8>
   %sin1_sharded1 = mesh.sharding @mesh_1d_dyn, [[0]] halo_sizes = [2, 1] : !mesh.sharding
   %in1_sharded1 = mesh.shard %in1 to %sin1_sharded1  : tensor<32x16xi8>
   %in1_sharded2 = mesh.shard %in1_sharded1 to %sin1_sharded1 annotate_for_users force : tensor<32x16xi8>
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index b1fd8e6a423b1..0e2ed5b466ae1 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -1,3 +1,7 @@
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN:   %s | FileCheck %s
+
 mesh.mesh @mesh_1d_4(shape = 4)
 
 // CHECK-LABEL: func @tensor_empty_static_sharded_dims_sizes



More information about the Mlir-commits mailing list