[Mlir-commits] [llvm] [mlir] [NFC][mlir][mesh, shard] Fixing misnomers in mesh dialect, renaming 'mesh' dialect to 'shard' (PR #150177)

Frank Schlimbach llvmlistbot at llvm.org
Thu Jul 24 02:30:52 PDT 2025


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/150177

>From 48c6790984ca58630304bde4a32009263952858e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 21 Jul 2025 15:00:41 +0200
Subject: [PATCH 1/2] Fixing misnomers in mesh dialect

  - dialect name mesh -> shard
  - (device) mesh -> (device) grid
  - spmdize -> partition
---
 mlir/docs/Dialects/{Mesh.md => Shard.md}      |  34 +-
 mlir/docs/Passes.md                           |   4 +-
 mlir/include/mlir/Conversion/Passes.h         |   2 +-
 mlir/include/mlir/Conversion/Passes.td        |   8 +-
 .../MeshToMPI.h => ShardToMPI/ShardToMPI.h}   |  10 +-
 mlir/include/mlir/Dialect/CMakeLists.txt      |   2 +-
 ...rdingExtensions.h => ShardingExtensions.h} |   2 +-
 ...nterfaceImpl.h => ShardingInterfaceImpl.h} |  10 +-
 .../mlir/Dialect/Mesh/IR/CMakeLists.txt       |  25 -
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |   6 -
 .../Dialect/{Mesh => Shard}/CMakeLists.txt    |   0
 .../mlir/Dialect/Shard/IR/CMakeLists.txt      |  25 +
 .../IR/MeshBase.td => Shard/IR/ShardBase.td}  |  46 +-
 .../MeshDialect.h => Shard/IR/ShardDialect.h} |  10 +-
 .../IR/MeshOps.h => Shard/IR/ShardOps.h}      | 146 +++--
 .../IR/MeshOps.td => Shard/IR/ShardOps.td}    | 420 +++++++-------
 .../{Mesh => Shard}/Interfaces/CMakeLists.txt |   0
 .../Interfaces/ShardingInterface.h            |  58 +-
 .../Interfaces/ShardingInterface.td           |  48 +-
 .../Interfaces/ShardingInterfaceImpl.h        |  79 ++-
 .../Dialect/Shard/Transforms/CMakeLists.txt   |   6 +
 .../Transforms/Partition.h}                   |  26 +-
 .../{Mesh => Shard}/Transforms/Passes.h       |  16 +-
 .../{Mesh => Shard}/Transforms/Passes.td      |  48 +-
 .../Transforms/ReshardingPartitionDoc.md}     | 144 ++---
 .../Transforms/Simplifications.h              |  14 +-
 .../{Mesh => Shard}/Transforms/Transforms.h   |  28 +-
 ...rdingExtensions.h => ShardingExtensions.h} |   2 +-
 mlir/include/mlir/InitAllDialects.h           |   4 +-
 mlir/include/mlir/InitAllPasses.h             |   4 +-
 mlir/lib/Conversion/CMakeLists.txt            |   2 +-
 .../{MeshToMPI => ShardToMPI}/CMakeLists.txt  |   8 +-
 .../ShardToMPI.cpp}                           | 118 ++--
 .../Dialect/Arith/Transforms/CMakeLists.txt   |   2 +-
 .../Transforms/ShardingInterfaceImpl.cpp      |  36 +-
 mlir/lib/Dialect/CMakeLists.txt               |   2 +-
 .../Dialect/Func/Extensions/AllExtensions.cpp |   2 +-
 .../Dialect/Func/Extensions/CMakeLists.txt    |   8 +-
 ...gExtensions.cpp => ShardingExtensions.cpp} |   8 +-
 mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp  |   6 +-
 .../Linalg/Transforms/AllInterfaces.cpp       |   4 +-
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   4 +-
 ...faceImpl.cpp => ShardingInterfaceImpl.cpp} | 167 +++---
 .../Dialect/{Mesh => Shard}/CMakeLists.txt    |   0
 .../Dialect/{Mesh => Shard}/IR/CMakeLists.txt |   8 +-
 .../IR/MeshOps.cpp => Shard/IR/ShardOps.cpp}  | 526 +++++++++---------
 .../{Mesh => Shard}/Interfaces/CMakeLists.txt |   4 +-
 .../Interfaces/ShardingInterface.cpp          | 270 +++++----
 .../{Mesh => Shard}/Transforms/CMakeLists.txt |  10 +-
 .../Transforms/Partition.cpp}                 | 393 +++++++------
 .../Transforms/ShardingPropagation.cpp        |  79 ++-
 .../Transforms/Simplifications.cpp            |  64 +--
 .../{Mesh => Shard}/Transforms/Transforms.cpp |  82 +--
 .../Transforms/TransformsDetail.h             |  10 +-
 .../Tensor/Extensions/AllExtensions.cpp       |   2 +-
 .../Dialect/Tensor/Extensions/CMakeLists.txt  |   8 +-
 ...gExtensions.cpp => ShardingExtensions.cpp} |  38 +-
 mlir/lib/Dialect/Tosa/CMakeLists.txt          |   2 +-
 .../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp |  26 +-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          |   4 +-
 .../convert-shard-to-mpi.mlir}                |  66 +--
 .../convert-shardshape-to-mpi.mlir            |  40 +-
 mlir/test/Dialect/Arith/mesh-spmdize.mlir     |  17 -
 mlir/test/Dialect/Arith/shard-partition.mlir  |  17 +
 .../Dialect/Arith/sharding-propagation.mlir   |  60 +-
 .../Linalg/mesh-sharding-propagation.mlir     |  42 --
 ...-spmdization.mlir => shard-partition.mlir} | 118 ++--
 .../Dialect/Linalg/sharding-propagation.mlir  |  42 ++
 mlir/test/Dialect/Mesh/canonicalization.mlir  | 248 ---------
 mlir/test/Dialect/Mesh/folding.mlir           |  22 -
 .../Mesh/forward-sharding-propagation.mlir    |  49 --
 mlir/test/Dialect/Mesh/inlining.mlir          |  15 -
 .../Mesh/process-multi-index-op-lowering.mlir |  23 -
 .../Dialect/Mesh/resharding-spmdization.mlir  | 168 ------
 .../Dialect/Mesh/sharding-propagation.mlir    | 301 ----------
 mlir/test/Dialect/Mesh/spmdization.mlir       | 317 -----------
 .../all-scatter-op-lowering.mlir              |  40 +-
 .../backward-sharding-propagation.mlir        |  12 +-
 mlir/test/Dialect/Shard/canonicalization.mlir | 248 +++++++++
 mlir/test/Dialect/Shard/folding.mlir          |  22 +
 ...forward-backward-sharding-propagation.mlir |  14 +-
 .../Shard/forward-sharding-propagation.mlir   |  49 ++
 mlir/test/Dialect/Shard/inlining.mlir         |  15 +
 .../test/Dialect/{Mesh => Shard}/invalid.mlir | 442 +++++++--------
 mlir/test/Dialect/{Mesh => Shard}/ops.mlir    | 350 ++++++------
 mlir/test/Dialect/Shard/partition.mlir        | 317 +++++++++++
 .../process-multi-index-op-lowering.mlir      |  23 +
 .../Dialect/Shard/resharding-partition.mlir   | 168 ++++++
 .../sharding-propagation-failed.mlir          |   0
 .../Dialect/Shard/sharding-propagation.mlir   | 301 ++++++++++
 .../{Mesh => Shard}/simplifications.mlir      |  78 +--
 .../test/Dialect/Tensor/mesh-spmdization.mlir |  52 --
 mlir/test/Dialect/Tensor/shard-partition.mlir |  52 ++
 mlir/test/lib/Dialect/CMakeLists.txt          |   2 +-
 .../Dialect/{Mesh => Shard}/CMakeLists.txt    |  10 +-
 .../{Mesh => Shard}/TestOpLowering.cpp        |  18 +-
 .../TestReshardingPartition.cpp}              |  40 +-
 .../{Mesh => Shard}/TestSimplifications.cpp   |  24 +-
 mlir/tools/mlir-opt/CMakeLists.txt            |   2 +-
 mlir/tools/mlir-opt/mlir-opt.cpp              |   8 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     | 146 ++---
 .../mlir/test/BUILD.bazel                     |   8 +-
 102 files changed, 3542 insertions(+), 3564 deletions(-)
 rename mlir/docs/Dialects/{Mesh.md => Shard.md} (73%)
 rename mlir/include/mlir/Conversion/{MeshToMPI/MeshToMPI.h => ShardToMPI/ShardToMPI.h} (64%)
 rename mlir/include/mlir/Dialect/Func/Extensions/{MeshShardingExtensions.h => ShardingExtensions.h} (88%)
 rename mlir/include/mlir/Dialect/Linalg/Transforms/{MeshShardingInterfaceImpl.h => ShardingInterfaceImpl.h} (54%)
 delete mode 100644 mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
 delete mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
 rename mlir/include/mlir/Dialect/{Mesh => Shard}/CMakeLists.txt (100%)
 create mode 100644 mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt
 rename mlir/include/mlir/Dialect/{Mesh/IR/MeshBase.td => Shard/IR/ShardBase.td} (64%)
 rename mlir/include/mlir/Dialect/{Mesh/IR/MeshDialect.h => Shard/IR/ShardDialect.h} (57%)
 rename mlir/include/mlir/Dialect/{Mesh/IR/MeshOps.h => Shard/IR/ShardOps.h} (52%)
 rename mlir/include/mlir/Dialect/{Mesh/IR/MeshOps.td => Shard/IR/ShardOps.td} (70%)
 rename mlir/include/mlir/Dialect/{Mesh => Shard}/Interfaces/CMakeLists.txt (100%)
 rename mlir/include/mlir/Dialect/{Mesh => Shard}/Interfaces/ShardingInterface.h (52%)
 rename mlir/include/mlir/Dialect/{Mesh => Shard}/Interfaces/ShardingInterface.td (80%)
 rename mlir/include/mlir/Dialect/{Mesh => Shard}/Interfaces/ShardingInterfaceImpl.h (58%)
 create mode 100644 mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt
 rename mlir/include/mlir/Dialect/{Mesh/Transforms/Spmdization.h => Shard/Transforms/Partition.h} (61%)
 rename mlir/include/mlir/Dialect/{Mesh => Shard}/Transforms/Passes.h (75%)
 rename mlir/include/mlir/Dialect/{Mesh => Shard}/Transforms/Passes.td (65%)
 rename mlir/include/mlir/Dialect/{Mesh/Transforms/ReshardingSpmdizationDoc.md => Shard/Transforms/ReshardingPartitionDoc.md} (87%)
 rename mlir/include/mlir/Dialect/{Mesh => Shard}/Transforms/Simplifications.h (93%)
 rename mlir/include/mlir/Dialect/{Mesh => Shard}/Transforms/Transforms.h (65%)
 rename mlir/include/mlir/Dialect/Tensor/Extensions/{MeshShardingExtensions.h => ShardingExtensions.h} (88%)
 rename mlir/lib/Conversion/{MeshToMPI => ShardToMPI}/CMakeLists.txt (65%)
 rename mlir/lib/Conversion/{MeshToMPI/MeshToMPI.cpp => ShardToMPI/ShardToMPI.cpp} (92%)
 rename mlir/lib/Dialect/Func/Extensions/{MeshShardingExtensions.cpp => ShardingExtensions.cpp} (68%)
 rename mlir/lib/Dialect/Linalg/Transforms/{MeshShardingInterfaceImpl.cpp => ShardingInterfaceImpl.cpp} (65%)
 rename mlir/lib/Dialect/{Mesh => Shard}/CMakeLists.txt (100%)
 rename mlir/lib/Dialect/{Mesh => Shard}/IR/CMakeLists.txt (59%)
 rename mlir/lib/Dialect/{Mesh/IR/MeshOps.cpp => Shard/IR/ShardOps.cpp} (76%)
 rename mlir/lib/Dialect/{Mesh => Shard}/Interfaces/CMakeLists.txt (76%)
 rename mlir/lib/Dialect/{Mesh => Shard}/Interfaces/ShardingInterface.cpp (70%)
 rename mlir/lib/Dialect/{Mesh => Shard}/Transforms/CMakeLists.txt (73%)
 rename mlir/lib/Dialect/{Mesh/Transforms/Spmdization.cpp => Shard/Transforms/Partition.cpp} (66%)
 rename mlir/lib/Dialect/{Mesh => Shard}/Transforms/ShardingPropagation.cpp (85%)
 rename mlir/lib/Dialect/{Mesh => Shard}/Transforms/Simplifications.cpp (66%)
 rename mlir/lib/Dialect/{Mesh => Shard}/Transforms/Transforms.cpp (78%)
 rename mlir/lib/Dialect/{Mesh => Shard}/Transforms/TransformsDetail.h (82%)
 rename mlir/lib/Dialect/Tensor/Extensions/{MeshShardingExtensions.cpp => ShardingExtensions.cpp} (74%)
 rename mlir/test/Conversion/{MeshToMPI/convert-mesh-to-mpi.mlir => ShardToMPI/convert-shard-to-mpi.mlir} (90%)
 rename mlir/test/Conversion/{MeshToMPI => ShardToMPI}/convert-shardshape-to-mpi.mlir (62%)
 delete mode 100644 mlir/test/Dialect/Arith/mesh-spmdize.mlir
 create mode 100644 mlir/test/Dialect/Arith/shard-partition.mlir
 delete mode 100644 mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
 rename mlir/test/Dialect/Linalg/{mesh-spmdization.mlir => shard-partition.mlir} (50%)
 create mode 100644 mlir/test/Dialect/Linalg/sharding-propagation.mlir
 delete mode 100644 mlir/test/Dialect/Mesh/canonicalization.mlir
 delete mode 100644 mlir/test/Dialect/Mesh/folding.mlir
 delete mode 100644 mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
 delete mode 100644 mlir/test/Dialect/Mesh/inlining.mlir
 delete mode 100644 mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
 delete mode 100644 mlir/test/Dialect/Mesh/resharding-spmdization.mlir
 delete mode 100644 mlir/test/Dialect/Mesh/sharding-propagation.mlir
 delete mode 100644 mlir/test/Dialect/Mesh/spmdization.mlir
 rename mlir/test/Dialect/{Mesh => Shard}/all-scatter-op-lowering.mlir (72%)
 rename mlir/test/Dialect/{Mesh => Shard}/backward-sharding-propagation.mlir (67%)
 create mode 100644 mlir/test/Dialect/Shard/canonicalization.mlir
 create mode 100644 mlir/test/Dialect/Shard/folding.mlir
 rename mlir/test/Dialect/{Mesh => Shard}/forward-backward-sharding-propagation.mlir (63%)
 create mode 100644 mlir/test/Dialect/Shard/forward-sharding-propagation.mlir
 create mode 100644 mlir/test/Dialect/Shard/inlining.mlir
 rename mlir/test/Dialect/{Mesh => Shard}/invalid.mlir (57%)
 rename mlir/test/Dialect/{Mesh => Shard}/ops.mlir (55%)
 create mode 100644 mlir/test/Dialect/Shard/partition.mlir
 create mode 100644 mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir
 create mode 100644 mlir/test/Dialect/Shard/resharding-partition.mlir
 rename mlir/test/Dialect/{Mesh => Shard}/sharding-propagation-failed.mlir (100%)
 create mode 100644 mlir/test/Dialect/Shard/sharding-propagation.mlir
 rename mlir/test/Dialect/{Mesh => Shard}/simplifications.mlir (69%)
 delete mode 100644 mlir/test/Dialect/Tensor/mesh-spmdization.mlir
 create mode 100644 mlir/test/Dialect/Tensor/shard-partition.mlir
 rename mlir/test/lib/Dialect/{Mesh => Shard}/CMakeLists.txt (51%)
 rename mlir/test/lib/Dialect/{Mesh => Shard}/TestOpLowering.cpp (80%)
 rename mlir/test/lib/Dialect/{Mesh/TestReshardingSpmdization.cpp => Shard/TestReshardingPartition.cpp} (75%)
 rename mlir/test/lib/Dialect/{Mesh => Shard}/TestSimplifications.cpp (60%)

diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Shard.md
similarity index 73%
rename from mlir/docs/Dialects/Mesh.md
rename to mlir/docs/Dialects/Shard.md
index 5eb6569c7044b..714b340db4cde 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Shard.md
@@ -1,28 +1,28 @@
-# 'mesh' Dialect
+# 'shard' Dialect
 
-The `mesh` dialect contains a set of attributes, operations and interfaces that
-are useful for representing sharding and communication on a device mesh
+The `shard` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device grid
 cluster.
 
 [TOC]
 
 ## Collective Communication Operations
-There are a number of operations in the Mesh dialect to facilitate
-communication between devices in a mesh.
+There are a number of operations in the Shard dialect to facilitate
+communication between devices in a grid.
 It is assumed that the user is familiar with collective operations.
 [Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
 explanation.
-The main addition is that the collectives in this dialect have mesh
+The main addition is that the collectives in this dialect have grid
 semantics.
 
 ### Device groups
-The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
+The operation attributes `grid` and `grid_axes` specifies a list of device grid
 axes that partition the devices into disjoint groups.
 The collective operation is performed between devices in the same group.
-Devices that have the same coordinates outside of axes `mesh_axes` are in the
+Devices that have the same coordinates outside of axes `grid_axes` are in the
 same group.
-A group is described by its multi-index along the axes outside of `mesh_axes`.
-For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+A group is described by its multi-index along the axes outside of `grid_axes`.
+For example if we have a device grid of size `2x3x4x5` and the partition grid
 axes list is `[0, 1]` then devices are partitioned into the groups
 `{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
 The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`.
@@ -31,7 +31,7 @@ Device (1, 0, 2, 4) will be in another group.
 Some collective operations like all-to-all and all-gather care about the
 order of devices.
 The order of device in a device group is induced by the order of axes in
-`mesh_axes`.
+`grid_axes`.
 The axes are ordered from outer to inner.
 If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
 both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
@@ -39,11 +39,11 @@ both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
 ### In-group Device
 Some operations like `broadcast`, `scatter` and `send` specify devices in each
 device-group.
-These devices are represented with their multi-index over the mesh axes that
+These devices are represented with their multi-index over the grid axes that
 are not constant within a device group.
-These are the axes specified by `mesh_axes` attribute.
+These are the axes specified by `grid_axes` attribute.
 
-For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
+For Example on a 3D grid an operation with `grid_axes = [0, 2]` would specify
 an in-group device with `(i, j)`. Then for each group with index `g` on the
 second axis, the in-group device would be `(i, g, j)`.
 ### Purity
@@ -60,15 +60,15 @@ For example if a collective operation is optimized out, than it must also
 not appear in any path of execution on any process.
 
 Having the operations as `Pure` implies that if an interpreter is to execute
-the IR containing the `mesh` collectives, all processes would execute the same
+the IR containing the `grid` collectives, all processes would execute the same
 line when they reach a pure collective operation.
 This requirement stems from the need to be compatible with general optimization
 passes like dead code and common sub-expression elimination.
 
 ## Operations
 
-[include "Dialects/MeshOps.md"]
+[include "Dialects/ShardOps.md"]
 
 ## Attributes
 
-[include "Dialects/MeshAttrs.md"]
+[include "Dialects/ShardAttrs.md"]
diff --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md
index e9d22d1e3dfac..9df32666415bb 100644
--- a/mlir/docs/Passes.md
+++ b/mlir/docs/Passes.md
@@ -72,9 +72,9 @@ This document describes the available MLIR passes and their contracts.
 
 [include "MemRefPasses.md"]
 
-## 'mesh' Dialect Passes
+## 'shard' Dialect Passes
 
-[include "MeshPasses.md"]
+[include "ShardPasses.md"]
 
 ## 'ml\_program' Dialect Passes
 
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index d93fbefab74aa..3dc48b2201cf2 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -52,7 +52,6 @@
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
-#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
@@ -66,6 +65,7 @@
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
+#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
 #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
 #include "mlir/Conversion/TosaToArith/TosaToArith.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8183f355795a9..eb18160ea2eeb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -903,13 +903,13 @@ def ConvertMemRefToSPIRVPass : Pass<"convert-memref-to-spirv"> {
 }
 
 //===----------------------------------------------------------------------===//
-// MeshToMPI
+// ShardToMPI
 //===----------------------------------------------------------------------===//
 
-def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
-  let summary = "Convert Mesh dialect to MPI dialect.";
+def ConvertShardToMPIPass : Pass<"convert-shard-to-mpi"> {
+  let summary = "Convert Shard dialect to MPI dialect.";
   let description = [{
-    This pass converts communication operations from the Mesh dialect to the
+    This pass converts communication operations from the Shard dialect to the
     MPI dialect.
     If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
     use that integer value instead of calling MPI_Comm_rank. This allows
diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/ShardToMPI/ShardToMPI.h
similarity index 64%
rename from mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
rename to mlir/include/mlir/Conversion/ShardToMPI/ShardToMPI.h
index bc64e7a3c1c8c..b1aa08c432249 100644
--- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
+++ b/mlir/include/mlir/Conversion/ShardToMPI/ShardToMPI.h
@@ -1,4 +1,4 @@
-//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===//
+//===- ShardToMPI.h - Convert Shard to MPI dialect --------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
-#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#ifndef MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H
+#define MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H
 
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
@@ -15,9 +15,9 @@
 namespace mlir {
 class Pass;
 
-#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
+#define GEN_PASS_DECL_CONVERTSHARDTOMPIPASS
 #include "mlir/Conversion/Passes.h.inc"
 
 } // namespace mlir
 
-#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#endif // MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 56dc97282fa4a..e27b1679c2a52 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -19,7 +19,7 @@ add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
-add_subdirectory(Mesh)
+add_subdirectory(Shard)
 add_subdirectory(MLProgram)
 add_subdirectory(MPI)
 add_subdirectory(NVGPU)
diff --git a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/ShardingExtensions.h
similarity index 88%
rename from mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
rename to mlir/include/mlir/Dialect/Func/Extensions/ShardingExtensions.h
index 30d3033209d21..e22b24b3446bb 100644
--- a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
+++ b/mlir/include/mlir/Dialect/Func/Extensions/ShardingExtensions.h
@@ -1,4 +1,4 @@
-//===- MeshShardingExtensions.h - -----------------------------------------===//
+//===- ShardingExtensions.h - -----------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h
similarity index 54%
rename from mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
rename to mlir/include/mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h
index c57501ea86b7e..dc21bc05a2dc1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h
@@ -1,4 +1,4 @@
-//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
+//===- ShardingInterfaceImpl.h ----------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,15 +6,15 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
-#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+#ifndef MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H
+#define MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H
 
 namespace mlir {
 class DialectRegistry;
 
 namespace linalg {
-void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry);
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
 } // namespace linalg
 } // namespace mlir
 
-#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+#endif // MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
deleted file mode 100644
index f26c6285efd89..0000000000000
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ /dev/null
@@ -1,25 +0,0 @@
-add_mlir_doc(MeshOps MeshOps Dialects/ -gen-op-doc -dialect=mesh)
-add_mlir_doc(MeshOps MeshAttrs Dialects/ -gen-attrdef-doc -dialect=mesh)
-
-set(LLVM_TARGET_DEFINITIONS MeshOps.td)
-mlir_tablegen(MeshDialect.cpp.inc -gen-dialect-defs -dialect=mesh)
-mlir_tablegen(MeshDialect.h.inc -gen-dialect-decls -dialect=mesh)
-
-set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)
-
-set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
-mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
-
-set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls)
-mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs)
-
-set(LLVM_TARGET_DEFINITIONS MeshOps.td)
-mlir_tablegen(MeshOps.h.inc -gen-op-decls)
-mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
-
-add_public_tablegen_target(MLIRMeshIncGen)
-add_dependencies(mlir-headers MLIRMeshIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
deleted file mode 100644
index 8d768485103b6..0000000000000
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
-add_public_tablegen_target(MLIRMeshPassIncGen)
-add_dependencies(mlir-headers MLIRMeshPassIncGen)
-
-add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/CMakeLists.txt
similarity index 100%
rename from mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
rename to mlir/include/mlir/Dialect/Shard/CMakeLists.txt
diff --git a/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..a2495af135899
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt
@@ -0,0 +1,25 @@
+add_mlir_doc(ShardOps ShardOps Dialects/ -gen-op-doc -dialect=shard)
+add_mlir_doc(ShardOps ShardAttrs Dialects/ -gen-attrdef-doc -dialect=shard)
+
+set(LLVM_TARGET_DEFINITIONS ShardOps.td)
+mlir_tablegen(ShardDialect.cpp.inc -gen-dialect-defs -dialect=shard)
+mlir_tablegen(ShardDialect.h.inc -gen-dialect-decls -dialect=shard)
+
+set(LLVM_TARGET_DEFINITIONS ShardBase.td)
+mlir_tablegen(ShardAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(ShardAttributes.cpp.inc -gen-attrdef-defs)
+
+set(LLVM_TARGET_DEFINITIONS ShardBase.td)
+mlir_tablegen(ShardEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ShardEnums.cpp.inc -gen-enum-defs)
+
+set(LLVM_TARGET_DEFINITIONS ShardBase.td)
+mlir_tablegen(ShardTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(ShardTypes.cpp.inc -gen-typedef-defs)
+
+set(LLVM_TARGET_DEFINITIONS ShardOps.td)
+mlir_tablegen(ShardOps.h.inc -gen-op-decls)
+mlir_tablegen(ShardOps.cpp.inc -gen-op-defs)
+
+add_public_tablegen_target(MLIRShardIncGen)
+add_dependencies(mlir-headers MLIRShardIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Shard/IR/ShardBase.td
similarity index 64%
rename from mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
rename to mlir/include/mlir/Dialect/Shard/IR/ShardBase.td
index 61403ac178980..41ae31807c825 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardBase.td
@@ -1,4 +1,4 @@
-//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===//
+//===- ShardBase.td - Shard Dialect ------------------------*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD
-#define MLIR_DIALECT_MESH_IR_MESHBASE_TD
+#ifndef MLIR_DIALECT_SHARD_IR_SHARDBASE_TD
+#define MLIR_DIALECT_SHARD_IR_SHARDBASE_TD
 
 include "mlir/IR/OpBase.td"
 include "mlir/IR/AttrTypeBase.td"
@@ -16,15 +16,15 @@ include "mlir/IR/CommonAttrConstraints.td"
 include "mlir/IR/EnumAttr.td"
 
 //===----------------------------------------------------------------------===//
-// Mesh Dialect
+// Shard Dialect
 //===----------------------------------------------------------------------===//
 
-def Mesh_Dialect : Dialect {
-  let name = "mesh";
-  let cppNamespace = "::mlir::mesh";
+def Shard_Dialect : Dialect {
+  let name = "shard";
+  let cppNamespace = "::mlir::shard";
 
   let description = [{
-    See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
+    See [Shard dialect documentation](mlir/docs/Dialects/Shard.md).
   }];
 
   let dependentDialects = [
@@ -36,16 +36,16 @@ def Mesh_Dialect : Dialect {
   let hasConstantMaterializer = 1;
 }
 
-def Mesh_MeshAxis : I<16>;
-def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
-def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
+def Shard_GridAxis : I<16>;
+def Shard_GridAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+def Shard_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
 
 //===----------------------------------------------------------------------===//
-// Mesh Enums.
+// Shard Enums.
 //===----------------------------------------------------------------------===//
 
-def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
-  "Reduction of an iterator/mesh dimension.", [
+def Shard_ReductionKind : I32EnumAttr<"ReductionKind",
+  "Reduction of an iterator/grid dimension.", [
   I32EnumAttrCase<"Sum", 1, "sum">,
   I32EnumAttrCase<"Max", 2, "max">,
   I32EnumAttrCase<"Min", 3, "min">,
@@ -58,31 +58,31 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
   I32EnumAttrCase<"Generic", 100, "generic">
 ]> {
   let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::mesh";
+  let cppNamespace = "::mlir::shard";
 }
 
-def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
+def Shard_ReductionKindAttr : EnumAttr<Shard_Dialect, Shard_ReductionKind, "partial"> {
   let assemblyFormat = "$value";
 }
 
-class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
+class Shard_Type<string name, string typeMnemonic, list<Trait> traits = [],
                    string baseCppClass = "::mlir::Type">
-    : TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
+    : TypeDef<Shard_Dialect, name, traits, baseCppClass> {
   let mnemonic = typeMnemonic;
 }
 
-def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
+def Shard_Sharding : Shard_Type<"Sharding", "sharding"> {
   let summary = "sharding definition";
   let assemblyFormat = "";
 }
 
 //===----------------------------------------------------------------------===//
-// Mesh Attribute
+// Shard Attribute
 //===----------------------------------------------------------------------===//
 
-def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
+def Shard_GridAxesArrayAttr : AttrDef<Shard_Dialect, "GridAxesArray"> {
   let mnemonic = "axisarray";
-  let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
+  let parameters = (ins ArrayRefParameter<"GridAxesAttr">:$axes);
   let assemblyFormat = "`[` $axes `]`";
   let extraClassDeclaration = [{
     size_t size() const { return getAxes().size(); }
@@ -91,4 +91,4 @@ def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
   }];
 }
 
-#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
+#endif // MLIR_DIALECT_SHARD_IR_SHARDBASE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h b/mlir/include/mlir/Dialect/Shard/IR/ShardDialect.h
similarity index 57%
rename from mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
rename to mlir/include/mlir/Dialect/Shard/IR/ShardDialect.h
index a30cf91e851fe..4113a668d4b76 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardDialect.h
@@ -1,4 +1,4 @@
-//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===//
+//===- ShardOps.h - Shard Dialect -------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H
-#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+#ifndef MLIR_DIALECT_SHARD_IR_SHARDDIALECT_H
+#define MLIR_DIALECT_SHARD_IR_SHARDDIALECT_H
 
 #include "mlir/IR/Dialect.h"
 
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h.inc"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h.inc"
 
-#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+#endif // MLIR_DIALECT_SHARD_IR_SHARDDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.h
similarity index 52%
rename from mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
rename to mlir/include/mlir/Dialect/Shard/IR/ShardOps.h
index 7cfe59dd957ca..457fe6f6b8d0a 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.h
@@ -1,4 +1,4 @@
-//===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
+//===- ShardOps.h - Shard Dialect Operations --------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
-#define MLIR_DIALECT_MESH_IR_MESHOPS_H
+#ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_H
+#define MLIR_DIALECT_SHARD_IR_SHARDOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
@@ -21,45 +21,45 @@
 #include "llvm/Support/MathExtras.h"
 
 namespace mlir {
-namespace mesh {
+namespace shard {
 
-using MeshAxis = int16_t;
-using MeshAxesAttr = DenseI16ArrayAttr;
+using GridAxis = int16_t;
+using GridAxesAttr = DenseI16ArrayAttr;
 using ShardShapeAttr = DenseI64ArrayAttr;
 using HaloSizePairAttr = DenseI64ArrayAttr;
 
-} // namespace mesh
+} // namespace shard
 } // namespace mlir
 
-#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
+#include "mlir/Dialect/Shard/IR/ShardEnums.h.inc"
 
 #define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
+#include "mlir/Dialect/Shard/IR/ShardAttributes.h.inc"
 
 namespace mlir {
-namespace mesh {
+namespace shard {
 
-class MeshSharding {
+class Sharding {
 private:
-  ::mlir::FlatSymbolRefAttr mesh;
-  SmallVector<MeshAxesAttr> split_axes;
+  ::mlir::FlatSymbolRefAttr grid;
+  SmallVector<GridAxesAttr> split_axes;
   SmallVector<int64_t> static_halo_sizes;
   SmallVector<int64_t> static_sharded_dims_offsets;
   SmallVector<Value> dynamic_halo_sizes;
   SmallVector<Value> dynamic_sharded_dims_offsets;
 
 public:
-  MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
-  MeshSharding(Value rhs);
-  static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
-                          ArrayRef<MeshAxesAttr> split_axes_,
-                          ArrayRef<int64_t> static_halo_sizes_ = {},
-                          ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
-                          ArrayRef<Value> dynamic_halo_sizes_ = {},
-                          ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
-  ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
-  ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
-  ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
+  Sharding(::mlir::FlatSymbolRefAttr grid_ = nullptr);
+  Sharding(Value rhs);
+  static Sharding get(::mlir::FlatSymbolRefAttr grid_,
+                      ArrayRef<GridAxesAttr> split_axes_,
+                      ArrayRef<int64_t> static_halo_sizes_ = {},
+                      ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
+                      ArrayRef<Value> dynamic_halo_sizes_ = {},
+                      ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
+  ::mlir::FlatSymbolRefAttr getGridAttr() const { return grid; }
+  ::llvm::StringRef getGrid() const { return grid ? grid.getValue() : ""; }
+  ArrayRef<GridAxesAttr> getSplitAxes() const { return split_axes; }
   ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
   ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
     return static_sharded_dims_offsets;
@@ -68,28 +68,28 @@ class MeshSharding {
   ArrayRef<Value> getDynamicShardedDimsOffsets() const {
     return dynamic_sharded_dims_offsets;
   }
-  operator bool() const { return (!mesh) == false; }
+  operator bool() const { return (!grid) == false; }
   bool operator==(Value rhs) const;
   bool operator!=(Value rhs) const;
-  bool operator==(const MeshSharding &rhs) const;
-  bool operator!=(const MeshSharding &rhs) const;
-  bool equalSplitAxes(const MeshSharding &rhs) const;
-  bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
-  bool equalHaloSizes(const MeshSharding &rhs) const;
-  bool equalShardSizes(const MeshSharding &rhs) const;
+  bool operator==(const Sharding &rhs) const;
+  bool operator!=(const Sharding &rhs) const;
+  bool equalSplitAxes(const Sharding &rhs) const;
+  bool equalHaloAndShardSizes(const Sharding &rhs) const;
+  bool equalHaloSizes(const Sharding &rhs) const;
+  bool equalShardSizes(const Sharding &rhs) const;
 };
 
-} // namespace mesh
+} // namespace shard
 } // namespace mlir
 
 #define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
+#include "mlir/Dialect/Shard/IR/ShardTypes.h.inc"
 
 #define GET_OP_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
+#include "mlir/Dialect/Shard/IR/ShardOps.h.inc"
 
 namespace mlir {
-namespace mesh {
+namespace shard {
 
 inline bool isReductionLoop(utils::IteratorType iType) {
   return iType == utils::IteratorType::reduction;
@@ -103,52 +103,52 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 }
 
 // Is the same tensor replicated on all processes.
-inline bool isFullReplication(MeshSharding sharding) {
-  return llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
+inline bool isFullReplication(Sharding sharding) {
+  return llvm::all_of(sharding.getSplitAxes(), [](GridAxesAttr axes) {
     return axes.asArrayRef().empty();
   });
 }
 
-inline mesh::MeshOp
-getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
+inline shard::GridOp
+getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol,
               SymbolTableCollection &symbolTableCollection) {
-  if (!meshSymbol)
+  if (!gridSymbol)
     return nullptr;
-  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
-      op, meshSymbol);
+  return symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
+      op, gridSymbol);
 }
 
-inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                            SymbolTableCollection &symbolTableCollection) {
-  mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
-  assert(meshOp);
-  return meshOp;
+inline shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol,
+                             SymbolTableCollection &symbolTableCollection) {
+  shard::GridOp gridOp = getGridOrNull(op, gridSymbol, symbolTableCollection);
+  assert(gridOp);
+  return gridOp;
 }
 
-// Get the corresponding mesh op using the standard attribute nomenclature.
+// Get the corresponding grid op using the standard attribute nomenclature.
 template <typename Op>
-mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
-  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
+shard::GridOp getGrid(Op op, SymbolTableCollection &symbolTableCollection) {
+  return getGrid(op.getOperation(), op.getGridAttr(), symbolTableCollection);
 }
 
 template <>
-inline mesh::MeshOp
-getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
-  return getMesh(
+inline shard::GridOp
+getGrid<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
+  return getGrid(
       op.getOperation(),
-      cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
+      cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr(),
       symbolTableCollection);
 }
 
 // Get the number of processes that participate in each group
-// induced by `meshAxes`.
-template <typename MeshAxesRange, typename MeshShapeRange>
-int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
-                                   MeshShapeRange &&meshShape) {
+// induced by `gridAxes`.
+template <typename GridAxesRange, typename GridShapeRange>
+int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes,
+                                   GridShapeRange &&gridShape) {
   int64_t res = 1;
 
-  for (MeshAxis axis : meshAxes) {
-    auto axisSize = *(std::begin(meshShape) + axis);
+  for (GridAxis axis : gridAxes) {
+    auto axisSize = *(std::begin(gridShape) + axis);
     if (ShapedType::isDynamic(axisSize)) {
       return ShapedType::kDynamic;
     }
@@ -158,10 +158,10 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
   return res;
 }
 
-template <typename MeshAxesRange>
-int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
-  return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
-                                    mesh.getShape());
+template <typename GridAxesRange>
+int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridOp grid) {
+  return collectiveProcessGroupSize(std::forward<GridAxesRange>(gridAxes),
+                                    grid.getShape());
 }
 
 // Get the size of a sharded dimension.
@@ -182,27 +182,25 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
 }
 
 // Return the sharded shape `shape` according ot sharding `sharding`.
-// The shape for the tensor on each device in the mesh.
+// The shape for the tensor on each device in the grid.
 // Example:
-// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
+// On a 2x4x? grid with split axes = [[0], [1], [2]] the shape ?x5x1 would
 // result in a shape for each shard of ?x2x?.
-ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
-                           MeshSharding sharding);
+ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding);
 
 // If ranked tensor type return its sharded counterpart.
 //
 // If not ranked tensor type return `type`.
 // `sharding` in that case must be null.
-Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
+Type shardType(Type type, GridOp grid, Sharding sharding);
 
 // Insert shard op if there is not one that already has the same sharding.
 // Use newShardOp if it is not null. Otherwise create a new one.
 // May insert resharding if required.
 // Potentially updates newShardOp.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
+void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result,
                                          OpBuilder &builder);
-void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
-                                         OpOperand &operand,
+void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand,
                                          OpBuilder &builder);
 
 /// Converts a vector of OpFoldResults (ints) into vector of Values of the
@@ -210,7 +208,7 @@ void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
 SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
                                     llvm::ArrayRef<int64_t> statics,
                                     ValueRange dynamics, Type type = Type());
-} // namespace mesh
+} // namespace shard
 } // namespace mlir
 
-#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
+#endif // MLIR_DIALECT_SHARD_IR_SHARDOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
similarity index 70%
rename from mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
rename to mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 1662885c161e6..29b384f401876 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -1,4 +1,4 @@
-//===-- MeshOps.td - Mesh dialect operation definitions ----*- tablegen -*-===//
+//===-- ShardOps.td - Shard dialect operation definitions ----*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_TD
-#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
+#ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_TD
+#define MLIR_DIALECT_SHARD_IR_SHARDOPS_TD
 
-include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Dialect/Shard/IR/ShardBase.td"
 include "mlir/Dialect/Shape/IR/ShapeBase.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -21,24 +21,24 @@ include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
-// Mesh operations.
+// Shard operations.
 //===----------------------------------------------------------------------===//
 
-class Mesh_Op<string mnemonic, list<Trait> traits = []> :
-    Op<Mesh_Dialect, mnemonic, traits> {
+class Shard_Op<string mnemonic, list<Trait> traits = []> :
+    Op<Shard_Dialect, mnemonic, traits> {
 }
 
-def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
-  let summary = "Description of a device/process mesh.";
+def Shard_GridOp : Shard_Op<"grid", [Symbol, Pure]> {
+  let summary = "Description of a device/process grid.";
   let description = [{
-    The mesh.mesh operation is a symbol operation that identifies a specific
-    mesh. The operation has three attributes:
+    The shard.grid operation is a symbol operation that identifies a specific
+    grid. The operation has three attributes:
 
-    1. `sym_name`: This attribute uniquely identifies the name of the mesh.
-    This name serves as a symbolic reference to the mesh throughout
+    1. `sym_name`: This attribute uniquely identifies the name of the grid.
+    This name serves as a symbolic reference to the grid throughout
     the MLIR module, allowing for consistent referencing and easier debugging.
 
-    2. `shape`: This attribute represents the shape of the device mesh.
+    2. `shape`: This attribute represents the shape of the device grid.
     It uses the same notation as a tensor shape. Also allowing for dynamic
     dimensions.
     This flexibility allows for dynamic device assignment or configurations
@@ -48,21 +48,21 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
 
     Example:
     ```
-    // A device mesh with 3 axes, the total device number is 4 * 8 * 12
+    // A device grid with 3 axes, the total device number is 4 * 8 * 12
     // The dimension sizes are 4, 8, 12 
-    mesh.mesh @mesh0(shape = 4x8x12)
+    shard.grid @grid0(shape = 4x8x12)
 
-    // A device mesh with 2 axes, the total device number is unknown
+    // A device grid with 2 axes, the total device number is unknown
     // The first dimension size is 4 and the second is unknown
-    mesh.mesh @mesh1(shape = 4x?)
+    shard.grid @grid1(shape = 4x?)
 
-    // A device mesh with 2 axes, the total device number is unknown
+    // A device grid with 2 axes, the total device number is unknown
     // The first dimension size is unknown and the second is 4
-    mesh.mesh @mesh2(shape = ?x4)
+    shard.grid @grid2(shape = ?x4)
 
-    // A device mesh with 2 axes, the number of devices along both axes
+    // A device grid with 2 axes, the number of devices along both axes
     // is unknown
-    mesh.mesh @mesh3(shape = ?x?)
+    shard.grid @grid3(shape = ?x?)
     ```
   }];
   let arguments = (ins
@@ -79,15 +79,15 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
   let hasVerifier = 1;
 }
 
-def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
+def Shard_GridShapeOp : Shard_Op<"grid_shape", [
     Pure,
     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
   ]> {
-  let summary = "Get the shape of the mesh.";
+  let summary = "Get the shape of the grid.";
   let arguments = (ins
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+    FlatSymbolRefAttr:$grid,
+    DefaultValuedAttr<Shard_GridAxesAttr, "{}">:$axes
   );
 
   let results = (outs
@@ -95,46 +95,46 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
   );
 
   let assemblyFormat = [{
-    $mesh (`axes` `=` $axes^)?
+    $grid (`axes` `=` $axes^)?
     attr-dict `:` type($result)
   }];
 
   let builders = [
-    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
-    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
-    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+    OpBuilder<(ins "::mlir::shard::GridOp":$grid)>,
+    OpBuilder<(ins "::mlir::shard::GridOp":$grid, "ArrayRef<GridAxis>":$axes)>,
+    OpBuilder<(ins "StringRef":$grid, "ArrayRef<GridAxis>":$axes)>
   ];
 }
 
-def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+def Shard_ProcessMultiIndexOp : Shard_Op<"process_multi_index", [
   Pure,
   DeclareOpInterfaceMethods<SymbolUserOpInterface>,
   DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
-  let summary = "Get the multi index of current device along specified mesh axes.";
+  let summary = "Get the multi index of current device along specified grid axes.";
   let description = [{
     It is used in the SPMD format of IR.
-    The `axes` mush be non-negative and less than the total number of mesh axes.
+    The `axes` mush be non-negative and less than the total number of grid axes.
     If the axes are empty then get the index along all axes.
   }];
   let arguments = (ins
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+    FlatSymbolRefAttr:$grid,
+    DefaultValuedAttr<Shard_GridAxesAttr, "{}">:$axes
   );
   let results = (outs
     Variadic<Index>:$result
   );
   let assemblyFormat = [{
-    `on` $mesh (`axes` `=` $axes^)?
+    `on` $grid (`axes` `=` $axes^)?
     attr-dict `:` type($result)
   }];
   let builders = [
-    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
-    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+    OpBuilder<(ins "::mlir::shard::GridOp":$grid)>,
+    OpBuilder<(ins "StringRef":$grid, "ArrayRef<GridAxis>":$axes)>
   ];
 }
 
-def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+def Shard_ProcessLinearIndexOp : Shard_Op<"process_linear_index", [
   Pure,
   DeclareOpInterfaceMethods<SymbolUserOpInterface>,
   DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
@@ -143,34 +143,34 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
   let description = [{
     Example:
     ```
-    %idx = mesh.process_linear_index on @mesh : index
+    %idx = shard.process_linear_index on @grid : index
     ```
-    if `@mesh` has shape `(10, 20, 30)`, a device with multi
+    if `@grid` has shape `(10, 20, 30)`, a device with multi
     index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
   }];
-  let arguments = (ins FlatSymbolRefAttr:$mesh);
+  let arguments = (ins FlatSymbolRefAttr:$grid);
   let results = (outs Index:$result);
-  let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+  let assemblyFormat = "`on` $grid attr-dict `:` type($result)";
   let builders = [
-    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
+    OpBuilder<(ins "::mlir::shard::GridOp":$grid)>
   ];
 }
 
-def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
+def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [
   Pure,
   DeclareOpInterfaceMethods<SymbolUserOpInterface>,
   DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
   let summary =
-      "For given mesh index get the linear indices of the direct neighbor processes along the given split.";
+      "For given grid index get the linear indices of the direct neighbor processes along the given split.";
   let description = [{
     Example:
     ```
-    mesh.mesh @mesh0(shape = 10x20x30)
+    shard.grid @grid0(shape = 10x20x30)
     %c1 = arith.constant 1 : index
     %c2 = arith.constant 2 : index
     %c3 = arith.constant 3 : index
-    %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index
+    %idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index
     ```
     The above returns two indices, `633` and `693`, which correspond to the
     index of the previous process `(1, 1, 3)`, and the next process 
@@ -179,12 +179,12 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
     A negative value is returned if there is no neighbor in the respective
     direction along the given `split_axes`.
   }];
-  let arguments = (ins FlatSymbolRefAttr:$mesh,
+  let arguments = (ins FlatSymbolRefAttr:$grid,
                        Variadic<Index>:$device,
-                       Mesh_MeshAxesAttr:$split_axes);
+                       Shard_GridAxesAttr:$split_axes);
   let results = (outs Index:$neighbor_down, Index:$neighbor_up);
   let assemblyFormat =  [{
-      `on` $mesh `[` $device `]`
+      `on` $grid `[` $device `]`
       `split_axes` `=` $split_axes
       attr-dict `:` type(results)
   }];
@@ -194,7 +194,7 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
 // Sharding operations.
 //===----------------------------------------------------------------------===//
 
-def Mesh_ShardingOp : Mesh_Op<"sharding", [
+def Shard_ShardingOp : Shard_Op<"sharding", [
     Pure,
     AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
@@ -202,18 +202,18 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
   ]> {
   let summary = "Define a sharding of a tensor.";
   let description = [{
-    The MeshSharding specifies how a tensor is sharded and distributed across the
-    process mesh. It is typically used in a `mesh.shard` operation.
+    The Sharding specifies how a tensor is sharded and distributed across the
+    process shard. It is typically used in a `shard.shard` operation.
     The operation has the following attributes and operands:
 
-    1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
-    mesh where the distributed tensor is placed. The symbol must resolve to a
-    `mesh.mesh` operation.
+    1. `grid`: this attribute is a FlatSymbolRefAttr that refers to the device
+    grid where the distributed tensor is placed. The symbol must resolve to a
+    `shard.grid` operation.
 
     2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
     maximum size is the `rank` of the related tensor. For the i-th sub-array, if
     its value is [x, y], it indicates that the tensor's i-th dimension is splitted
-    along the x and y axes of the device mesh.
+    along the x and y axes of the device grid.
 
     3. [Optional] Sizes of halos to be added for each sharded tensor dimension.
     `halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
@@ -233,7 +233,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     
     Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
     `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
-    the device-mesh will get a shard of shape 24x20x32 and the second device will get
+    the device-grid will get a shard of shape 24x20x32 and the second device will get
     a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
     
     `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
@@ -241,101 +241,101 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
     Examples:
 
     ```
-    mesh.mesh @mesh0(shape = 2x2x4)
-    mesh.mesh @mesh1d_4(shape = 4)
+    shard.grid @grid0(shape = 2x2x4)
+    shard.grid @grid1d_4(shape = 4)
 
-    // The tensor is fully replicated on @mesh0.
+    // The tensor is fully replicated on @grid0.
     // Currently, there must be at least one sub-array present in axes, even
     // if it's empty. Otherwise, a parsing error will occur.
-    %sharding0 = mesh.sharding @mesh0 split_axes = [[]]
+    %sharding0 = shard.sharding @grid0 split_axes = [[]]
 
-    // The tensor is sharded on the first dimension along axis 0 of @mesh0
-    %sharding1 = mesh.sharding @mesh0 split_axes = [[0]]
+    // The tensor is sharded on the first dimension along axis 0 of @grid0
+    %sharding1 = shard.sharding @grid0 split_axes = [[0]]
 
-    // Could be used for a mesh.shard op
-    %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
+    // Could be used for a shard.shard op
+    %sharded0 = shard.shard %arg0 to %sharding3 : tensor<4x8xf32>
 
-    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+    // The tensor is sharded on its first dimension along axis 0 of @grid0 and
     // and it has halo-sizes of 1 and 2 on the sharded dim.
-    %halo_sharding = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2]
-    %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>
+    %halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
+    %sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>
     
-    // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
+    // The tensor is sharded on its second dimension along axis 0 of @grid1d_4
     // and it has pre-defined shard sizes. The shards of the devices will have
     // the following shapes: [4x2, 4x3, 4x4, 4x5]
-    %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
-    %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
+    %sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
+    %sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];
     
   let arguments = (ins
-    FlatSymbolRefAttr:$mesh,
-    Mesh_MeshAxesArrayAttr:$split_axes,
+    FlatSymbolRefAttr:$grid,
+    Shard_GridAxesArrayAttr:$split_axes,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
     Variadic<I64>:$dynamic_sharded_dims_offsets,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
     Variadic<I64>:$dynamic_halo_sizes
   );
   let results = (outs
-    Mesh_Sharding:$result
+    Shard_Sharding:$result
   );
   let assemblyFormat = [{
-    $mesh
+    $grid
     `split_axes` `=` $split_axes
     (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
     (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
     attr-dict `:` type($result)
   }];
   let builders = [
-    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
-                   "ArrayRef<MeshAxesAttr>":$split_axes,
+    OpBuilder<(ins "FlatSymbolRefAttr":$grid,
+                   "ArrayRef<GridAxesAttr>":$split_axes,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
-    OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
-                   "ArrayRef<MeshAxesAttr>":$split_axes,
+    OpBuilder<(ins "FlatSymbolRefAttr":$grid,
+                   "ArrayRef<GridAxesAttr>":$split_axes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
                    "::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
-    OpBuilder<(ins "llvm::StringRef":$mesh,
-                   "ArrayRef<MeshAxesAttr>":$split_axes,
+    OpBuilder<(ins "llvm::StringRef":$grid,
+                   "ArrayRef<GridAxesAttr>":$split_axes,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
                    CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
     )>,
-    OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
+    OpBuilder<(ins "mlir::shard::Sharding":$from)>
   ];
   let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
 
-def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
+def Shard_GetShardingOp : Shard_Op<"get_sharding", [Pure]> {
   let summary = "Get the sharding of the given tensor.";
   let description = [{
-    This operation returns the sharding of the given tensor as a MeshSharding.
+    This operation returns the sharding of the given tensor as a Sharding.
   }];
   let arguments = (ins
     AnyRankedTensor:$source
   );
   let results = (outs
-    Mesh_Sharding:$result
+    Shard_Sharding:$result
   );
   let assemblyFormat = [{
     $source attr-dict `:` type($source) `->` type($result)
   }];
 }
 
-def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
+def Shard_ShardShapeOp : Shard_Op<"shard_shape", [
     Pure, AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
   ]> {
   let summary = "Get the shard shape for a given process/device.";
   let description = [{
-    The device/process id is a multi-index of the device/process in the mesh.
-    This operation might be used during spmdization when the shard shape depends
-    on (non-constant) values used in `mesh.sharding`.
+    The device/process id is a multi-index of the device/process in the shard.
+    This operation might be used during partition when the shard shape depends
+    on (non-constant) values used in `shard.sharding`.
   }];
   let arguments = (ins
     DenseI64ArrayAttr:$dims,
     Variadic<Index>:$dims_dynamic,
-    Mesh_Sharding:$sharding,
+    Shard_Sharding:$sharding,
     DenseI64ArrayAttr:$device,
     Variadic<Index>:$device_dynamic
   );
@@ -351,23 +351,23 @@ def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
   ];
 }
 
-def Mesh_ShardOp : Mesh_Op<"shard", [
+def Shard_ShardOp : Shard_Op<"shard", [
     Pure,
     AllTypesMatch<["result", "src"]>,
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
   ]> {
-  let summary = "Annotate on how a tensor is sharded across a mesh.";
+  let summary = "Annotate on how a tensor is sharded across a shard.";
   let description = [{
-    The mesh.shard operation is designed to specify and guide the sharding
-    behavior of a tensor value across a mesh topology. This operation has two
+    The shard.shard operation is designed to specify and guide the sharding
+    behavior of a tensor value across a grid topology. This operation has two
     operands and two optional attributes:
 
     1. `input`: This operand represents the tensor value that needs to be
     annotated for sharding.
 
-    2. `sharding`: This attribute is type of `MeshShardingType`, which is the core data
-    structure to represent distribution of a tensor on a mesh. it is typically defiend
-    by an `mesh.sharding` operation.
+    2. `sharding`: This attribute is type of `ShardingType`, which is the core data
+    structure to represent distribution of a tensor on a shard. it is typically defined
+    by an `shard.sharding` operation.
 
     3. `annotate_for_users`: A unit attribute addressing the scenario when a
     tensor's sharding annotation differs based on its context of use (either as
@@ -378,36 +378,36 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
 
     Example:
     ```
-    func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
-      %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
-      %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
+  func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
+      %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+      %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
       ...
     }
 
     func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
-      %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
-      %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
+      %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+      %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
       ...
     }
     
     func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
-      %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
-      %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
-      %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
+      %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+      %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
+      %1 = shard.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
       ...
     }
 
-    // The first mesh.shard op applies to %arg0, the second mesh.shard op
-    // applies for the operand of op0, the third mesh.shard op applies for the
+    // The first shard.shard op applies to %arg0, the second shard.shard op
+    // applies for the operand of op0, the third shard.shard op applies for the
     // operand of op2
     func.func @both_result_and_multi_operands_annotated(
         %arg0 : tensor<4x8xf32>) -> () {
-      %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
-      %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
-      %sharding1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
-      %1 = mesh.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
-      %sharding2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
-      %2 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
+      %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+      %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
+      %sharding1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+      %1 = shard.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+      %sharding2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding
+      %2 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
       "op0"(%1) : ...
       "op1"(%2) : ...
       ...
@@ -418,44 +418,44 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
     ```
     func.func @annotate_on_same_result_with_different_sharding(
         %arg0 : tensor<4x8xf32>) -> () {
-      %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
-      %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
-      %0 = mesh.shard %arg0 to $sharding1 : tensor<4x8xf32>
-      %1 = mesh.shard %0 to sharding2 : tensor<4x8xf32>
+      %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+      %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+      %0 = shard.shard %arg0 to $sharding1 : tensor<4x8xf32>
+      %1 = shard.shard %0 to sharding2 : tensor<4x8xf32>
       ...
     }
 
     func.func @annotate_on_same_result_same_value_with_different_sharding(
         %arg0 : tensor<4x8xf32>) -> () {
-      %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
-      %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
-      %0 = mesh.shard %arg0 to %sharding1 : tensor<4x8xf32>
-      %1 = mesh.shard %arg0 to %sharding2 : tensor<4x8xf32>
+      %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+      %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+      %0 = shard.shard %arg0 to %sharding1 : tensor<4x8xf32>
+      %1 = shard.shard %arg0 to %sharding2 : tensor<4x8xf32>
       ...
     }
 
     func.func @annotate_on_same_operand_with_different_sharding(
         %arg0 : tensor<4x8xf32>) -> () {
-      %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
-      %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
-      %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
-      %1 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
+      %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+      %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+      %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+      %1 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
       ...
     }
 
     func.func @result_annotated_after_operand(
         %arg0 : tensor<4x8xf32>) -> () {
-      %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
-      %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
-      %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
-      %1 = mesh.shard %0 to %sharding2 : tensor<4x8xf32>
+      %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+      %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+      %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+      %1 = shard.shard %0 to %sharding2 : tensor<4x8xf32>
       ...
     }
     ```
   }];
   let arguments = (ins
     AnyRankedTensor:$src,
-    Mesh_Sharding:$sharding,
+    Shard_Sharding:$sharding,
     UnitAttr:$annotate_for_users
   );
   let results = (outs
@@ -473,34 +473,34 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
 // collective communication ops
 //===----------------------------------------------------------------------===//
 
-class Mesh_CollectiveCommunicationOpBase<
+class Shard_CollectiveCommunicationOpBase<
     string mnemonic, list<Trait> traits = []> :
-    Mesh_Op<mnemonic,
+    Shard_Op<mnemonic,
       !listconcat(traits,
         [
           DeclareOpInterfaceMethods<SymbolUserOpInterface>,
           DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
         ])> {
   dag commonArgs = (ins
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
+    FlatSymbolRefAttr:$grid,
+    DefaultValuedAttr<Shard_GridAxesAttr, "{}">:$grid_axes
   );
 }
 
-def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
     Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultRank,
   ]> {
-  let summary = "All-gather over a device mesh.";
+  let summary = "All-gather over a device grid.";
   let description = [{
     Gathers along the `gather_axis` tensor axis.
 
     Example:
     ```mlir
-    mesh.mesh @mesh0(shape = 2x2)
+    shard.grid @grid0(shape = 2x2)
     ...
-    %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
+    %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1
       : tensor<2x2xi8> -> tensor<2x4xi8>
     ```
     Input:
@@ -535,16 +535,16 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
     AnyNon0RankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasCanonicalizer = 1;
 }
 
-def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [
     Pure,
     SameOperandsAndResultShape]> {
-  let summary = "All-reduce over a device mesh.";
+  let summary = "All-reduce over a device grid.";
   let description = [{
     The accumulation element type is specified by the result type and
     it does not need to match the input element type.
@@ -556,34 +556,34 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
 
     Example:
     ```
-    %1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
+    %1 = shard.all_reduce %0 on @grid0 grid_axes = [1, 0] reduction = <max>
       : tensor<3x4xf32> -> tensor<3x4xf64>
     ```
   }];
   let arguments = !con(commonArgs, (ins
     AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
-    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
+    DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction
   ));
   let results = (outs
     AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)? (`reduction` `=` $reduction^)?
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasCanonicalizer = 1;
     let builders = [
-    OpBuilder<(ins "Value":$input, "StringRef":$mesh,
-      "ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)>
+    OpBuilder<(ins "Value":$input, "StringRef":$grid,
+      "ArrayRef<GridAxis>":$gridAxes, "ReductionKind":$reduction)>
   ];
 }
 
-def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
+def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
     Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultRank
   ]> {
-  let summary = "All-slice over a device mesh. This is the inverse of all-gather.";
+  let summary = "All-slice over a device grid. This is the inverse of all-gather.";
   let description = [{
     Slice along the `slice_axis` tensor axis.
     This operation can be thought of as the inverse of all-gather.
@@ -593,9 +593,9 @@ def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
 
     Example:
     ```mlir
-    mesh.mesh @mesh0(shape = 2x2)
+    shard.grid @grid0(shape = 2x2)
     ...
-    %1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
+    %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1
       : tensor<2x4xi8> -> tensor<2x2xi8>
     ```
     Input:
@@ -630,30 +630,30 @@ def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
     AnyNon0RankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)? `slice_axis` `=` $slice_axis
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasCanonicalizer = 1;
   let builders = [
-    OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>,
-    OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>
+    OpBuilder<(ins "Value":$input, "GridOp":$grid, "ArrayRef<GridAxis>":$gridAxes, "int64_t":$sliceAxis)>,
+    OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$grid, "ArrayRef<GridAxis>":$gridAxes, "int64_t":$sliceAxis)>
   ];
 }
 
-def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [
     Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultRank]> {
-  let summary = "All-to-all over a device mesh.";
+  let summary = "All-to-all over a device grid.";
   let description = [{
     Performs an all-to-all on tensor pieces split along `split_axis`.
     The resulting pieces are concatenated along `concat_axis` on ech device.
 
     Example:
     ```
-    mesh.mesh @mesh0(shape = 3)
+    shard.grid @grid0(shape = 3)
     ...
-    %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
+    %1 = shard.all_to_all %0 on @grid0 grid_axes = [0]
       split_axis = 0 concat_axis = 0
       : tensor<3x2xi8> -> tensor<3x2xi8>
     ```
@@ -687,7 +687,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
     AnyNon0RankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     `split_axis` `=` $split_axis
     `concat_axis` `=` $concat_axis
     attr-dict `:` type($input) `->` type($result)
@@ -695,24 +695,24 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
   let hasCanonicalizer = 1;
 }
 
-def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
+def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
     Pure,
     AllShapesMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
-  let summary = "Broadcast over a device mesh.";
+  let summary = "Broadcast over a device grid.";
   let description = [{
     Broadcast the tensor on `root` to all devices in each respective group.
-    The operation broadcasts along mesh axes `mesh_axes`.
+    The operation broadcasts along grid axes `grid_axes`.
     The `root` device specifies the in-group multi-index that is broadcast to
     all other devices in the group.
     
     Example:
     ```
-    mesh.mesh @mesh0(shape = 2x2)
+    shard.grid @grid0(shape = 2x2)
 
-    %1 = mesh.broadcast %0 on @mesh0
-      mesh_axes = [0]
+    %1 = shard.broadcast %0 on @grid0
+      grid_axes = [0]
       root = [0]
       : (tensor<2xi8>) -> tensor<2xi8>
     ```
@@ -744,31 +744,31 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
     AnyRankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
   let hasCanonicalizer = 1;
 }
 
-def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
+def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [
     Pure,
     AllRanksMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
-  let summary = "Gather over a device mesh.";
+  let summary = "Gather over a device grid.";
   let description = [{
     Gathers on device `root` along the `gather_axis` tensor axis.
-    `root` specifies the coordinates of a device along `mesh_axes`.
+    `root` specifies the coordinates of a device along `grid_axes`.
     It uniquely identifies the root device for each device group.
     The result tensor on non-root devices is undefined.
     Using it will result in undefined behavior.
 
     Example:
     ```mlir
-    mesh.mesh @mesh0(shape = 2x2)
+    shard.grid @grid0(shape = 2x2)
     ...
-    %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
+    %1 = shard.gather %0 on @grid0 grid_axes = [1]
       gather_axis = 1 root = [1]
       : (tensor<2x2xi8>) -> tensor<2x4xi8>
     ```
@@ -807,7 +807,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
     AnyNon0RankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     `gather_axis` `=` $gather_axis
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
@@ -815,11 +815,11 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
   let hasCanonicalizer = 1;
 }
 
-def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
+def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [
     AllShapesMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
-  let summary = "Send over a device mesh.";
+  let summary = "Send over a device grid.";
   let description = [{
     Receive from a device within a device group.
   }];
@@ -832,21 +832,21 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
     AnyRankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
     attr-dict `:` functional-type(operands, results)
   }];
   let hasCanonicalizer = 1;
 }
 
-def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
+def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [
     Pure,
     AllShapesMatch<["input", "result"]>
   ]> {
-  let summary = "Reduce over a device mesh.";
+  let summary = "Reduce over a device grid.";
   let description = [{
     Reduces on device `root` within each device group.
-    `root` specifies the coordinates of a device along `mesh_axes`.
+    `root` specifies the coordinates of a device along `grid_axes`.
     It uniquely identifies the root device within its device group.
     The accumulation element type is specified by the result type and
     it does not need to match the input element type.
@@ -858,14 +858,14 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
 
     Example:
     ```
-    %1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
+    %1 = shard.reduce %0 on @grid0 grid_axes = [1, 0]
       reduction = <max> root = [2, 3]
       : (tensor<3x4xf32>) -> tensor<3x4xf64>
     ```
   }];
   let arguments = !con(commonArgs, (ins
     AnyRankedTensor:$input,
-    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
+    DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction,
     DenseI64ArrayAttr:$root,
     Variadic<Index>:$root_dynamic
   ));
@@ -873,7 +873,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
     AnyRankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     (`reduction` `=` $reduction^)?
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
@@ -881,19 +881,19 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
   let hasCanonicalizer = 1;
 }
 
-def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter", [
     Pure,
     SameOperandsAndResultRank]> {
-  let summary = "Reduce-scatter over a device mesh.";
+  let summary = "Reduce-scatter over a device grid.";
   let description = [{
     After the reduction, the result is scattered within each device group.
     The tensor is split along `scatter_axis` and the pieces distributed
     across the device group.
     Example:
     ```
-    mesh.mesh @mesh0(shape = 2x2)
+    shard.grid @grid0(shape = 2x2)
     ...
-    %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
+    %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
       reduction = <max> scatter_axis = 0
       : tensor<3x4xf32> -> tensor<1x4xf64>
     ```
@@ -928,14 +928,14 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
+    DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction,
     IndexAttr:$scatter_axis
   ));
   let results = (outs
     AnyRankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     (`reduction` `=` $reduction^)?
     `scatter_axis` `=` $scatter_axis
     attr-dict `:` type($input) `->` type($result)
@@ -943,20 +943,20 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
   let hasCanonicalizer = 1;
 }
 
-def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
+def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
     Pure,
     AllRanksMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
-  let summary = "Scatter over a device mesh.";
+  let summary = "Scatter over a device grid.";
   let description = [{
     For each device group split the input tensor on the `root` device along
     axis `scatter_axis` and scatter the parts across the group devices.
 
     Example:
     ```
-    mesh.mesh @mesh0(shape = 2x2)
-    %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
+    shard.grid @grid0(shape = 2x2)
+    %1 = shard.scatter %0 on @grid0 grid_axes = [0]
       scatter_axis = 0
       root = [1]
       : (tensor<2x2xi8>) -> tensor<1x2xi8>
@@ -1004,7 +1004,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
     AnyRankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     `scatter_axis` `=` $scatter_axis
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
@@ -1012,11 +1012,11 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
   let hasCanonicalizer = 1;
 }
 
-def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
+def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [
     AllShapesMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
-  let summary = "Send over a device mesh.";
+  let summary = "Send over a device grid.";
   let description = [{
     Send from one device to another within a device group.
   }];
@@ -1029,38 +1029,38 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
     AnyRankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
     attr-dict `:` functional-type(operands, results)
   }];
   let hasCanonicalizer = 1;
 }
 
-def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
+def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [
     Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultShape
   ]> {
-  let summary = "Shift over a device mesh.";
+  let summary = "Shift over a device grid.";
   let description = [{
-    Within each device group shift along mesh axis `shift_axis` by an offset
+    Within each device group shift along grid axis `shift_axis` by an offset
     `offset`.
     The result on devices that do not have a corresponding source is undefined.
-    `shift_axis` must be one of `mesh_axes`.
+    `shift_axis` must be one of `grid_axes`.
     If the `rotate` attribute is present,
     instead of a shift a rotation is done.
 
     Example:
     ```
-    mesh.mesh @mesh0(shape = 2x4)
-    %1 = mesh.shift on @mesh0 mesh_axes = [1]
+    shard.grid @grid0(shape = 2x4)
+    %1 = shard.shift on @grid0 grid_axes = [1]
       shift_axis = 1 offset = 2 rotate
       : tensor<2xi8> -> tensor<2xi8>
     ```
 
     Input:
     ```
-    mesh axis 1
+    grid axis 1
     ----------->
 
     +----+----+----+----+
@@ -1089,7 +1089,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
     AnyRankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     `shift_axis` `=` $shift_axis
     `offset` `=` $offset
     (`rotate` $rotate^)?
@@ -1098,7 +1098,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
   let hasCanonicalizer = 1;
 }
 
-def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
+def Shard_UpdateHaloOp : Shard_Op<"update_halo", [
   Pure,
   DestinationStyleOpInterface,
   TypesMatchWith<
@@ -1120,14 +1120,14 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     `destination_halo_sizes/static_destination_halo_sizes` in source shard
     and destination/result shard.
 
-    `split_axes` specifies for each tensor axis along which mesh axes its halo
+    `split_axes` specifies for each tensor axis along which grid axes its halo
     data is updated.
 
   }];
   let arguments = (ins
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
-    FlatSymbolRefAttr:$mesh,
-    Mesh_MeshAxesArrayAttr:$split_axes,
+    FlatSymbolRefAttr:$grid,
+    Shard_GridAxesArrayAttr:$split_axes,
     Variadic<I64>:$halo_sizes,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes
   );
@@ -1136,7 +1136,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
   );
   let assemblyFormat = [{
     $destination
-    `on` $mesh
+    `on` $grid
     `split_axes` `=` $split_axes
     (`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
     attr-dict `:` type($result)
@@ -1145,4 +1145,4 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
   }];
 }
-#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
+#endif // MLIR_DIALECT_SHARD_IR_SHARDOPS_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt
similarity index 100%
rename from mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
rename to mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h
similarity index 52%
rename from mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
rename to mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h
index 14aad7f9f6783..46f6ed410ebed 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h
@@ -6,10 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
-#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
 
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
@@ -20,24 +20,24 @@ class Operation;
 class IRMapping;
 class SymbolTableCollection;
 
-namespace mesh {
+namespace shard {
 
-using ShardingArray = SmallVector<SmallVector<MeshAxis>>;
-using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>;
+using ShardingArray = SmallVector<SmallVector<GridAxis>>;
+using ShardingArrayRef = ArrayRef<SmallVector<GridAxis>>;
 
 struct ShardingOption {
   // An array of int array. The sub-array at the i-th position signifies the
-  // mesh axes the i-th loop will be sharded on.
+  // grid axes the i-th loop will be sharded on.
   ShardingArray shardingArray = {};
-  FlatSymbolRefAttr mesh = nullptr;
+  FlatSymbolRefAttr grid = nullptr;
   // `empty` being true indicates that no sharding information can be inferred
   // at present. Note that it is different from the case where an operation is
   // not sharded.
   bool empty = false;
   ShardingOption() = default;
-  ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
-      : shardingArray(std::move(shardingArray)), mesh(mesh) {
-    assert(this->mesh);
+  ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr grid)
+      : shardingArray(std::move(shardingArray)), grid(grid) {
+    assert(this->grid);
   }
   static ShardingOption makeEmpty() {
     auto res = ShardingOption();
@@ -46,21 +46,21 @@ struct ShardingOption {
   }
 };
 
-// This method retrieves the 'MeshSharding' from a given operation
+// This method retrieves the 'Sharding' from a given operation
 // result and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpResult result);
+FailureOr<std::pair<bool, Sharding>> getSharding(OpResult result);
 
-// This method retrieves the 'MeshSharding' from a given operation
+// This method retrieves the 'Sharding' from a given operation
 // operand and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpOperand &opOperand);
+FailureOr<std::pair<bool, Sharding>> getSharding(OpOperand &opOperand);
 
 namespace detail {
 
 FailureOr<ShardingOption>
-defaultGetShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
-                         ArrayRef<MeshSharding> resultShardings);
+defaultGetShardingOption(Operation *op, ArrayRef<Sharding> operandShardings,
+                         ArrayRef<Sharding> resultShardings);
 
-FailureOr<std::vector<MeshSharding>>
+FailureOr<std::vector<Sharding>>
 defaultGetShardingAnnotations(Operation *op,
                               const ShardingOption &shardingOption);
 
@@ -71,18 +71,18 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
 } // namespace detail
 
 // Assumes full replication on all ranked tensor arguments and results.
-void spmdizeFullyReplicatedOperation(Operation &op,
-                                     ArrayRef<Value> spmdizedOperands,
-                                     ArrayRef<MeshSharding> operandShardings,
-                                     ArrayRef<MeshSharding> resultShardings,
-                                     IRMapping &spmdizationMap,
-                                     SymbolTableCollection &symbolTable,
-                                     OpBuilder &builder);
-
-} // namespace mesh
+void partitionFullyReplicatedOperation(Operation &op,
+                                       ArrayRef<Value> partitionedOperands,
+                                       ArrayRef<Sharding> operandShardings,
+                                       ArrayRef<Sharding> resultShardings,
+                                       IRMapping &partitionMap,
+                                       SymbolTableCollection &symbolTable,
+                                       OpBuilder &builder);
+
+} // namespace shard
 } // namespace mlir
 
 /// Include the ODS generated interface header files.
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h.inc"
 
-#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td
similarity index 80%
rename from mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
rename to mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td
index a70d2c3e03851..8f5332b41ca72 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
-#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_TD
 
 include "mlir/IR/OpBase.td"
 
@@ -16,7 +16,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
         Interface for allowing operations to expose information needed to
         shard them.
     }];
-    let cppNamespace = "::mlir::mesh";
+    let cppNamespace = "::mlir::shard";
 
     let methods = [
       InterfaceMethod<
@@ -84,8 +84,8 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
         /*retTy=*/"FailureOr<ShardingOption>",
         /*methodName=*/"getShardingOption",
         /*args=*/(ins
-          "ArrayRef<MeshSharding>": $operandShardings,
-          "ArrayRef<MeshSharding>": $resultShardings
+          "ArrayRef<Sharding>": $operandShardings,
+          "ArrayRef<Sharding>": $resultShardings
         ),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
@@ -100,7 +100,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
           This is what shardings the operands and results need to have in order
           to shard the op according to shardingOption.
         }],
-        /*retTy=*/"FailureOr<std::vector<MeshSharding>>",
+        /*retTy=*/"FailureOr<std::vector<Sharding>>",
         /*methodName=*/"getShardingAnnotations",
         /*args=*/(ins
           "const ShardingOption &":$shardingOption
@@ -113,7 +113,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
       >,
       InterfaceMethod<
         /*desc=*/[{
-          Based on a given ShardingOption, this method adds `mesh.shard`
+          Based on a given ShardingOption, this method adds `shard.shard`
           operations for the operands and results that previously lacked
           sharding annotations.
         }],
@@ -132,21 +132,21 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Convert self to SPMD form.
-          This method is used during the spmdization pass of a program fully
+          This method is used during the partition pass of a program fully
           annotated with shardings.
 
-          The spmdization algorithm would read the surrounding sharding
+          The partition algorithm would read the surrounding sharding
           annotations from the IR for each argument/result and prepare
           `operandShardings` and `resultShardings`.
           Values that are not ranked tensors do not have sharding annotations.
-          In this case their corresponding MeshSharding is null.
+          In this case their corresponding Sharding is null.
 
-          For convenience it will also prepare `spmdizedOperands`, although
-          they can be retrieved from the `spmdizationMap`.
+          For convenience it will also prepare `partitionedOperands`, although
+          they can be retrieved from the `partitionMap`.
 
-          The `spmdizationMap` contains a mapping from unsharded to
-          sharded/spmdized values that are constructed during the spmdization
-          pass. The interface implementation must populate `spmdizationMap`
+          The `partitionMap` contains a mapping from unsharded to
+          sharded/partitioned values that are constructed during the partition
+          pass. The interface implementation must populate `partitionMap`
           with the mapping for this op's results.
 
           `builder` is set to insert new operations in the appropriate point.
@@ -158,20 +158,20 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
           This assumes that all sharding annotations are for full replication.
         }],
         /*retTy=*/"LogicalResult",
-        /*methodName=*/"spmdize",
+        /*methodName=*/"partition",
         /*args=*/(ins
-          "ArrayRef<Value>": $spmdizedOperands,
-          "ArrayRef<MeshSharding>": $operandShardings,
-          "ArrayRef<MeshSharding>": $resultShardings,
-          "IRMapping&": $spmdizationMap,
+          "ArrayRef<Value>": $partitionedOperands,
+          "ArrayRef<Sharding>": $operandShardings,
+          "ArrayRef<Sharding>": $resultShardings,
+          "IRMapping&": $partitionMap,
           "SymbolTableCollection &": $symbolTableCollection,
           "OpBuilder &":$builder
         ),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          spmdizeFullyReplicatedOperation(
-            *$_op.getOperation(), spmdizedOperands, operandShardings,
-              resultShardings, spmdizationMap, symbolTableCollection, builder);
+          partitionFullyReplicatedOperation(
+            *$_op.getOperation(), partitionedOperands, operandShardings,
+              resultShardings, partitionMap, symbolTableCollection, builder);
           return success();
         }]>
     ];
@@ -184,4 +184,4 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
 }
 
 
-#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h
similarity index 58%
rename from mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
rename to mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h
index 2af8b2bd1d906..d34ba79257ff8 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h
@@ -6,11 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
-#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
 
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Value.h"
 
@@ -20,35 +20,34 @@ class Operation;
 class IRMapping;
 class SymbolTableCollection;
 
-namespace mesh {
+namespace shard {
 
-// Retrieve the mesh axes corresponding to each operation loop iterator based
+// Retrieve the grid axes corresponding to each operation loop iterator based
 // on the provided shardings for the op's operands and results.
 // Assumes that the indexingMaps are projected permutations.
-ShardingArray getMeshAxisAssignmentForLoopIterators(
-    ArrayRef<MeshSharding> operandShardings,
-    ArrayRef<MeshSharding> resultShardings,
+ShardingArray getGridAxisAssignmentForLoopIterators(
+    ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
     ArrayRef<utils::IteratorType> loopIteratorTypes,
     ArrayRef<AffineMap> indexingMaps);
 
 bool isAtLeastOneReductionIteratorSharded(
     ArrayRef<utils::IteratorType> loopIteratorTypes,
-    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+    ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
 
-// Get the set of mesh axes that correspond to reduction loop iterators.
-SmallVector<MeshAxis> getReductionMeshAxes(
+// Get the set of grid axes that correspond to reduction loop iterators.
+SmallVector<GridAxis> getReductionGridAxes(
     ArrayRef<utils::IteratorType> loopIteratorTypes,
-    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+    ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
 
 // Inserts a clone of the operation that has all ranked tensor
 // arguments/results sharded.
-void spmdizeTriviallyShardableOperation(Operation &op,
-                                        ArrayRef<Value> spmdizedOperands,
-                                        ArrayRef<MeshSharding> operandShardings,
-                                        ArrayRef<MeshSharding> resultShardings,
-                                        IRMapping &spmdizationMap,
-                                        SymbolTableCollection &symbolTable,
-                                        OpBuilder &builder);
+void partitionTriviallyShardableOperation(Operation &op,
+                                          ArrayRef<Value> partitionedOperands,
+                                          ArrayRef<Sharding> operandShardings,
+                                          ArrayRef<Sharding> resultShardings,
+                                          IRMapping &partitionMap,
+                                          SymbolTableCollection &symbolTable,
+                                          OpBuilder &builder);
 
 // All ranked tensor argument and result dimensions have
 // independent parallel loop iterators.
@@ -73,15 +72,15 @@ struct IndependentParallelIteratorDomainShardingInterface
     return SmallVector<AffineMap>();
   }
 
-  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
-                        ArrayRef<MeshSharding> operandShardings,
-                        ArrayRef<MeshSharding> resultShardings,
-                        IRMapping &spmdizationMap,
-                        SymbolTableCollection &symbolTable,
-                        OpBuilder &builder) const {
-    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
-                                       resultShardings, spmdizationMap,
-                                       symbolTable, builder);
+  LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
+                          ArrayRef<Sharding> operandShardings,
+                          ArrayRef<Sharding> resultShardings,
+                          IRMapping &partitionMap,
+                          SymbolTableCollection &symbolTable,
+                          OpBuilder &builder) const {
+    partitionTriviallyShardableOperation(*op, partitionedOperands,
+                                         operandShardings, resultShardings,
+                                         partitionMap, symbolTable, builder);
     return success();
   }
 
@@ -129,20 +128,20 @@ struct ElementwiseShardingInterface
     return maps;
   }
 
-  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
-                        ArrayRef<MeshSharding> operandShardings,
-                        ArrayRef<MeshSharding> resultShardings,
-                        IRMapping &spmdizationMap,
-                        SymbolTableCollection &symbolTable,
-                        OpBuilder &builder) const {
-    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
-                                       resultShardings, spmdizationMap,
-                                       symbolTable, builder);
+  LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
+                          ArrayRef<Sharding> operandShardings,
+                          ArrayRef<Sharding> resultShardings,
+                          IRMapping &partitionMap,
+                          SymbolTableCollection &symbolTable,
+                          OpBuilder &builder) const {
+    partitionTriviallyShardableOperation(*op, partitionedOperands,
+                                         operandShardings, resultShardings,
+                                         partitionMap, symbolTable, builder);
     return success();
   }
 };
 
-} // namespace mesh
+} // namespace shard
 } // namespace mlir
 
-#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..9e2c8d00b27f5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Shard)
+add_public_tablegen_target(MLIRShardPassIncGen)
+add_dependencies(mlir-headers MLIRShardPassIncGen)
+
+add_mlir_doc(Passes ShardPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
similarity index 61%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
index 2f6de3e134319..37903765903db 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,35 +6,35 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H
 
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
 #include "mlir/IR/DialectRegistry.h"
 
 namespace mlir {
-namespace mesh {
+namespace shard {
 
-// Insert resharding spmdization of the value `sourceShardValue`
+// Insert resharding partition of the value `sourceShardValue`
 // from sharding `source` to sharding `target`.
 // `sourceShardValue` is the already sharded value according to `source`.
 //
 // Example
 //
 // ```mlir
-//   mesh.mesh @mesh_1d(shape = 2)
+//   shard.grid @grid_1d(shape = 2)
 //   ...
-//   %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8>
-//   %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+//   %1 = shard.shard %0 to <@grid_1d, [[0]]> : tensor<2xi8>
+//   %2 = shard.shard %1 to <@grid_1d, [[]]> annotate_for_users: tensor<2xi8>
 // ```
 //
 // Will result in
 //
 // ```mlir
-//   %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 :
+//   %1 = shard.all_gather %0 on @grid_1d grid_axes = [0] gather_axis = 0 :
 //     tensor<1xi8> -> tensor<2xi8>
 // ```
-TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
                                ShardOp target,
                                TypedValue<ShapedType> sourceShardValue);
 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
@@ -44,7 +44,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
 
 void reshardingRegisterDependentDialects(DialectRegistry &registry);
 
-} // namespace mesh
+} // namespace shard
 } // namespace mlir
 
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.h
similarity index 75%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Passes.h
index a2424d43a8ba9..88bb460255728 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.h
@@ -1,4 +1,4 @@
-//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//===- Passes.h - Shard Passes ----------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_H
 
 #include "mlir/Pass/Pass.h"
 
@@ -17,7 +17,7 @@ namespace func {
 class FuncOp;
 }
 
-namespace mesh {
+namespace shard {
 
 /// This enum controls the traversal order for the sharding propagation.
 enum class TraversalOrder {
@@ -36,16 +36,16 @@ enum class TraversalOrder {
 //===----------------------------------------------------------------------===//
 
 #define GEN_PASS_DECL
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
 
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
 
 #define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
 
-} // namespace mesh
+} // namespace shard
 } // namespace mlir
 
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
similarity index 65%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
rename to mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
index 11ec7e78cd5e6..bbc6a1977b13e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
@@ -1,4 +1,4 @@
-//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//===-- Passes.td - Shard transformation definition file ---*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,8 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
-#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_TD
 
 include "mlir/Pass/PassBase.td"
 
@@ -20,31 +20,31 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
   let summary = "sharding propagation";
   let description = [{
     Propagates sharding information throughout the graph. After this pass, each
-    of the operations' operands and results is annotated with a `mesh.shard`
+    of the operations' operands and results is annotated with a `shard.shard`
     operation, and the operations themselves are added with sharding option
     attributes.
   }];
   let options = [
     Option<"traversal", "traversal",
-           "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
+           "mlir::shard::TraversalOrder", /*default=*/"mlir::shard::TraversalOrder::BackwardForward",
            "Traversal order to use for sharding propagation:",
             [{::llvm::cl::values(
-              clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
+              clEnumValN(mlir::shard::TraversalOrder::Forward, "forward",
               "Forward only traversal."),
-              clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
+              clEnumValN(mlir::shard::TraversalOrder::Backward, "backward",
               "backward only traversal."),
-              clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
+              clEnumValN(mlir::shard::TraversalOrder::ForwardBackward, "forward-backward",
               "forward-backward traversal."),
-              clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
+              clEnumValN(mlir::shard::TraversalOrder::BackwardForward, "backward-forward",
               "backward-forward traversal.")
             )}]>,
   ];
   let dependentDialects = [
-    "mesh::MeshDialect"
+    "shard::ShardDialect"
   ];
 }
 
-def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
+def Partition : InterfacePass<"shard-partition", "mlir::FunctionOpInterface"> {
   let summary = "Partition a function into SPMD form.";
   let description = [{
     This pass fits in right after a pass that annotates the function with
@@ -52,15 +52,15 @@ def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface">
     It operates on a fully annotated IR.
 
     A fully annotated IR required that all ranked tensor operands, results and
-    block arguments are annotated with the `mesh.shard` operation.
+    block arguments are annotated with the `shard.shard` operation.
   
     All direct descendant operations in the function must implement the
     `ShardingInterface` interface or all their ranked tensor operands and
     results must have full replication sharding.
 
     The input IR must have sharding annotations such that each operation
-    that implements `ShardingInterface` can handle during spmdization with
-    its `spmdize` method.
+    that implements `ShardingInterface` can handle during partition with
+    its `partition` method.
     This can be achieved with the `ShardingPropagation` pass.
 
     If the function has multiple terminating blocks,
@@ -70,36 +70,36 @@ def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface">
 
     Example:
     ```mlir
-    mesh.mesh @mesh_1d(shape = 2)
+    shard.grid @grid_1d(shape = 2)
 
     func.func @f(
       %arg0: tensor<2xi8>
     ) -> tensor<2xi8> {
-      %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
-      %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+      %0 = shard.shard %arg0 to <@grid_1d, [[0]]> : tensor<2xi8>
+      %1 = shard.shard %0 to <@grid_1d, [[0]]> annotate_for_users: tensor<2xi8>
       %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
-      %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
-      %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+      %3 = shard.shard %2 to <@grid_1d, [[0]]> : tensor<2xi8>
+      %4 = shard.shard %3 to <@grid_1d, [[]]> annotate_for_users: tensor<2xi8>
       return %4 : tensor<2xi8>
     }
     ```
-    Spmdizing the above would result in 
+    Partitioning the above would result in 
     * Performing the element-wise `abs` operation on each device.
     * Resharding to full replication with an all-gather.
 
     ```mlir
-    mesh.mesh @mesh_1d(shape = 2)
+    shard.grid @grid_1d(shape = 2)
   
     func.func @f(%arg0: tensor<1xi8>) -> tensor<2xi8> {
       %0 = tosa.abs %arg0 : (tensor<1xi8>) -> tensor<1xi8>
-      %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+      %1 = shard.all_gather %0 on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
       return %1 : tensor<2xi8>
     }
     ```
   }];
   let dependentDialects = [
-    "mesh::MeshDialect"
+    "shard::ShardDialect"
   ];
 }
 
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md b/mlir/include/mlir/Dialect/Shard/Transforms/ReshardingPartitionDoc.md
similarity index 87%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md
rename to mlir/include/mlir/Dialect/Shard/Transforms/ReshardingPartitionDoc.md
index 6368931cf6e07..cf5ae12b54b2c 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/ReshardingPartitionDoc.md
@@ -1,6 +1,6 @@
-# Resharding Spmdization Examples
+# Resharding Partition Examples
 
-Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[0, 1]]` on a `2x3` mesh.
+Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[0, 1]]` on a `2x3` shard.
 
 unsharded `2x3` tensor
 ```
@@ -8,16 +8,16 @@ unsharded `2x3` tensor
 21 22 23
 ```
 
-sharded on a `2x3` mesh
+sharded on a `2x3` grid
 
 sharding = `[[0, 1]]`
 
-mesh contents:
+grid contents:
 
 ```
-mesh axis 1
+grid axis 1
 ----------->
-+----+----+----+ mesh axis 0 |
++----+----+----+ grid axis 0 |
 | 11 | 12 | 13 |             |
 +----+----+----+             |
 | 21 | 22 | 23 |             |
@@ -27,9 +27,9 @@ mesh axis 1
 Transform into
 sharding = `[[1, 0]]`
 ```
-mesh axis 1
+grid axis 1
 ----------->
-+----+----+----+ mesh axis 0 |
++----+----+----+ grid axis 0 |
 | 11 | 13 | 22 |             |
 +----+----+----+             |
 | 12 | 21 | 23 |             |
@@ -40,7 +40,7 @@ Swap contents on devices that have the same linear index in the 2 shardings.
 
 --------------------------------------------------------------
 
-Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[1]]` on a `2x3` mesh.
+Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[1]]` on a `2x3` shard.
 
 unsharded `2x3` tensor
 ```
@@ -48,15 +48,15 @@ unsharded `2x3` tensor
 21 22 23
 ```
 
-sharded on a `2x3` mesh
+sharded on a `2x3` grid
 
 sharding = `[[0, 1]]`
 
-mesh contents:
+grid contents:
 ```
-mesh axis 1
+grid axis 1
 ----------->
-+----+----+----+ mesh axis 0 |
++----+----+----+ grid axis 0 |
 | 11 | 12 | 13 |             |
 +----+----+----+             |
 | 21 | 22 | 23 |             |
@@ -66,9 +66,9 @@ mesh axis 1
 Transform into
 sharding = `[[1]]`
 ```
-mesh axis 1
+grid axis 1
 ----------->
-+----+----+----+ mesh axis 0 |
++----+----+----+ grid axis 0 |
 | 11 | 12 | 13 |             |
 | 21 | 22 | 23 |             |
 +----+----+----+             |
@@ -77,11 +77,11 @@ mesh axis 1
 +----+----+----+             ↓
 ```
 Algorithm:
-All-gather along mesh axis 0.
+All-gather along grid axis 0.
 
 --------------------------------------------------------------
 
-Reshard `4x6` tensor from sharding `[[], [0, 1]]` to sharding `[[], [0]]` on a `2x3` mesh.
+Reshard `4x6` tensor from sharding `[[], [0, 1]]` to sharding `[[], [0]]` on a `2x3` shard.
 
 unsharded `4x6` tensor
 ```
@@ -89,15 +89,15 @@ unsharded `4x6` tensor
 21 22 23 24 25 26
 ```
 
-sharded on a `2x3` mesh
+sharded on a `2x3` grid
 
 sharding = `[[], [0, 1]]`
 
-mesh contents:
+grid contents:
 ```
-mesh axis 1
+grid axis 1
 ----------->
-+----+----+----+ mesh axis 0 |
++----+----+----+ grid axis 0 |
 | 11 | 12 | 13 |             |
 | 21 | 22 | 23 |             |
 +----+----+----+             |
@@ -108,9 +108,9 @@ mesh axis 1
 Transform into
 sharding = `[[], [0]]`
 ```
-mesh axis 1
+grid axis 1
 ----------->
-+----------+----------+ mesh axis 0 |
++----------+----------+ grid axis 0 |
 | 11 12 13 | 11 12 13 |             |
 | 21 22 23 | 21 22 23 |             |
 +----------+----------+             |
@@ -119,11 +119,11 @@ mesh axis 1
 +----------+----------+             ↓
 ```
 Algorithm:
-All-gather along mesh axis 1.
+All-gather along grid axis 1.
 
 --------------------------------------------------------------
 
-Reshard `4x8` tensor from sharding `[[0], [1, 2]]` to sharding `[[0], [2]]` on a `2x2x2` mesh.
+Reshard `4x8` tensor from sharding `[[0], [1, 2]]` to sharding `[[0], [2]]` on a `2x2x2` shard.
 
 unsharded `4x8` tensor
 ```
@@ -132,15 +132,15 @@ unsharded `4x8` tensor
 31 32 33 34 35 36 37 38
 41 42 43 44 45 46 47 48
 ```
-sharded on a `2x2x2` mesh
+sharded on a `2x2x2` grid
 
 sharding = `[[0], [1, 2]]`
 
-mesh contents:
+grid contents:
 ```
-mesh axis 2
+grid axis 2
 ----------->
-+-------+-------+ mesh axis 1 | mesh axis 0 |
++-------+-------+ grid axis 1 | grid axis 0 |
 | 11 12 | 13 14 |             |             |
 | 21 22 | 23 24 |             |             |
 +-------+-------+             |             |
@@ -158,9 +158,9 @@ mesh axis 2
 Transform into
 sharding = `[[0], [2]]`
 ```
-mesh axis 2
+grid axis 2
 ----------->
-+-------------+-------------+ mesh axis 1 | mesh axis 0 |
++-------------+-------------+ grid axis 1 | grid axis 0 |
 | 11 12 13 14 | 15 16 17 18 |             |             |
 | 21 22 23 24 | 25 26 27 28 |             |             |
 +-------------+-------------+             |             |
@@ -177,13 +177,13 @@ mesh axis 2
 ```
 Algorithm:
 
-Can't be done with just an all-gather along mesh axis 1.
+Can't be done with just an all-gather along grid axis 1.
 Can be handled by multiple resharding transformations
 `[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]`
 
 --------------------------------------------------------------
 
-Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh.
+Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` shard.
 
 unsharded `6x6` tensor
 ```
@@ -194,13 +194,13 @@ unsharded `6x6` tensor
 51 52 53 54 55 56
 61 62 63 64 65 66
 ```
-sharded on a `2x3` mesh
+sharded on a `2x3` grid
 
 sharding = `[[0], [1]]`
 ```
-mesh axis 1
+grid axis 1
 ----------->
-+-------+-------+-------+ mesh axis 0 |
++-------+-------+-------+ grid axis 0 |
 | 11 12 | 13 14 | 15 16 |             |
 | 21 22 | 23 24 | 25 26 |             |
 | 31 32 | 33 34 | 35 36 |             |
@@ -213,9 +213,9 @@ mesh axis 1
 transform to
 sharding = `[[1], [0]]`
 ```
-mesh axis 1
+grid axis 1
 ----------->
-+----------+----------+----------+ mesh axis 0 |
++----------+----------+----------+ grid axis 0 |
 | 11 12 13 | 31 32 33 | 51 52 53 |             |
 | 21 22 23 | 41 42 43 | 61 62 63 |             |
 +----------+----------+----------+             |
@@ -223,9 +223,9 @@ mesh axis 1
 | 24 25 26 | 44 45 46 | 64 65 66 |             |
 +----------+----------+----------+             ↓
 
-mesh axis 0
+grid axis 0
 ----------->
-+----------+----------+ mesh axis 1 |
++----------+----------+ grid axis 1 |
 | 11 12 13 | 14 15 16 |             |
 | 21 22 23 | 24 25 26 |             |
 +----------+----------+             |
@@ -240,7 +240,7 @@ Algorithm: TODO
 
 --------------------------------------------------------------
 
-Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x6` mesh.
+Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x6` shard.
 
 unsharded 6x6 tensor
 ```
@@ -251,13 +251,13 @@ unsharded 6x6 tensor
 51 52 53 54 55 56
 61 62 63 64 65 66
 ```
-shard on `2x6` mesh
+shard on `2x6` grid
 
 sharding = `[[0], [1]]`
 ```
-mesh axis 1
+grid axis 1
 ----------->
-+----+----+----+----+----+----+ mesh axis 0 |
++----+----+----+----+----+----+ grid axis 0 |
 | 11 | 12 | 13 ‖ 14 | 15 | 16 |             |
 | 21 | 22 | 23 ‖ 24 | 23 | 26 |             |
 | 31 | 32 | 33 ‖ 34 | 35 | 36 |             |
@@ -270,9 +270,9 @@ mesh axis 1
 transform to
 sharding = `[[1], [0]]`
 ```
-mesh axis 0
+grid axis 0
 ----------->
-+----------+----------+ mesh axis 1 |
++----------+----------+ grid axis 1 |
 | 11 12 13 | 14 15 16 |             |
 +----------+----------+             |
 | 21 22 23 | 24 25 26 |             |
@@ -290,9 +290,9 @@ Algorithm: TODO
 
 --------------------------------------------------------------
 
-Reshard KxL tensor from `[[0], [1]]` to `[[1], [0]]` on `MxN` mesh.
+Reshard KxL tensor from `[[0], [1]]` to `[[1], [0]]` on `MxN` shard.
 
-`M x N` mesh.
+`M x N` shard.
 `K x L` tensor `t`.
 `d(m, n)` the tensor on device `(m, n)`.
 
@@ -433,9 +433,9 @@ TODO
 
 --------------------------------------------------------------
 
-Reshard `KxL` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh.
+Reshard `KxL` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` shard.
 
-Device placement on a `2x3` mesh
+Device placement on a `2x3` grid
 ```
 11 12 13  <- devices
 21 22 23
@@ -512,7 +512,7 @@ TODO
 
 --------------------------------------------------------------
 
-Reshard `6x6` tensor from sharding `[[0], []]` to sharding `[[], [0]]` on a `3` mesh.
+Reshard `6x6` tensor from sharding `[[0], []]` to sharding `[[], [0]]` on a `3` shard.
 
 unsharded `6x6` tensor
 ```
@@ -523,11 +523,11 @@ unsharded `6x6` tensor
 51 52 53 54 55 56
 61 62 63 64 65 66
 ```
-sharded on a `3` mesh
+sharded on a `3` grid
 
 sharding = `[[0], []]`
 ```
-+-------------------+ mesh axis 0 |
++-------------------+ grid axis 0 |
 | 11 12 13 14 15 16 |             |
 | 21 22 23 24 25 26 |             |
 +-------------------+             |
@@ -541,7 +541,7 @@ sharding = `[[0], []]`
 transform to
 sharding = `[[], [0]]`
 ```
-mesh axis 0
+grid axis 0
 ----------->
 +-------+-------+-------+
 | 11 12 | 13 14 | 15 16 |
@@ -554,11 +554,11 @@ mesh axis 0
 ```
 Algorithm:
 ```mlir
-%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
+%1 = all_to_all %0 on @grid grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
 ```
 --------------------------------------------------------------
 
-Reshard `4x4` tensor from sharding `[[0], [1, 2]]` to sharding `[[0, 1], [2]]` on a `2x2x2` mesh.
+Reshard `4x4` tensor from sharding `[[0], [1, 2]]` to sharding `[[0, 1], [2]]` on a `2x2x2` shard.
 
 unsharded `4x4` tensor
 ```
@@ -567,13 +567,13 @@ unsharded `4x4` tensor
 31 32 33 34
 41 42 43 44
 ```
-sharded on a `2x2x2` mesh
+sharded on a `2x2x2` grid
 
 sharding = `[[0], [1, 2]]`
 ```
-mesh axis 2
+grid axis 2
 ----------->
-+----+----+ mesh axis 1 | mesh axis 0 |
++----+----+ grid axis 1 | grid axis 0 |
 | 11 | 12 |             |             |
 | 21 | 22 |             |             |
 +----+----+             |             |
@@ -591,9 +591,9 @@ mesh axis 2
 transform to
 sharding = `[[0, 1], [2]]`
 ```
-mesh axis 2
+grid axis 2
 ----------->
-+-------+-------+ mesh axis 1 | mesh axis 0 |
++-------+-------+ grid axis 1 | grid axis 0 |
 | 11 12 | 13 41 |             |             |
 +-------+-------+             |             |
 | 21 22 | 23 24 |             |             |
@@ -606,7 +606,7 @@ mesh axis 2
 ```
 Algorithm:
 ```mlir
-%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
+%1 = all_to_all %0 on @grid grid_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
 ```
 is not enough.
 
@@ -639,15 +639,15 @@ Basis:
     [[0]] -> [[]]
     [[0, 1]] -> [[1]]
     ```
-    All-gather along mesh axis 0.
+    All-gather along grid axis 0.
 
-*   Swap mesh axes order when assigned to the same tensor axis.
+*   Swap grid axes order when assigned to the same tensor axis.
     ```
     [[0, 1]] -> [[1, 0]]
     ```
     Swap contents on devices with the same linear index.
 
-*   Move mesh axis to different tensor dimension.
+*   Move grid axis to different tensor dimension.
     ```
     [[0], []] -> [[], [0]]
     ```
@@ -661,9 +661,9 @@ Example decomposition of
 ```
 into
 ```
-[[0], [1]] -> all-gather along mesh axis 1    ->
-[[0], []]  -> all-to-all along mesh axis 0    ->
-[[], [0]]  -> extract slice along mesh axis 1 ->
+[[0], [1]] -> all-gather along grid axis 1    ->
+[[0], []]  -> all-to-all along grid axis 0    ->
+[[], [0]]  -> extract slice along grid axis 1 ->
 [[1], [0]]
 ```
 
@@ -675,9 +675,9 @@ Example decomposition of
 ```
 into
 ```
-[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 ->
-[[3, 2], [1], [0]]   -> all-to-all along mesh axis 2 ->
-[[3], [1, 2], [0]]   -> all-gather along mesh axis 3 ->
-[[], [1, 2], [0]]    -> all-to-all along mesh axis 0 ->
+[[3, 2], [], [0, 1]] -> all-to-all along grid axis 1 ->
+[[3, 2], [1], [0]]   -> all-to-all along grid axis 2 ->
+[[3], [1, 2], [0]]   -> all-gather along grid axis 3 ->
+[[], [1, 2], [0]]    -> all-to-all along grid axis 0 ->
 [[0], [1, 2], []]
 ```
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
similarity index 93%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
index 243dbf081b999..452d4f6b4ed61 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
 
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/EndomorphismSimplification.h"
 #include "llvm/Support/Casting.h"
@@ -22,7 +22,7 @@ namespace mlir {
 
 class SymbolTableCollection;
 
-namespace mesh {
+namespace shard {
 
 // If we have an algebraic op like "+" and a summing all-reduce,
 // `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
@@ -112,7 +112,7 @@ void populateSimplificationPatterns(
 void populateFoldingPatterns(RewritePatternSet &patterns,
                              SymbolTableCollection &symbolTableCollection);
 
-} // namespace mesh
+} // namespace shard
 } // namespace mlir
 
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
similarity index 65%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
index f46c0db846088..57d65e687ea35 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
@@ -1,4 +1,4 @@
-//===- Transforms.h - Mesh Transforms ---------------------------*- C++ -*-===//
+//===- Transforms.h - Shard Transforms --------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H
 
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
@@ -20,7 +20,7 @@ class RewritePatternSet;
 class SymbolTableCollection;
 class DialectRegistry;
 class ImplicitLocOpBuilder;
-namespace mesh {
+namespace shard {
 
 void populateProcessMultiIndexOpLoweringPatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
@@ -35,20 +35,20 @@ void populateAllOpLoweringPatterns(
 void registerAllOpLoweringDialects(DialectRegistry &registry);
 
 TypedValue<IndexType>
-createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
                                  ImplicitLocOpBuilder &builder);
 
-// Get process linear index along the given mesh axes.
-TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
-                                               ArrayRef<MeshAxis> meshAxes,
+// Get process linear index along the given grid axes.
+TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
+                                               ArrayRef<GridAxis> gridAxes,
                                                ImplicitLocOpBuilder &builder);
-// Get process linear index from a multi-index along the given mesh axes .
+// Get process linear index from a multi-index along the given grid axes .
 TypedValue<IndexType>
-createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
-                         ArrayRef<MeshAxis> meshAxes,
+createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
+                         ArrayRef<GridAxis> gridAxes,
                          ImplicitLocOpBuilder &builder);
 
-} // namespace mesh
+} // namespace shard
 } // namespace mlir
 
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H
diff --git a/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Tensor/Extensions/ShardingExtensions.h
similarity index 88%
rename from mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h
rename to mlir/include/mlir/Dialect/Tensor/Extensions/ShardingExtensions.h
index cfac485b807f2..895e7e5939935 100644
--- a/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h
+++ b/mlir/include/mlir/Dialect/Tensor/Extensions/ShardingExtensions.h
@@ -1,4 +1,4 @@
-//===- MeshShardingExtensions.h - -------------------------------*- C++ -*-===//
+//===- ShardingExtensions.h - -------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c6fcf1a0d510b..856170e9308da 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -60,7 +60,6 @@
 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -77,6 +76,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -131,7 +131,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   LLVM::LLVMDialect,
                   math::MathDialect,
                   memref::MemRefDialect,
-                  mesh::MeshDialect,
+                  shard::ShardDialect,
                   ml_program::MLProgramDialect,
                   mpi::MPIDialect,
                   nvgpu::NVGPUDialect,
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index dd8b292a87344..002ff61fb87dd 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -32,13 +32,13 @@
 #include "mlir/Dialect/MLProgram/Transforms/Passes.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
-#include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
 #include "mlir/Dialect/OpenACC/Transforms/Passes.h"
 #include "mlir/Dialect/Quant/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
 #include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
@@ -81,7 +81,7 @@ inline void registerAllPasses() {
   LLVM::registerLLVMPasses();
   math::registerMathPasses();
   memref::registerMemRefPasses();
-  mesh::registerMeshPasses();
+  shard::registerShardPasses();
   ml_program::registerMLProgramPasses();
   quant::registerQuantPasses();
   registerSCFPasses();
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index f84375b6b8d6a..785cb8293810c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -43,7 +43,7 @@ add_subdirectory(MathToSPIRV)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
-add_subdirectory(MeshToMPI)
+add_subdirectory(ShardToMPI)
 add_subdirectory(MPIToLLVM)
 add_subdirectory(NVGPUToNVVM)
 add_subdirectory(NVVMToLLVM)
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
similarity index 65%
rename from mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
rename to mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
index 15560aa61e145..564f36fd20abb 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
@@ -1,8 +1,8 @@
-add_mlir_conversion_library(MLIRMeshToMPI
-  MeshToMPI.cpp
+add_mlir_conversion_library(MLIRShardToMPI
+  ShardToMPI.cpp
 
   ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShardToMPI
 
   DEPENDS
   MLIRConversionPassIncGen
@@ -17,7 +17,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
   MLIRLinalgTransforms
   MLIRMemRefDialect
   MLIRPass
-  MLIRMeshDialect
+  MLIRShardDialect
   MLIRMPIDialect
   MLIRTransforms
   )
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
similarity index 92%
rename from mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
rename to mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 63b1fdabaf407..8525543760d99 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -1,4 +1,4 @@
-//===- MeshToMPI.cpp - Mesh to MPI  dialect conversion -----------------===//
+//===- ShardToMPI.cpp - Shard to MPI  dialect conversion -----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements a translation of Mesh communication ops tp MPI ops.
+// This file implements a translation of Shard communication ops to MPI ops.
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -20,11 +20,11 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
@@ -35,16 +35,16 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-#define DEBUG_TYPE "mesh-to-mpi"
+#define DEBUG_TYPE "shard-to-mpi"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
 namespace mlir {
-#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
 #include "mlir/Conversion/Passes.h.inc"
 } // namespace mlir
 
 using namespace mlir;
-using namespace mesh;
+using namespace shard;
 
 namespace {
 /// Converts a vector of OpFoldResults (ints) into vector of Values of the
@@ -188,18 +188,18 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
     // maxSplitSize+1}. Store the offsets in the tensor but set trailing
     // elements for smaller split-groups to -1. Computing the max size of the
     // split groups needs using collectiveProcessGroupSize (which needs the
-    // MeshOp)
+    // GridOp)
     Value resOffsets;
     if (adaptor.getStaticShardedDimsOffsets().empty()) {
       resOffsets = tensor::EmptyOp::create(rewriter, loc,
                                            std::array<int64_t, 2>{0, 0}, i64);
     } else {
       SymbolTableCollection symbolTableCollection;
-      auto meshOp = getMesh(op, symbolTableCollection);
+      auto gridOp = getGrid(op, symbolTableCollection);
       int64_t maxSplitSize = 0;
       for (auto axes : splitAxes) {
         int64_t splitSize =
-            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+            collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
         assert(splitSize != ShapedType::kDynamic);
         maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
       }
@@ -218,7 +218,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
       int64_t curr = 0;
       for (auto [i, axes] : llvm::enumerate(splitAxes)) {
         int64_t splitSize =
-            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+            collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
         assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
         ++splitSize; // add one for the total size
         ArrayRef<Value> values(&offsets[curr], splitSize);
@@ -264,20 +264,20 @@ struct ConvertProcessMultiIndexOp
 
     SymbolTableCollection symbolTableCollection;
     Location loc = op.getLoc();
-    auto meshOp = getMesh(op, symbolTableCollection);
-    // For now we only support static mesh shapes
-    if (ShapedType::isDynamicShape(meshOp.getShape()))
+    auto gridOp = getGrid(op, symbolTableCollection);
+    // For now we only support static grid shapes
+    if (ShapedType::isDynamicShape(gridOp.getShape()))
       return failure();
 
     SmallVector<Value> dims;
     llvm::transform(
-        meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+        gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
           return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
         });
-    Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp);
+    Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
     auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
 
-    // optionally extract subset of mesh axes
+    // optionally extract subset of grid axes
     auto axes = adaptor.getAxes();
     if (!axes.empty()) {
       SmallVector<Value> subIndex;
@@ -338,12 +338,12 @@ struct ConvertNeighborsLinearIndicesOp
 
     Location loc = op.getLoc();
     SymbolTableCollection symbolTableCollection;
-    auto meshOp = getMesh(op, symbolTableCollection);
+    auto gridOp = getGrid(op, symbolTableCollection);
     auto mIdx = adaptor.getDevice();
     auto orgIdx = mIdx[axes[0]];
     SmallVector<Value> dims;
     llvm::transform(
-        meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+        gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
           return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
         });
     Value dimSz = dims[axes[0]];
@@ -394,14 +394,14 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
     auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
     if (!sharding) {
       return op->emitError()
-             << "Expected SharingOp as defining op for sharding"
+             << "Expected ShardingOp as defining op for sharding"
              << " but found " << adaptor.getSharding()[0].getDefiningOp();
     }
 
     // Compute the sharded shape by applying the sharding to the input shape.
     // If shardedDimsOffsets is not defined in the sharding, the shard shape is
     // computed by dividing the dimension size by the number of shards in that
-    // dimension (which is given by the size of the mesh axes provided in
+    // dimension (which is given by the size of the grid axes provided in
     // split-axes). Odd elements get distributed to trailing shards. If a
     // shardedDimsOffsets is provided, the shard shape is computed by
     // subtracting the offset of the current shard from the offset of the next
@@ -431,11 +431,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
     SmallVector<Value> multiIdx =
         getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
 
-    // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
+    // Get the GridOp, the grid shape is needed to compute the sharded shape.
     SymbolTableCollection symbolTableCollection;
-    auto meshOp = getMesh(sharding, symbolTableCollection);
-    // For now we only support static mesh shapes
-    if (ShapedType::isDynamicShape(meshOp.getShape()))
+    auto gridOp = getGrid(sharding, symbolTableCollection);
+    // For now we only support static grid shapes
+    if (ShapedType::isDynamicShape(gridOp.getShape()))
       return failure();
 
     auto splitAxes = sharding.getSplitAxes().getAxes();
@@ -455,7 +455,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
             tmp);
     }
 
-    // With static mesh shape the sizes of the split axes are known.
+    // With static grid shape the sizes of the split axes are known.
     // Hence the start/pos for each split axes in shardDimsOffsets can be
     // computed statically.
     int64_t pos = 0;
@@ -475,10 +475,10 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
         // Create a value from the static position in shardDimsOffsets.
         Value posVal = arith::ConstantOp::create(rewriter, loc,
                                                  rewriter.getIndexAttr(pos));
-        // Get the index of the local shard in the mesh axis.
+        // Get the index of the local shard in the grid axis.
         Value idx = multiIdx[axes[0]];
         auto numShards =
-            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+            collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
         if (shardedDimsOffs) {
           // If sharded dims offsets are provided, use them to compute the
           // sharded shape.
@@ -556,13 +556,13 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
   matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SymbolTableCollection symbolTableCollection;
-    auto mesh = adaptor.getMesh();
-    mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection);
-    if (!meshOp)
-      return op->emitError() << "No mesh found for AllReduceOp";
-    if (ShapedType::isDynamicShape(meshOp.getShape()))
+    auto grid = adaptor.getGrid();
+    mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
+    if (!gridOp)
+      return op->emitError() << "No grid found for AllReduceOp";
+    if (ShapedType::isDynamicShape(gridOp.getShape()))
       return op->emitError()
-             << "Dynamic mesh shape not supported in AllReduceOp";
+             << "Dynamic grid shape not supported in AllReduceOp";
 
     ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
     Value input = adaptor.getInput();
@@ -592,27 +592,27 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
     linalg::CopyOp::create(iBuilder, input, buffer);
 
     // Get an MPI_Comm_split for the AllReduce operation.
-    // The color is the linear index of the process in the mesh along the
-    // non-reduced axes. The key is the linear index of the process in the mesh
+    // The color is the linear index of the process in the grid along the
+    // non-reduced axes. The key is the linear index of the process in the grid
     // along the reduced axes.
-    SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+    SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
                                        iBuilder.getIndexType());
     SmallVector<Value> myMultiIndex =
-        ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh)
+        ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
             .getResult();
     Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
     SmallVector<Value> multiKey(myMultiIndex.size(), zero);
 
-    auto redAxes = adaptor.getMeshAxes();
+    auto redAxes = adaptor.getGridAxes();
     for (auto axis : redAxes) {
       multiKey[axis] = myMultiIndex[axis];
       myMultiIndex[axis] = zero;
     }
 
     Value color =
-        createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
+        createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
     color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
-    Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
+    Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
     key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
 
     // Finally split the communicator
@@ -698,8 +698,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
     }
     auto rank = cast<ShapedType>(array.getType()).getRank();
     auto opSplitAxes = adaptor.getSplitAxes().getAxes();
-    auto mesh = adaptor.getMesh();
-    auto meshOp = getMesh(op, symbolTableCollection);
+    auto grid = adaptor.getGrid();
+    auto gridOp = getGrid(op, symbolTableCollection);
     // subviews need Index values
     for (auto &sz : haloSizes) {
       if (auto value = dyn_cast<Value>(sz))
@@ -745,10 +745,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
     auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
     auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
 
-    SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+    SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
                                        rewriter.getIndexType());
     auto myMultiIndex =
-        ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh)
+        ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
             .getResult();
     // traverse all split axes from high to low dim
     for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
@@ -759,7 +759,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
       // Get the linearized ids of the neighbors (down and up) for the
       // given split
       auto tmp = rewriter
-                     .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
+                     .create<NeighborsLinearIndicesOp>(loc, grid, myMultiIndex,
                                                        splitAxes)
                      .getResults();
       // MPI operates on i32...
@@ -791,7 +791,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
         dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
                                   : haloSizes[currHaloDim * 2];
         // Check if we need to send and/or receive
-        // Processes on the mesh borders have only one neighbor
+        // Processes on the grid borders have only one neighbor
         auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
         auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
         auto hasFrom = arith::CmpIOp::create(
@@ -869,8 +869,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
   }
 };
 
-struct ConvertMeshToMPIPass
-    : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+struct ConvertShardToMPIPass
+    : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
   using Base::Base;
 
   /// Run the dialect converter on the module.
@@ -879,12 +879,12 @@ struct ConvertMeshToMPIPass
     RewritePatternSet patterns(ctxt);
     ConversionTarget target(getContext());
 
-    // Define a type converter to convert mesh::ShardingType,
+    // Define a type converter to convert shard::ShardingType,
     // mostly for use in return operations.
     TypeConverter typeConverter;
     typeConverter.addConversion([](Type type) { return type; });
 
-    // convert mesh::ShardingType to a tuple of RankedTensorTypes
+    // convert shard::ShardingType to a tuple of RankedTensorTypes
     typeConverter.addConversion(
         [](ShardingType type,
            SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
@@ -920,10 +920,10 @@ struct ConvertMeshToMPIPass
           return results;
         });
 
-    // No mesh dialect should left after conversion...
-    target.addIllegalDialect<mesh::MeshDialect>();
-    // ...except the global MeshOp. MeshShapeOp which will get folded later.
-    target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>();
+    // No shard dialect should left after conversion...
+    target.addIllegalDialect<shard::ShardDialect>();
+    // ...except the global GridOp. GridShapeOp which will get folded later.
+    target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
     // Allow all the stuff that our patterns will convert to
     target.addLegalDialect<
         BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
@@ -951,7 +951,7 @@ struct ConvertMeshToMPIPass
     // Folding patterns cannot be mixed with conversion patterns -> extra pass.
     patterns.clear();
     SymbolTableCollection symbolTableCollection;
-    mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
+    mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index f96bda603baa6..93682a9375dac 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -27,7 +27,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   MLIRInferIntRangeInterface
   MLIRIR
   MLIRMemRefDialect
-  MLIRMeshDialect
+  MLIRShardDialect
   MLIRPass
   MLIRShardingInterface
   MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
index 3478adcb4a128..2906999f8def7 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -6,22 +6,22 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
 #include "mlir/IR/DialectRegistry.h"
 
 using namespace mlir;
 using namespace mlir::arith;
-using namespace mlir::mesh;
+using namespace mlir::shard;
 
 namespace {
 
 // Sharding of arith.constant
 // RankedTensor constants can be sharded like any other tensor.
 //   %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-//   %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+//   %sharding = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
 // Scalar constants are always replicated and need no sharding annotation.
 
 struct ConstantShardingInterface
@@ -48,8 +48,8 @@ struct ConstantShardingInterface
   // Otherwise mirror result sharding if it is a tensor constant.
   // Otherwise return replication option.
   FailureOr<ShardingOption>
-  getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
-                    ArrayRef<MeshSharding> resultShardings) const {
+  getShardingOption(Operation *op, ArrayRef<Sharding> operandShardings,
+                    ArrayRef<Sharding> resultShardings) const {
     assert(resultShardings.size() == 1 &&
            "Expecting exactly one result sharding for arith.constant");
     auto resultSharding = resultShardings[0];
@@ -61,17 +61,17 @@ struct ConstantShardingInterface
       for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
         axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
       }
-      return ShardingOption(axesArray, resultSharding.getMeshAttr());
+      return ShardingOption(axesArray, resultSharding.getGridAttr());
     }
-    return ShardingOption({}, resultSharding.getMeshAttr());
+    return ShardingOption({}, resultSharding.getGridAttr());
   }
 
-  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
-                        ArrayRef<MeshSharding> operandShardings,
-                        ArrayRef<MeshSharding> resultShardings,
-                        IRMapping &spmdizationMap,
-                        SymbolTableCollection &symbolTable,
-                        OpBuilder &builder) const {
+  LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+                          ArrayRef<Sharding> operandShardings,
+                          ArrayRef<Sharding> resultShardings,
+                          IRMapping &partitionMap,
+                          SymbolTableCollection &symbolTable,
+                          OpBuilder &builder) const {
     auto cOp = cast<ConstantOp>(op);
     if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
       if (!value.isSplat() || !resultShardings[0]) {
@@ -80,15 +80,15 @@ struct ConstantShardingInterface
       }
       auto sharding = resultShardings[0];
       auto newType = cast<RankedTensorType>(shardType(
-          cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
+          cOp.getType(), getGrid(op, sharding.getGridAttr(), symbolTable),
           sharding));
       auto newValue = value.resizeSplat(newType);
       auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
-      spmdizationMap.map(op->getResult(0), newOp.getResult());
-      spmdizationMap.map(op, newOp.getOperation());
+      partitionMap.map(op->getResult(0), newOp.getResult());
+      partitionMap.map(op, newOp.getOperation());
     } else {
       // `clone` will populate the mapping of old to new results.
-      (void)builder.clone(*op, spmdizationMap);
+      (void)builder.clone(*op, partitionMap);
     }
     return success();
   }
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 3cc52ebc0a8d9..053ee95e92053 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -19,7 +19,7 @@ add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
-add_subdirectory(Mesh)
+add_subdirectory(Shard)
 add_subdirectory(MLProgram)
 add_subdirectory(MPI)
 add_subdirectory(NVGPU)
diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
index eb6b59bb00f1b..1b18ef2dd04a7 100644
--- a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
@@ -8,7 +8,7 @@
 
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
 #include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
-#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Func/Extensions/ShardingExtensions.h"
 
 using namespace mlir;
 
diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
index 47363f48f95cc..87ef51e63f1da 100644
--- a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
@@ -1,7 +1,7 @@
 set(LLVM_OPTIONAL_SOURCES
   AllExtensions.cpp
   InlinerExtension.cpp
-  MeshShardingExtensions.cpp
+  ShardingExtensions.cpp
   )
 
 add_mlir_extension_library(MLIRFuncInlinerExtension
@@ -17,8 +17,8 @@ add_mlir_extension_library(MLIRFuncInlinerExtension
   MLIRFuncDialect
   )
 
-add_mlir_extension_library(MLIRFuncMeshShardingExtensions
-  MeshShardingExtensions.cpp
+add_mlir_extension_library(MLIRFuncShardingExtensions
+  ShardingExtensions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions
@@ -38,5 +38,5 @@ add_mlir_extension_library(MLIRFuncAllExtensions
 
   LINK_LIBS PUBLIC
   MLIRFuncInlinerExtension
-  MLIRFuncMeshShardingExtensions
+  MLIRFuncShardingExtensions
   )
diff --git a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp
similarity index 68%
rename from mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
rename to mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp
index da508cc95bfe1..dfd1348c24441 100644
--- a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp
@@ -1,4 +1,4 @@
-//===- MeshShardingExtensions.cpp - ---------------------------------------===//
+//===- ShardingExtensions.cpp - ---------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Func/Extensions/ShardingExtensions.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
 #include "mlir/IR/MLIRContext.h"
 
 namespace mlir::func {
@@ -16,7 +16,7 @@ namespace mlir::func {
 void registerShardingInterfaceExternalModels(DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, FuncDialect *dialect) {
     ReturnOp::attachInterface<
-        mesh::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>(
+        shard::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>(
         *ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
index b6e168e95ee86..7f6ecab2d90f5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
@@ -15,7 +15,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
@@ -119,8 +119,8 @@ void mlir::linalg::LinalgDialect::initialize() {
 
   addInterfaces<LinalgInlinerInterface>();
 
-  declarePromisedInterface<mesh::ShardingInterface, GenericOp>();
-  declarePromisedInterfaces<mesh::ShardingInterface,
+  declarePromisedInterface<shard::ShardingInterface, GenericOp>();
+  declarePromisedInterfaces<shard::ShardingInterface,
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
                             >();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
index 281d9f2204486..ba94ad7906ab7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
@@ -10,14 +10,14 @@
 
 #include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
 
 void mlir::linalg::registerAllDialectInterfaceImplementations(
     DialectRegistry &registry) {
   registerBufferizableOpInterfaceExternalModels(registry);
-  registerMeshShardingInterfaceExternalModels(registry);
+  registerShardingInterfaceExternalModels(registry);
   registerSubsetOpInterfaceExternalModels(registry);
   registerTilingInterfaceExternalModels(registry);
   registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 69e6fdabf9a58..70f846e5bbd20 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,7 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Interchange.cpp
   Loops.cpp
   TransposeMatmul.cpp
-  MeshShardingInterfaceImpl.cpp
+  ShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
   BlockPackMatmul.cpp
   PackAndUnpackPatterns.cpp
@@ -68,7 +68,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRIR
   MLIRMemRefDialect
   MLIRMemRefTransforms
-  MLIRMeshTransforms
+  MLIRShardTransforms
   MLIRLinalgDialect
   MLIRLinalgUtils
   MLIRSCFDialect
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
similarity index 65%
rename from mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
rename to mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 24b8765284fa5..4c36118033aed 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -1,4 +1,4 @@
-//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//===- ShardingInterfaceImpl.cpp --------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,18 +6,18 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h"
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -36,13 +36,13 @@
 
 namespace mlir::linalg {
 
-using MeshAxis = mesh::MeshAxis;
-using ReductionKind = mesh::ReductionKind;
-using MeshSharding = mesh::MeshSharding;
-using ShardingArray = mesh::ShardingArray;
-using MeshOp = mesh::MeshOp;
+using GridAxis = shard::GridAxis;
+using ReductionKind = shard::ReductionKind;
+using Sharding = shard::Sharding;
+using ShardingArray = shard::ShardingArray;
+using GridOp = shard::GridOp;
 
-// Returns the corresponding mesh reduction kind for the given arith op.
+// Returns the corresponding grid reduction kind for the given arith op.
 static ReductionKind getReductionKind(Operation *op) {
   return llvm::TypeSwitch<Operation *, ReductionKind>(op)
       // Floating-point operations.
@@ -97,18 +97,18 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
   return getReductionKind(reductionOp.value());
 }
 
-static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
-                      ArrayRef<MeshSharding> resultShardings,
+static GridOp getGrid(Operation *op, ArrayRef<Sharding> operandShardings,
+                      ArrayRef<Sharding> resultShardings,
                       SymbolTableCollection &symbolTable) {
-  for (const MeshSharding &sharding : operandShardings) {
+  for (const Sharding &sharding : operandShardings) {
     if (sharding) {
-      return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
+      return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
     }
   }
 
-  for (const MeshSharding &sharding : resultShardings) {
+  for (const Sharding &sharding : resultShardings) {
     if (sharding) {
-      return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
+      return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
     }
   }
 
@@ -117,29 +117,29 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
 }
 
 // Choose the operand based on the current process index along the reduction
-// mesh axes.
+// grid axes.
 // We need to use the initial value only once to avoid including it in the
 // reduction multiple times.
 // In each process group only the leading process with linear index 0 would use
 // the original operand.
 // The other processes would use the reduction operation neutral tensor.
 static Value createDestinationPassingStyleInitOperand(
-    LinalgOp op, int operandNumber, Value spmdizedOperand,
-    ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
+    LinalgOp op, int operandNumber, Value partitiondOperand,
+    ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
     ImplicitLocOpBuilder &builder) {
-  Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
-      meshOp.getSymName(), reductionMeshAxes, builder);
+  Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
+      gridOp.getSymName(), reductionGridAxes, builder);
   Value zero = builder.create<arith::ConstantIndexOp>(0);
   Value isLeadProcess = builder.create<arith::CmpIOp>(
       builder.getI1Type(), arith::CmpIPredicate::eq,
       processLinearIndexInReductionGroup, zero);
-  scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
+  scf::IfOp ifOp = builder.create<scf::IfOp>(partitiondOperand.getType(),
                                              isLeadProcess, true, true);
   // Then block.
   {
     OpBuilder::InsertionGuard insertionGuard(builder);
     builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
-    builder.create<scf::YieldOp>(spmdizedOperand);
+    builder.create<scf::YieldOp>(partitiondOperand);
   }
 
   // Else block.
@@ -147,7 +147,7 @@ static Value createDestinationPassingStyleInitOperand(
     OpBuilder::InsertionGuard insertionGuard(builder);
     builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
     SmallVector<OpFoldResult> shape =
-        tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
+        tensor::getMixedSizes(builder, builder.getLoc(), partitiondOperand);
 
     SmallVector<Operation *> combinerOps;
     matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
@@ -167,73 +167,72 @@ static Value createDestinationPassingStyleInitOperand(
   return ifOp.getResult(0);
 }
 
-// Create the DPS init operands for the spmdized Linalg op.
-// Return all the new spmdized operands.
+// Create the DPS init operands for the partitiond Linalg op.
+// Return all the new partitiond operands.
 static SmallVector<Value> createDestinationPassingStyleInitOperands(
-    LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
+    LinalgOp op, GridOp gridOp, ArrayRef<Value> partitiondOperands,
+    ArrayRef<GridAxis> reductionGridAxes, IRMapping &partitionMap,
     ImplicitLocOpBuilder &builder) {
   // TODO: add support for multiple destination passing style initial value
   // operands.
   assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
-  SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
+  SmallVector<Value> newOperands = llvm::to_vector(partitiondOperands);
   auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
-  Value spmdizedInitOperand =
-      spmdizationMap.lookup(op->getOperands()[operandIdx]);
+  Value partitiondInitOperand =
+      partitionMap.lookup(op->getOperands()[operandIdx]);
   newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
-      op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+      op, 0, partitiondInitOperand, reductionGridAxes, gridOp, builder);
   return newOperands;
 }
 
 static void createAllReduceForResultsWithoutPartialShardings(
-    LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
-    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
+    LinalgOp unshardedOp, ArrayRef<GridAxis> opReductionGridAxes,
+    ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
     ImplicitLocOpBuilder &builder) {
   ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
   for (auto [unshardedLinalgOpResult, resultSharding] :
        llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
-    Value spmdizedLinalgOpResult =
-        spmdizationMap.lookup(unshardedLinalgOpResult);
-    Value reducedValue = builder.create<mesh::AllReduceOp>(
-        spmdizedLinalgOpResult, resultSharding.getMesh(), opReductionMeshAxes,
+    Value partitiondLinalgOpResult =
+        partitionMap.lookup(unshardedLinalgOpResult);
+    Value reducedValue = builder.create<shard::AllReduceOp>(
+        partitiondLinalgOpResult, resultSharding.getGrid(), opReductionGridAxes,
         reductionKind);
-    spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
+    partitionMap.map(unshardedLinalgOpResult, reducedValue);
   }
 }
 
-static void spmdizeLinalgOpWithShardedReduction(
-    LinalgOp op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshSharding> operandShardings,
-    ArrayRef<MeshSharding> resultShardings,
+static void partitionLinalgOpWithShardedReduction(
+    LinalgOp op, ArrayRef<Value> partitiondOperands,
+    ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
     ArrayRef<utils::IteratorType> loopIteratorTypes,
-    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
-    IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
+    ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators,
+    IRMapping &partitionMap, SymbolTableCollection &symbolTable,
     ImplicitLocOpBuilder &builder) {
-  MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
-  SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
-      loopIteratorTypes, meshAxisAssignmentForLoopIterators);
-  SmallVector<Value> spmdizedLinalgOpOperands =
-      createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
-                                                reductionMeshAxes,
-                                                spmdizationMap, builder);
-  // We must not change the operand mappings of the original spmdizationMap as
-  // they are the mappings for the whole spmdization blob and may be used by
+  GridOp grid = getGrid(op, operandShardings, resultShardings, symbolTable);
+  SmallVector<GridAxis> reductionGridAxes = shard::getReductionGridAxes(
+      loopIteratorTypes, gridAxisAssignmentForLoopIterators);
+  SmallVector<Value> partitiondLinalgOpOperands =
+      createDestinationPassingStyleInitOperands(op, grid, partitiondOperands,
+                                                reductionGridAxes, partitionMap,
+                                                builder);
+  // We must not change the operand mappings of the original partitionMap as
+  // they are the mappings for the whole partition blob and may be used by
   // others.
-  IRMapping internalSpmdizationMap;
-  for (auto [unshardedOperand, spmdizedOperand] :
-       llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
-    internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
+  IRMapping internalPartitionMap;
+  for (auto [unshardedOperand, partitiondOperand] :
+       llvm::zip_equal(op->getOperands(), partitiondLinalgOpOperands)) {
+    internalPartitionMap.map(unshardedOperand, partitiondOperand);
   }
-  spmdizeTriviallyShardableOperation(
-      *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
-      internalSpmdizationMap, symbolTable, builder);
+  partitionTriviallyShardableOperation(
+      *op, partitiondLinalgOpOperands, operandShardings, resultShardings,
+      internalPartitionMap, symbolTable, builder);
   for (Value result : op->getResults()) {
-    spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
+    partitionMap.map(result, internalPartitionMap.lookup(result));
   }
 
   // Handle partial shardings.
   createAllReduceForResultsWithoutPartialShardings(
-      op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
+      op, reductionGridAxes, resultShardings, partitionMap, builder);
 }
 
 namespace {
@@ -243,7 +242,7 @@ namespace {
 // permutations.
 template <typename Op>
 struct StructuredOpShardingInterface
-    : public mesh::ShardingInterface::ExternalModel<
+    : public shard::ShardingInterface::ExternalModel<
           StructuredOpShardingInterface<Op>, Op> {
   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
@@ -272,16 +271,16 @@ struct StructuredOpShardingInterface
         [](unsigned count, utils::IteratorType iter) {
           return count + (iter == utils::IteratorType::reduction);
         });
-    mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
+    shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
     return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
   }
 
-  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
-                        ArrayRef<MeshSharding> operandShardings,
-                        ArrayRef<MeshSharding> resultShardings,
-                        IRMapping &spmdizationMap,
-                        SymbolTableCollection &symbolTable,
-                        OpBuilder &builder) const {
+  LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+                          ArrayRef<Sharding> operandShardings,
+                          ArrayRef<Sharding> resultShardings,
+                          IRMapping &partitionMap,
+                          SymbolTableCollection &symbolTable,
+                          OpBuilder &builder) const {
     LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
 
     SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
@@ -297,20 +296,20 @@ struct StructuredOpShardingInterface
 
     SmallVector<utils::IteratorType> loopIteratorTypes =
         linalgOp.getIteratorTypesArray();
-    ShardingArray meshAxisAssignmentForLoopIterators =
-        getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
+    ShardingArray gridAxisAssignmentForLoopIterators =
+        getGridAxisAssignmentForLoopIterators(operandShardings, resultShardings,
                                               loopIteratorTypes, indexingMaps);
-    if (mesh::isAtLeastOneReductionIteratorSharded(
-            loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+    if (shard::isAtLeastOneReductionIteratorSharded(
+            loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
       ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
-      spmdizeLinalgOpWithShardedReduction(
-          linalgOp, spmdizedOperands, operandShardings, resultShardings,
-          loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
+      partitionLinalgOpWithShardedReduction(
+          linalgOp, partitiondOperands, operandShardings, resultShardings,
+          loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap,
           symbolTable, implicitLocBuilder);
     } else {
-      spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
-                                         operandShardings, resultShardings,
-                                         spmdizationMap, symbolTable, builder);
+      partitionTriviallyShardableOperation(*op, partitiondOperands,
+                                           operandShardings, resultShardings,
+                                           partitionMap, symbolTable, builder);
     }
 
     return success();
@@ -330,7 +329,7 @@ static void registerAll(MLIRContext *ctx) {
   (registerOne<OpTypes>(ctx), ...);
 }
 
-void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
+void registerShardingInterfaceExternalModels(DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
     DialectRegistry registry;
     registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Shard/CMakeLists.txt
similarity index 100%
rename from mlir/lib/Dialect/Mesh/CMakeLists.txt
rename to mlir/lib/Dialect/Shard/CMakeLists.txt
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Shard/IR/CMakeLists.txt
similarity index 59%
rename from mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
rename to mlir/lib/Dialect/Shard/IR/CMakeLists.txt
index 3fea4d67430e0..70c6049884e12 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/IR/CMakeLists.txt
@@ -1,11 +1,11 @@
-add_mlir_dialect_library(MLIRMeshDialect
-  MeshOps.cpp
+add_mlir_dialect_library(MLIRShardDialect
+  ShardOps.cpp
 
   ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard
 
   DEPENDS
-  MLIRMeshIncGen
+  MLIRShardIncGen
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
similarity index 76%
rename from mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
rename to mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index 28608cb0dd96c..df2fcf4c2f4c8 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -1,4 +1,4 @@
-//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
+//===- ShardOps.cpp - Shard Dialect Operations ----------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -37,13 +37,13 @@
 #include <optional>
 #include <utility>
 
-#define DEBUG_TYPE "mesh-ops"
+#define DEBUG_TYPE "shard-ops"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 
 using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
 
-#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
 
 namespace {
 
@@ -74,11 +74,10 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
   return lhs.value() * rhs.value();
 }
 
-SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
-                                                const Location &loc,
-                                                llvm::ArrayRef<int64_t> statics,
-                                                ValueRange dynamics,
-                                                Type type) {
+SmallVector<Value>
+mlir::shard::getMixedAsValues(OpBuilder b, const Location &loc,
+                              llvm::ArrayRef<int64_t> statics,
+                              ValueRange dynamics, Type type) {
   SmallVector<Value> values;
   auto dyn = dynamics.begin();
   Type i64 = b.getI64Type();
@@ -102,7 +101,7 @@ SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct MeshInlinerInterface : public DialectInlinerInterface {
+struct ShardInlinerinterface : public DialectInlinerInterface {
   using DialectInlinerInterface::DialectInlinerInterface;
   // Currently no restrictions are encoded for inlining.
   bool isLegalToInline(Operation *, Operation *, bool) const final {
@@ -118,44 +117,45 @@ struct MeshInlinerInterface : public DialectInlinerInterface {
 } // namespace
 
 //===----------------------------------------------------------------------===//
-// Mesh dialect
+// Shard dialect
 //===----------------------------------------------------------------------===//
 
-void MeshDialect::initialize() {
+void ShardDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
-#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
       >();
   addAttributes<
 #define GET_ATTRDEF_LIST
-#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
       >();
   addTypes<
 #define GET_TYPEDEF_LIST
-#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
       >();
-  addInterface<MeshInlinerInterface>();
+  addInterface<ShardInlinerinterface>();
 }
 
-Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
-                                            Type type, Location loc) {
+Operation *ShardDialect::materializeConstant(OpBuilder &builder,
+                                             Attribute value, Type type,
+                                             Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 //===----------------------------------------------------------------------===//
-// Mesh utilities
+// Shard utilities
 //===----------------------------------------------------------------------===//
 
-static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
-                                          FlatSymbolRefAttr meshSymbol,
+static FailureOr<GridOp> getGridAndVerify(Operation *op,
+                                          FlatSymbolRefAttr gridSymbol,
                                           SymbolTableCollection &symbolTable) {
-  mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
-  if (!mesh) {
-    return op->emitError() << "Undefined required mesh symbol \""
-                           << meshSymbol.getValue() << "\".";
+  shard::GridOp grid = getGridOrNull(op, gridSymbol, symbolTable);
+  if (!grid) {
+    return op->emitError() << "Undefined required grid symbol \""
+                           << gridSymbol.getValue() << "\".";
   }
 
-  return mesh;
+  return grid;
 }
 
 template <typename It>
@@ -175,20 +175,20 @@ bool isUnique(It begin, It end) {
   return true;
 }
 
-static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
-                                    MeshOp mesh) {
-  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+static LogicalResult verifyGridAxes(Location loc, ArrayRef<GridAxis> axes,
+                                    GridOp grid) {
+  SmallVector<GridAxis> sorted = llvm::to_vector(axes);
   llvm::sort(sorted);
   if (!isUnique(sorted.begin(), sorted.end())) {
-    return emitError(loc) << "Mesh axes contains duplicate elements.";
+    return emitError(loc) << "Grid axes contains duplicate elements.";
   }
 
-  MeshAxis rank = mesh.getRank();
+  GridAxis rank = grid.getRank();
   for (auto axis : axes) {
     if (axis >= rank || axis < 0) {
       return emitError(loc)
-             << "0-based mesh axis index " << axis
-             << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+             << "0-based grid axis index " << axis
+             << " is out of bounds. The referenced grid \"" << grid.getSymName()
              << "\" is of rank " << rank << ".";
     }
   }
@@ -197,22 +197,22 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
 }
 
 template <typename Op>
-static FailureOr<MeshOp>
-getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
-  auto mesh =
-      ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
-  if (failed(mesh)) {
+static FailureOr<GridOp>
+getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
+  auto grid =
+      ::getGridAndVerify(op.getOperation(), op.getGridAttr(), symbolTable);
+  if (failed(grid)) {
     return failure();
   }
-  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
+  if (failed(verifyGridAxes(op.getLoc(), op.getGridAxes(), grid.value()))) {
     return failure();
   }
-  return mesh;
+  return grid;
 }
 
-template <typename InShape, typename MeshShape, typename SplitAxes,
+template <typename InShape, typename GridShape, typename SplitAxes,
           typename OutShape>
-static void shardShape(const InShape &inShape, const MeshShape &meshShape,
+static void shardShape(const InShape &inShape, const GridShape &gridShape,
                        const SplitAxes &splitAxes, OutShape &outShape,
                        ArrayRef<int64_t> shardedDimsOffsets = {},
                        ArrayRef<int64_t> haloSizes = {}) {
@@ -226,7 +226,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
             llvm::adl_begin(outShape));
 
   if (!shardedDimsOffsets.empty()) {
-    auto isDynShape = ShapedType::isDynamicShape(meshShape);
+    auto isDynShape = ShapedType::isDynamicShape(gridShape);
     uint64_t pos = 1;
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
       if (!innerSplitAxes.empty()) {
@@ -238,7 +238,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
           // non-uniform offs in shardedDimsOffsets.
           uint64_t numShards = 0;
           for (auto i : innerSplitAxes.asArrayRef()) {
-            numShards += meshShape[i];
+            numShards += gridShape[i];
           }
           for (size_t i = 1; i < numShards; ++i) {
             if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
@@ -256,7 +256,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
       outShape[tensorAxis] = shardDimension(
           inShape[tensorAxis],
-          collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
+          collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), gridShape));
     }
 
     if (!haloSizes.empty()) {
@@ -279,25 +279,25 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
   }
 }
 
-ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
-                                 MeshSharding sharding) {
+ShapedType shard::shardShapedType(ShapedType shape, GridOp grid,
+                                  Sharding sharding) {
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
-  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+  shardShape(shape.getShape(), grid.getShape(), sharding.getSplitAxes(),
              resShapeArr, sharding.getStaticShardedDimsOffsets(),
              sharding.getStaticHaloSizes());
   return shape.clone(resShapeArr);
 }
 
-Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
+Type shard::shardType(Type type, GridOp grid, Sharding sharding) {
   RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
   if (rankedTensorType && !rankedTensorType.getShape().empty()) {
-    return shardShapedType(rankedTensorType, mesh, sharding);
+    return shardShapedType(rankedTensorType, grid, sharding);
   }
   return type;
 }
 
-static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
+static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding,
                                                     Value &operandValue,
                                                     Operation *operandOp,
                                                     OpBuilder &builder,
@@ -336,9 +336,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
   newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
 }
 
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                                     OpResult result,
-                                                     OpBuilder &builder) {
+void mlir::shard::maybeInsertTargetShardingAnnotation(Sharding sharding,
+                                                      OpResult result,
+                                                      OpBuilder &builder) {
   ShardOp newShardOp;
   SmallVector<std::pair<Value, Operation *>> uses;
   for (auto &use : result.getUses()) {
@@ -350,9 +350,9 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
   }
 }
 
-void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
-                                                     OpOperand &operand,
-                                                     OpBuilder &builder) {
+void mlir::shard::maybeInsertSourceShardingAnnotation(Sharding sharding,
+                                                      OpOperand &operand,
+                                                      OpBuilder &builder) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Value operandValue = operand.get();
   Operation *operandSrcOp = operandValue.getDefiningOp();
@@ -404,18 +404,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.mesh op
+// shard.grid op
 //===----------------------------------------------------------------------===//
 
-LogicalResult MeshOp::verify() {
+LogicalResult GridOp::verify() {
   int64_t rank = getRank();
 
   if (rank <= 0)
-    return emitOpError("rank of mesh is expected to be a positive integer");
+    return emitOpError("rank of grid is expected to be a positive integer");
 
   for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && ShapedType::isStatic(dimSize))
-      return emitOpError("dimension size of a mesh is expected to be "
+      return emitOpError("dimension size of a grid is expected to be "
                          "non-negative or dynamic");
   }
 
@@ -423,21 +423,21 @@ LogicalResult MeshOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.mesh_shape op
+// shard.grid_shape op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
-  if (failed(mesh)) {
+GridShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+  if (failed(grid)) {
     return failure();
   }
-  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+  if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
     return failure();
   }
 
   size_t expectedResultsCount =
-      getAxes().empty() ? mesh->getRank() : getAxes().size();
+      getAxes().empty() ? grid->getRank() : getAxes().size();
   if (getResult().size() != expectedResultsCount) {
     return emitError() << "Unexpected number of results " << getResult().size()
                        << ". Expected " << expectedResultsCount << ".";
@@ -446,53 +446,53 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
-void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                        MeshOp mesh) {
-  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        GridOp grid) {
+  build(odsBuilder, odsState, grid, SmallVector<GridAxis>());
 }
 
-void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                        MeshOp mesh, ArrayRef<MeshAxis> axes) {
+void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        GridOp grid, ArrayRef<GridAxis> axes) {
   build(odsBuilder, odsState,
-        SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+        SmallVector<Type>(axes.empty() ? grid.getRank() : axes.size(),
                           odsBuilder.getIndexType()),
-        mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
+        grid.getSymName(), GridAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
-void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                        StringRef mesh, ArrayRef<MeshAxis> axes) {
+void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        StringRef grid, ArrayRef<GridAxis> axes) {
   assert(!axes.empty());
   build(odsBuilder, odsState,
-        SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
-        MeshAxesAttr::get(odsBuilder.getContext(), axes));
+        SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
+        GridAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
-void MeshShapeOp::getAsmResultNames(
+void GridShapeOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
-  setNameFn(getResults()[0], "mesh_shape");
+  setNameFn(getResults()[0], "grid_shape");
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.sharding
+// shard.sharding
 //===----------------------------------------------------------------------===//
 
 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
-                       FlatSymbolRefAttr mesh,
-                       ArrayRef<MeshAxesAttr> split_axes,
+                       FlatSymbolRefAttr grid,
+                       ArrayRef<GridAxesAttr> split_axes,
                        ArrayRef<int64_t> static_halos,
                        ArrayRef<int64_t> static_offsets) {
   return build(
-      b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
+      b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
 }
 
 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
-                       llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
+                       llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes,
                        ArrayRef<int64_t> static_halos,
                        ArrayRef<int64_t> static_offsets) {
-  return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
-               MeshAxesArrayAttr::get(b.getContext(), split_axes),
+  return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
+               GridAxesArrayAttr::get(b.getContext(), split_axes),
                ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
                ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
                {});
@@ -500,7 +500,7 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
 
 void ShardingOp::build(
     ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
-    FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
+    FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes,
     ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
     ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
   mlir::SmallVector<int64_t> staticHalos, staticDims;
@@ -508,16 +508,16 @@ void ShardingOp::build(
   dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
   dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
   return build(
-      b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
+      b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
       ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
 }
 
 void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
-                       mlir::mesh::MeshSharding from) {
+                       mlir::shard::Sharding from) {
 
-  build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
-        MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
+  build(b, odsState, ShardingType::get(b.getContext()), from.getGridAttr(),
+        GridAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
         from.getStaticShardedDimsOffsets().empty()
             ? DenseI64ArrayAttr()
             : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
@@ -529,21 +529,21 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
 }
 
 LogicalResult ShardingOp::verify() {
-  llvm::SmallSet<MeshAxis, 4> visitedAxes;
+  llvm::SmallSet<GridAxis, 4> visitedAxes;
 
-  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
-    for (MeshAxis axis : axesArray) {
+  auto checkGridAxis = [&](ArrayRef<GridAxis> axesArray) -> LogicalResult {
+    for (GridAxis axis : axesArray) {
       if (axis < 0)
-        return emitError() << "mesh axis is expected to be non-negative";
+        return emitError() << "grid axis is expected to be non-negative";
       if (!visitedAxes.insert(axis).second)
-        return emitError() << "mesh axis duplicated";
+        return emitError() << "grid axis duplicated";
     }
     return success();
   };
 
   for (auto subAxes : getSplitAxes().getAxes()) {
-    ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
-    if (failed(checkMeshAxis(subAxesArray)))
+    ArrayRef<GridAxis> subAxesArray = subAxes.asArrayRef();
+    if (failed(checkGridAxis(subAxesArray)))
       return failure();
   }
 
@@ -572,26 +572,26 @@ void ShardingOp::getAsmResultNames(
 }
 
 LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
-  if (failed(mesh)) {
+  auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+  if (failed(grid)) {
     return failure();
   }
-  if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
+  if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
       getStaticShardedDimsOffsets().size() > 0) {
     return emitError() << "sharded dims offsets are not allowed for "
-                          "devices meshes with dynamic shape.";
+                          "device grids with dynamic shape.";
   }
 
   auto shardedDimsOffsets = getStaticShardedDimsOffsets();
   if (!shardedDimsOffsets.empty()) {
-    auto meshShape = mesh.value().getShape();
-    assert(ShapedType::isStaticShape(meshShape));
+    auto gridShape = grid.value().getShape();
+    assert(ShapedType::isStaticShape(gridShape));
     uint64_t pos = 0;
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
       if (!innerSplitAxes.empty()) {
         int64_t numShards = 0, off = 0;
         for (auto i : innerSplitAxes.asArrayRef()) {
-          numShards += meshShape[i];
+          numShards += gridShape[i];
         }
         for (int64_t i = 0; i <= numShards; ++i) {
           if (shardedDimsOffsets.size() <= pos + i) {
@@ -684,11 +684,11 @@ void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
 }
 
 //===----------------------------------------------------------------------===//
-// MeshSharding
+// Sharding
 //===----------------------------------------------------------------------===//
 
-bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const {
-  if (getMesh() != rhs.getMesh()) {
+bool Sharding::equalSplitAxes(const Sharding &rhs) const {
+  if (getGrid() != rhs.getGrid()) {
     return false;
   }
 
@@ -701,16 +701,16 @@ bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const {
   }
 
   return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize),
-                      std::mem_fn(&MeshAxesAttr::empty)) &&
+                      std::mem_fn(&GridAxesAttr::empty)) &&
          llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize),
-                      std::mem_fn(&MeshAxesAttr::empty));
+                      std::mem_fn(&GridAxesAttr::empty));
 }
 
-bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
+bool Sharding::equalHaloAndShardSizes(const Sharding &rhs) const {
   return equalShardSizes(rhs) && equalHaloSizes(rhs);
 }
 
-bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
+bool Sharding::equalShardSizes(const Sharding &rhs) const {
   if (rhs.getStaticShardedDimsOffsets().size() !=
           getStaticShardedDimsOffsets().size() ||
       !llvm::equal(getStaticShardedDimsOffsets(),
@@ -726,7 +726,7 @@ bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
   return true;
 }
 
-bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
+bool Sharding::equalHaloSizes(const Sharding &rhs) const {
   if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
       !llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) {
     return false;
@@ -738,45 +738,43 @@ bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
   return true;
 }
 
-bool MeshSharding::operator==(Value rhs) const {
+bool Sharding::operator==(Value rhs) const {
   return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
 }
 
-bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
+bool Sharding::operator!=(Value rhs) const { return !(*this == rhs); }
 
-bool MeshSharding::operator==(const MeshSharding &rhs) const {
+bool Sharding::operator==(const Sharding &rhs) const {
   return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
 }
 
-bool MeshSharding::operator!=(const MeshSharding &rhs) const {
-  return !(*this == rhs);
-}
+bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
 
-MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
+Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {}
 
-MeshSharding::MeshSharding(Value rhs) {
+Sharding::Sharding(Value rhs) {
   auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
   assert(shardingOp && "expected sharding op");
   auto splitAxes = shardingOp.getSplitAxes().getAxes();
   // If splitAxes are empty, use "empty" constructor.
   if (splitAxes.empty()) {
-    *this = MeshSharding(shardingOp.getMeshAttr());
+    *this = Sharding(shardingOp.getGridAttr());
     return;
   }
   *this =
-      get(shardingOp.getMeshAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
+      get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
           shardingOp.getStaticShardedDimsOffsets(),
           SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
           SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
 }
 
-MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
-                               ArrayRef<MeshAxesAttr> split_axes_,
-                               ArrayRef<int64_t> static_halo_sizes_,
-                               ArrayRef<int64_t> static_sharded_dims_offsets_,
-                               ArrayRef<Value> dynamic_halo_sizes_,
-                               ArrayRef<Value> dynamic_sharded_dims_offsets_) {
-  MeshSharding res(mesh_);
+Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
+                       ArrayRef<GridAxesAttr> split_axes_,
+                       ArrayRef<int64_t> static_halo_sizes_,
+                       ArrayRef<int64_t> static_sharded_dims_offsets_,
+                       ArrayRef<Value> dynamic_halo_sizes_,
+                       ArrayRef<Value> dynamic_sharded_dims_offsets_) {
+  Sharding res(grid_);
   if (split_axes_.empty()) {
     return res;
   }
@@ -784,7 +782,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
   res.split_axes.resize(split_axes_.size());
   for (auto [i, axis] : llvm::enumerate(split_axes_)) {
     res.split_axes[i] =
-        MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
+        GridAxesAttr::get(grid_.getContext(), axis.asArrayRef());
   }
 
   auto clone = [](const auto src, auto &dst) {
@@ -801,7 +799,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.shard_shape
+// shard.shard_shape
 //===----------------------------------------------------------------------===//
 
 void ShardShapeOp::getAsmResultNames(
@@ -820,7 +818,7 @@ void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.shard op
+// shard.shard op
 //===----------------------------------------------------------------------===//
 
 void ShardOp::getAsmResultNames(
@@ -850,10 +848,10 @@ class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
         if (!otherOp || !otherOp->isBeforeInBlock(op)) {
           return failure();
         }
-        // Create a MeshSharding object for the current and the other ShardOp
+        // Create a Sharding object for the current and the other ShardOp
         // If the two are equal replace current op with the other op.
-        MeshSharding currentSharding(op.getSharding());
-        MeshSharding otherSharding(otherOp.getSharding());
+        Sharding currentSharding(op.getSharding());
+        Sharding otherSharding(otherOp.getSharding());
         if (currentSharding == otherSharding) {
           b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
           b.eraseOp(op.getOperation());
@@ -876,21 +874,21 @@ void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.process_multi_index op
+// shard.process_multi_index op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
-  if (failed(mesh)) {
+  auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+  if (failed(grid)) {
     return failure();
   }
-  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+  if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
     return failure();
   }
 
   size_t expectedResultsCount =
-      getAxes().empty() ? mesh->getRank() : getAxes().size();
+      getAxes().empty() ? grid->getRank() : getAxes().size();
   if (getResult().size() != expectedResultsCount) {
     return emitError() << "Unexpected number of results " << getResult().size()
                        << ". Expected " << expectedResultsCount << ".";
@@ -900,17 +898,17 @@ ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 }
 
 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                                MeshOp mesh) {
+                                GridOp grid) {
   build(odsBuilder, odsState,
-        SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
-        mesh.getSymName(), ArrayRef<MeshAxis>());
+        SmallVector<Type>(grid.getRank(), odsBuilder.getIndexType()),
+        grid.getSymName(), ArrayRef<GridAxis>());
 }
 
 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                                StringRef mesh, ArrayRef<MeshAxis> axes) {
+                                StringRef grid, ArrayRef<GridAxis> axes) {
   build(odsBuilder, odsState,
-        SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
-        MeshAxesAttr::get(odsBuilder.getContext(), axes));
+        SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
+        GridAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
 void ProcessMultiIndexOp::getAsmResultNames(
@@ -919,21 +917,21 @@ void ProcessMultiIndexOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.process_linear_index op
+// shard.process_linear_index op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
-  if (failed(mesh)) {
+  auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+  if (failed(grid)) {
     return failure();
   }
   return success();
 }
 
 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
-                                 OperationState &odsState, MeshOp mesh) {
-  build(odsBuilder, odsState, mesh.getSymName());
+                                 OperationState &odsState, GridOp grid) {
+  build(odsBuilder, odsState, grid.getSymName());
 }
 
 void ProcessLinearIndexOp::getAsmResultNames(
@@ -942,13 +940,13 @@ void ProcessLinearIndexOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.neighbors_linear_indices op
+// shard.neighbors_linear_indices op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
 NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
-  if (failed(mesh)) {
+  auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+  if (failed(grid)) {
     return failure();
   }
   return success();
@@ -967,12 +965,12 @@ void NeighborsLinearIndicesOp::getAsmResultNames(
 namespace {
 
 template <typename Op>
-struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
+struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
   using OpRewritePattern<Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(Op op,
                                 PatternRewriter &rewriter) const override {
-    auto meshAxes = op.getMeshAxes();
-    if (!meshAxes.empty()) {
+    auto gridAxes = op.getGridAxes();
+    if (!gridAxes.empty()) {
       return failure();
     }
     if (op.getInput().getType() != op.getResult().getType()) {
@@ -990,24 +988,24 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
 static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
                                          ArrayRef<int64_t> device,
                                          Operation::operand_range deviceDynamic,
-                                         ArrayRef<MeshAxis> meshAxes,
-                                         ArrayRef<int64_t> meshShape) {
-  if (device.size() != meshAxes.size()) {
+                                         ArrayRef<GridAxis> gridAxes,
+                                         ArrayRef<int64_t> gridShape) {
+  if (device.size() != gridAxes.size()) {
     return emitError(loc) << "In-group device \"" << deviceName
                           << "\" has unexpected multi-index size "
-                          << device.size() << ". Expected " << meshAxes.size()
+                          << device.size() << ". Expected " << gridAxes.size()
                           << ".";
   }
 
   for (size_t i = 0; i < device.size(); ++i) {
     if (ShapedType::isStatic(device[i]) &&
-        ShapedType::isStatic(meshShape[meshAxes[i]]) &&
-        meshShape[meshAxes[i]] <= device[i]) {
+        ShapedType::isStatic(gridShape[gridAxes[i]]) &&
+        gridShape[gridAxes[i]] <= device[i]) {
       return emitError(loc)
              << "Out of bounds coordinate " << i << " for in-group device \""
              << deviceName << "\"."
              << " Got " << device[i] << ", but expected value in the range [0, "
-             << (meshShape[meshAxes[i]] - 1) << "].";
+             << (gridShape[gridAxes[i]] - 1) << "].";
     }
   }
   return success();
@@ -1043,7 +1041,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
 
 static LogicalResult verifyGatherOperandAndResultShape(
     Value operand, Value result, int64_t gatherAxis,
-    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+    ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
   auto resultRank = cast<ShapedType>(result.getType()).getRank();
   if (gatherAxis < 0 || gatherAxis >= resultRank) {
     return emitError(result.getLoc())
@@ -1054,7 +1052,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
   ShapedType operandType = cast<ShapedType>(operand.getType());
   ShapedType resultType = cast<ShapedType>(result.getType());
   auto deviceGroupSize =
-      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
     auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
     auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
@@ -1070,7 +1068,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
 
 static LogicalResult verifyAllToAllOperandAndResultShape(
     Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
-    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+    ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
   ShapedType operandType = cast<ShapedType>(operand.getType());
   ShapedType resultType = cast<ShapedType>(result.getType());
   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
@@ -1088,7 +1086,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
   auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
   auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
   DimensionSize expectedResultConcatDimSize =
@@ -1115,7 +1113,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
 
 static LogicalResult verifyScatterOrSliceOperandAndResultShape(
     Value operand, Value result, int64_t tensorAxis,
-    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+    ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
   ShapedType operandType = cast<ShapedType>(operand.getType());
   ShapedType resultType = cast<ShapedType>(result.getType());
   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
@@ -1129,7 +1127,7 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
   auto operandScatterDimSize =
       DimensionSize(operandType.getDimSize(tensorAxis));
   if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
@@ -1151,8 +1149,8 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape(
   return success();
 }
 
-static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
-                                        ArrayRef<MeshAxis> meshAxes,
+static RankedTensorType sliceResultType(Type operandType, GridOp grid,
+                                        ArrayRef<GridAxis> gridAxes,
                                         int64_t sliceAxis) {
   RankedTensorType operandRankedTensorType =
       cast<RankedTensorType>(operandType);
@@ -1163,29 +1161,29 @@ static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
 
   resultShape[sliceAxis] =
       operandSliceAxisSize /
-      DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+      DimensionSize(collectiveProcessGroupSize(gridAxes, grid));
   return operandRankedTensorType.clone(resultShape);
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.all_gather op
+// shard.all_gather op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
   auto gatherAxis = getGatherAxis().getSExtValue();
   return verifyGatherOperandAndResultShape(getOperand(), getResult(),
-                                           gatherAxis, getMeshAxes(),
-                                           mesh.value().getShape());
+                                           gatherAxis, getGridAxes(),
+                                           grid.value().getShape());
 }
 
 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
 }
 
 void AllGatherOp::getAsmResultNames(
@@ -1194,23 +1192,23 @@ void AllGatherOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.all_reduce op
+// shard.all_reduce op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  return getMeshAndVerifyAxes(*this, symbolTable);
+  return getGridAndVerifyAxes(*this, symbolTable);
 }
 
 void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
 }
 
 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                        Value input, StringRef mesh,
-                        ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
-  build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
+                        Value input, StringRef grid,
+                        ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
+  build(odsBuilder, odsState, input.getType(), grid, gridAxes, input,
         reduction);
 }
 
@@ -1220,36 +1218,36 @@ void AllReduceOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.all_slice op
+// shard.all_slice op
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
   return verifyScatterOrSliceOperandAndResultShape(
-      getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
-      mesh.value().getShape());
+      getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
+      grid.value().getShape());
 }
 
 void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                              MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
 }
 
 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                       Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+                       Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
                        int64_t sliceAxis) {
-  Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
-  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+  Type resultType = sliceResultType(input.getType(), grid, gridAxes, sliceAxis);
+  build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
         sliceAxis);
 }
 
 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                       Type resultType, Value input, StringRef mesh,
-                       ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
-  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+                       Type resultType, Value input, StringRef grid,
+                       ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
+  build(odsBuilder, odsState, resultType, grid, gridAxes, input,
         APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
 }
 
@@ -1259,23 +1257,23 @@ void AllSliceOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.all_to_all op
+// shard.all_to_all op
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
 
   return verifyAllToAllOperandAndResultShape(
       getOperand(), getResult(), getSplitAxis().getSExtValue(),
-      getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
+      getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
 }
 
 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                              MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
 }
 
 void AllToAllOp::getAsmResultNames(
@@ -1284,18 +1282,18 @@ void AllToAllOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.broadcast op
+// shard.broadcast op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(),
-                                 mesh.value().getShape()))) {
+                                 getRootDynamic(), getGridAxes(),
+                                 grid.value().getShape()))) {
     return failure();
   }
 
@@ -1304,7 +1302,7 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
 }
 
 void BroadcastOp::getAsmResultNames(
@@ -1313,29 +1311,29 @@ void BroadcastOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.gather op
+// shard.gather op
 //===----------------------------------------------------------------------===//
 
 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(),
-                                 mesh.value().getShape()))) {
+                                 getRootDynamic(), getGridAxes(),
+                                 grid.value().getShape()))) {
     return failure();
   }
 
   auto gatherAxis = getGatherAxis().getSExtValue();
   return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
-                                           getMeshAxes(),
-                                           mesh.value().getShape());
+                                           getGridAxes(),
+                                           grid.value().getShape());
 }
 
 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                            MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
 }
 
 void GatherOp::getAsmResultNames(
@@ -1344,18 +1342,18 @@ void GatherOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.recv op
+// shard.recv op
 //===----------------------------------------------------------------------===//
 
 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
   if (getSource() &&
       failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
                                  getSource().value(), getSourceDynamic(),
-                                 getMeshAxes(), mesh.value().getShape()))) {
+                                 getGridAxes(), grid.value().getShape()))) {
     return failure();
   }
   return success();
@@ -1363,7 +1361,7 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
 void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                          MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
 }
 
 void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1371,17 +1369,17 @@ void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.reduce op
+// shard.reduce op
 //===----------------------------------------------------------------------===//
 
 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(),
-                                 mesh.value().getShape()))) {
+                                 getRootDynamic(), getGridAxes(),
+                                 grid.value().getShape()))) {
     return failure();
   }
 
@@ -1390,7 +1388,7 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
 void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                            MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
 }
 
 void ReduceOp::getAsmResultNames(
@@ -1399,24 +1397,24 @@ void ReduceOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.reduce_scatter op
+// shard.reduce_scatter op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
 
   return verifyScatterOrSliceOperandAndResultShape(
-      getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
-      mesh.value().getShape());
+      getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
+      grid.value().getShape());
 }
 
 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                   MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
 }
 
 void ReduceScatterOp::getAsmResultNames(
@@ -1425,29 +1423,29 @@ void ReduceScatterOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.scatter op
+// shard.scatter op
 //===----------------------------------------------------------------------===//
 
 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(),
-                                 mesh.value().getShape()))) {
+                                 getRootDynamic(), getGridAxes(),
+                                 grid.value().getShape()))) {
     return failure();
   }
 
   auto scatterAxis = getScatterAxis().getSExtValue();
   return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
-                                                   scatterAxis, getMeshAxes(),
-                                                   mesh.value().getShape());
+                                                   scatterAxis, getGridAxes(),
+                                                   grid.value().getShape());
 }
 
 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
 }
 
 void ScatterOp::getAsmResultNames(
@@ -1456,17 +1454,17 @@ void ScatterOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.send op
+// shard.send op
 //===----------------------------------------------------------------------===//
 
 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
   if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
                                  getDestination(), getDestinationDynamic(),
-                                 getMeshAxes(), mesh.value().getShape()))) {
+                                 getGridAxes(), grid.value().getShape()))) {
     return failure();
   }
   return success();
@@ -1474,7 +1472,7 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
 void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                          MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
+  patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
 }
 
 void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1482,20 +1480,20 @@ void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.shift op
+// shard.shift op
 //===----------------------------------------------------------------------===//
 
 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerifyAxes(*this, symbolTable);
+  if (failed(grid)) {
     return failure();
   }
 
-  auto meshAxes = getMeshAxes();
+  auto gridAxes = getGridAxes();
   auto shiftAxis = getShiftAxis().getZExtValue();
-  if (!llvm::is_contained(meshAxes, shiftAxis)) {
+  if (!llvm::is_contained(gridAxes, shiftAxis)) {
     return emitError() << "Invalid shift axis " << shiftAxis
-                       << ". It must be one of the grouping mesh axes.";
+                       << ". It must be one of the grouping grid axes.";
   }
 
   return success();
@@ -1504,7 +1502,7 @@ LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                           MLIRContext *context) {
   // TODO: remove op when offset is 0 or if it is a rotate with and
-  // offset % shift_axis_mesh_dim_size == 0.
+  // offset % shift_axis_grid_dim_size == 0.
 }
 
 void ShiftOp::getAsmResultNames(
@@ -1513,13 +1511,13 @@ void ShiftOp::getAsmResultNames(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.update_halo op
+// shard.update_halo op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
 UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
-  if (failed(mesh)) {
+  auto grid = getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+  if (failed(grid)) {
     return failure();
   }
 
@@ -1531,12 +1529,12 @@ UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 #define GET_OP_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
 
 #define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
 
-#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt
similarity index 76%
rename from mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
rename to mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt
index afe76b539846a..01e8e56dd391d 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_library(MLIRShardingInterface
   ShardingInterface.cpp
 
   ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard
 
   DEPENDS
   MLIRShardingInterfaceIncGen
@@ -10,7 +10,7 @@ add_mlir_library(MLIRShardingInterface
   LINK_LIBS PUBLIC
   MLIRDialectUtils
   MLIRIR
-  MLIRMeshDialect
+  MLIRShardDialect
   MLIRTensorDialect
   MLIRSupport
 )
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
similarity index 70%
rename from mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
rename to mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
index 6b3d49e08b549..d4e76189f7b8a 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
@@ -6,10 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
 
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Support/LLVM.h"
@@ -24,9 +24,9 @@
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 
 using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
 
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc"
 
 //===----------------------------------------------------------------------===//
 // common util functions
@@ -93,40 +93,39 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
 }
 
 template <typename T>
-SmallVector<MeshAxesAttr>
+SmallVector<GridAxesAttr>
 fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {
-  SmallVector<MeshAxesAttr> res;
+  SmallVector<GridAxesAttr> res;
   for (const auto &v : vec) {
-    res.emplace_back(MeshAxesAttr::get(ctxt, v));
+    res.emplace_back(GridAxesAttr::get(ctxt, v));
   }
   return res;
 }
 
 //===----------------------------------------------------------------------===//
-// mesh::getMeshSharding
+// shard::getSharding
 //===----------------------------------------------------------------------===//
 
-FailureOr<std::pair<bool, MeshSharding>>
-mesh::getMeshSharding(OpResult result) {
+FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpResult result) {
   Value val = cast<Value>(result);
   bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
-    auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+    auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
     if (!shardOp)
       return false;
     return !shardOp.getAnnotateForUsers();
   });
 
   if (anyShardedForDef) {
-    // expected to have exact one use if it has a use of `mesh.shard` without
+    // expected to have exact one use if it has a use of `shard.shard` without
     // unit attr annotate_for_users
     if (!val.hasOneUse())
       return failure();
-    auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
-    return std::make_pair(false, MeshSharding(shardOp.getSharding()));
+    auto shardOp = llvm::cast<shard::ShardOp>(*val.getUsers().begin());
+    return std::make_pair(false, Sharding(shardOp.getSharding()));
   }
 
   bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
-    auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+    auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
     if (!shardOp)
       return false;
     return shardOp.getAnnotateForUsers();
@@ -138,24 +137,23 @@ mesh::getMeshSharding(OpResult result) {
       if (shardOp)
         shardOps.push_back(shardOp);
     }
-    MeshSharding shardForDef = shardOps[0].getSharding();
+    Sharding shardForDef = shardOps[0].getSharding();
     for (size_t i = 1; i < shardOps.size(); ++i) {
-      // TODO: Deduce a reasonable mesh sharding attr for def when they are
+      // TODO: Deduce a reasonable grid sharding attr for def when they are
       // different
       assert(shardForDef == shardOps[i].getSharding() &&
-             "only support all shard ops have the same mesh sharding attr");
+             "only support all shard ops have the same grid sharding attr");
     }
     return std::make_pair(true, shardForDef);
   }
   return failure();
 }
 
-FailureOr<std::pair<bool, MeshSharding>>
-mesh::getMeshSharding(OpOperand &opOperand) {
+FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpOperand &opOperand) {
   Value val = opOperand.get();
   if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
     return std::make_pair(shardOp.getAnnotateForUsers(),
-                          MeshSharding(shardOp.getSharding()));
+                          Sharding(shardOp.getSharding()));
 
   return failure();
 }
@@ -164,7 +162,7 @@ mesh::getMeshSharding(OpOperand &opOperand) {
 // ShardingInterface::verifyShardingInterfaceImpl
 //===----------------------------------------------------------------------===//
 
-LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
+LogicalResult shard::ShardingInterface::verifyShardingInterfaceImpl() {
   Operation *op = getOperation();
 
   // check operands and results type
@@ -201,7 +199,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
 // ShardingInterface::printLoopTypesAndIndexingMaps
 //===----------------------------------------------------------------------===//
 
-void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
+void shard::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
   os << "print loop types and indexing maps for: \n";
   getOperation()->print(os);
   os << "\n";
@@ -222,15 +220,15 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
 
 namespace {
 
-// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
+// Update the given `shardingOption` according to `gridAxes` and `loopIdx`
 static LogicalResult fillShardingOption(Operation *op,
                                         ShardingOption &shardingOption,
-                                        FlatSymbolRefAttr mesh,
-                                        ArrayRef<MeshAxis> meshAxes,
+                                        FlatSymbolRefAttr grid,
+                                        ArrayRef<GridAxis> gridAxes,
                                         unsigned loopIdx) {
-  if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
+  if ((shardingOption.grid && grid && shardingOption.grid != grid) ||
       (!shardingOption.shardingArray[loopIdx].empty() &&
-       shardingOption.shardingArray[loopIdx] != meshAxes)) {
+       shardingOption.shardingArray[loopIdx] != gridAxes)) {
     LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
                       << loopIdx << "\n");
     return failure();
@@ -239,28 +237,28 @@ static LogicalResult fillShardingOption(Operation *op,
     if (i == loopIdx)
       continue;
 
-    for (MeshAxis axis : meshAxes) {
+    for (GridAxis axis : gridAxes) {
       if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
-        LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
+        LLVM_DEBUG(DBGS() << "sharding option conflicts because grid axes "
                           << axis << " duplicate");
         return failure();
       }
     }
   }
-  if (mesh)
-    shardingOption.mesh = mesh;
+  if (grid)
+    shardingOption.grid = grid;
   if (shardingOption.shardingArray[loopIdx].empty())
-    shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
-                                                 meshAxes.end());
+    shardingOption.shardingArray[loopIdx].append(gridAxes.begin(),
+                                                 gridAxes.end());
   return success();
 }
 
 } // namespace
 
 FailureOr<ShardingOption>
-mesh::detail::defaultGetShardingOption(Operation *op,
-                                       ArrayRef<MeshSharding> operandShardings,
-                                       ArrayRef<MeshSharding> resultShardings) {
+shard::detail::defaultGetShardingOption(Operation *op,
+                                        ArrayRef<Sharding> operandShardings,
+                                        ArrayRef<Sharding> resultShardings) {
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
   ShardingOption shardingOption;
 
@@ -276,25 +274,25 @@ mesh::detail::defaultGetShardingOption(Operation *op,
 
   // 1. Fill sharding option based on op results
   for (auto shardingIt : llvm::enumerate(resultShardings)) {
-    MeshSharding shardAttr = shardingIt.value();
+    Sharding shardAttr = shardingIt.value();
     if (!shardAttr)
       continue;
     AffineMap map = maps[numOperands + shardingIt.index()];
     anyShardingInResultsOrOperands = true;
     if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
-      shardingOption.mesh = shardAttr.getMeshAttr();
+      shardingOption.grid = shardAttr.getGridAttr();
     } else {
       // Handle the split axes: calculate the corresponding loop index for each
       // split axes sub-array, and then store the sub-array to
       // shardingOption[index]
       for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
         AffineExpr expr = std::get<0>(it);
-        ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
+        ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef();
         auto dim = cast<AffineDimExpr>(expr);
         unsigned index = dim.getPosition();
         visitedLoopIndices.insert(index);
         if (failed(fillShardingOption(op, shardingOption,
-                                      shardAttr.getMeshAttr(), axes, index)))
+                                      shardAttr.getGridAttr(), axes, index)))
           return failure();
       }
     }
@@ -302,7 +300,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
 
   // 2. Fill sharding option based on operands
   for (auto shardingIt : llvm::enumerate(operandShardings)) {
-    MeshSharding shardAttr = shardingIt.value();
+    Sharding shardAttr = shardingIt.value();
     if (!shardAttr)
       continue;
 
@@ -316,7 +314,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
     // then the operands with multiple loop indices.
     for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
       AffineExpr expr = std::get<0>(it);
-      ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
+      ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef();
       FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
           checkOperandAffineExpr(expr, numDims);
       if (failed(loopIndices))
@@ -329,7 +327,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
         unsigned loopIdx = *loopIndices->begin();
         visitedLoopIndices.insert(loopIdx);
         if (failed(fillShardingOption(op, shardingOption,
-                                      shardAttr.getMeshAttr(), axes, loopIdx)))
+                                      shardAttr.getGridAttr(), axes, loopIdx)))
           return failure();
       }
       // If multiple loop indices correspond to a dimension of an operand, it is
@@ -361,11 +359,11 @@ mesh::detail::defaultGetShardingOption(Operation *op,
 }
 
 // Get the sharding attributed for the given result and sharding option.
-MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
-                         AffineMap map,
-                         ArrayRef<utils::IteratorType> loopTypes) {
+static Sharding getSharding(OpResult result,
+                            const ShardingOption &shardingOption, AffineMap map,
+                            ArrayRef<utils::IteratorType> loopTypes) {
   auto resultType = cast<RankedTensorType>(result.getType());
-  SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
+  SmallVector<SmallVector<GridAxis>> splitAxes(resultType.getRank());
 
   // process the split axes
   for (auto it : llvm::enumerate(map.getResults())) {
@@ -379,25 +377,25 @@ MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
   }
 
   removeTrailingEmptySubArray(splitAxes);
-  return MeshSharding::get(shardingOption.mesh,
-                           fromArrayOfVector(result.getContext(), splitAxes));
+  return Sharding::get(shardingOption.grid,
+                       fromArrayOfVector(result.getContext(), splitAxes));
 }
 
-static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
-                                           const ShardingOption &shardingOption,
-                                           AffineMap map) {
+static FailureOr<Sharding> getSharding(OpOperand &opOperand,
+                                       const ShardingOption &shardingOption,
+                                       AffineMap map) {
   Value operandValue = opOperand.get();
   auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
   if (!operandType) {
     if (operandValue.getType().isIntOrIndexOrFloat())
-      return MeshSharding();
+      return Sharding();
     return failure();
   }
   // 0d tensors cannot be sharded and must get replicated
   if (operandType.getRank() == 0) {
-    return MeshSharding(shardingOption.mesh);
+    return Sharding(shardingOption.grid);
   }
-  SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
+  SmallVector<SmallVector<GridAxis>> splitAxes(operandType.getRank());
   unsigned numDims = map.getNumDims();
   for (auto it : llvm::enumerate(map.getResults())) {
     int64_t idx = it.index();
@@ -422,15 +420,14 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
   }
 
   removeTrailingEmptySubArray(splitAxes);
-  return MeshSharding::get(
-      shardingOption.mesh,
+  return Sharding::get(
+      shardingOption.grid,
       fromArrayOfVector(opOperand.get().getContext(), splitAxes));
 }
 
-FailureOr<std::vector<MeshSharding>>
-mesh::detail::defaultGetShardingAnnotations(
+FailureOr<std::vector<Sharding>> shard::detail::defaultGetShardingAnnotations(
     Operation *op, const ShardingOption &shardingOption) {
-  std::vector<MeshSharding> res;
+  std::vector<Sharding> res;
 
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
   SmallVector<utils::IteratorType> loopTypes =
@@ -439,7 +436,7 @@ mesh::detail::defaultGetShardingAnnotations(
   unsigned numOperands = op->getNumOperands();
 
   for (OpOperand &opOperand : op->getOpOperands()) {
-    FailureOr<MeshSharding> shardingAttr = getSharding(
+    FailureOr<Sharding> shardingAttr = ::getSharding(
         opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
     if (failed(shardingAttr))
       return failure();
@@ -447,9 +444,9 @@ mesh::detail::defaultGetShardingAnnotations(
   }
 
   for (OpResult result : op->getResults()) {
-    res.push_back(getSharding(result, shardingOption,
-                              maps[numOperands + result.getResultNumber()],
-                              loopTypes));
+    res.push_back(::getSharding(result, shardingOption,
+                                maps[numOperands + result.getResultNumber()],
+                                loopTypes));
   }
 
   return res;
@@ -459,26 +456,25 @@ mesh::detail::defaultGetShardingAnnotations(
 // detail::defaultAddShardingAnnotations
 //===----------------------------------------------------------------------===//
 
-// To add a `mesh.shard` op for the given result, based on the details provided
+// To add a `shard.shard` op for the given result, based on the details provided
 // in `shardingOption`, `map`, and `loopTypes`.
 static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                 const ShardingOption &shardingOption,
                                 AffineMap map,
                                 ArrayRef<utils::IteratorType> loopTypes) {
-  MeshSharding sharding = getSharding(result, shardingOption, map, loopTypes);
+  Sharding sharding = getSharding(result, shardingOption, map, loopTypes);
   maybeInsertTargetShardingAnnotation(sharding, result, b);
 
   return success();
 }
 
-// To add a `mesh.shard` op for the given operand, based on the details provided
-// in `shardingOption`, `map`, and `loopTypes`.
+// To add a `shard.shard` op for the given operand, based on the details
+// provided in `shardingOption`, `map`, and `loopTypes`.
 static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
                                 const ShardingOption &shardingOption,
                                 AffineMap map) {
 
-  FailureOr<MeshSharding> sharding =
-      getSharding(opOperand, shardingOption, map);
+  FailureOr<Sharding> sharding = getSharding(opOperand, shardingOption, map);
   if (failed(sharding)) {
     return failure();
   }
@@ -488,9 +484,9 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
   return success();
 }
 
-LogicalResult mesh::detail::defaultAddShardingAnnotations(
+LogicalResult shard::detail::defaultAddShardingAnnotations(
     Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
-  assert(!shardingOption.empty && shardingOption.mesh);
+  assert(!shardingOption.empty && shardingOption.grid);
 
   ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
   SmallVector<utils::IteratorType> loopTypes =
@@ -498,7 +494,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
   unsigned numOperands = op->getNumOperands();
 
-  // 1. add mesh.shard ops for all op results
+  // 1. add shard.shard ops for all op results
   for (OpResult result : op->getResults()) {
     if (failed(addShardOp(b, result, shardingOption,
                           maps[numOperands + result.getResultNumber()],
@@ -506,7 +502,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
       return failure();
   }
 
-  // 2. add mesh.shard ops for all operands
+  // 2. add shard.shard ops for all operands
   for (OpOperand &opOperand : op->getOpOperands()) {
     if (failed(addShardOp(b, opOperand, shardingOption,
                           maps[opOperand.getOperandNumber()])))
@@ -517,9 +513,8 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
 }
 
 #ifndef NDEBUG
-static bool
-isValueCompatibleWithFullReplicationSharding(Value value,
-                                             MeshSharding sharding) {
+static bool isValueCompatibleWithFullReplicationSharding(Value value,
+                                                         Sharding sharding) {
   if (isa<RankedTensorType>(value.getType())) {
     return isFullReplication(sharding);
   }
@@ -527,60 +522,59 @@ isValueCompatibleWithFullReplicationSharding(Value value,
   return !sharding;
 }
 
-template <typename ValueRange, typename MeshShardingRage>
+template <typename ValueRange, typename ShardingRage>
 static bool
 areValuesCompatibleWithFullReplicationShardings(ValueRange &&values,
-                                                MeshShardingRage &&shardings) {
+                                                ShardingRage &&shardings) {
   if (std::size(values) != std::size(shardings)) {
     return false;
   }
-  return llvm::all_of(
-      llvm::zip_equal(std::forward<ValueRange>(values),
-                      std::forward<MeshShardingRage>(shardings)),
-      [](auto valueAndSharding) {
-        return isValueCompatibleWithFullReplicationSharding(
-            std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
-      });
+  return llvm::all_of(llvm::zip_equal(std::forward<ValueRange>(values),
+                                      std::forward<ShardingRage>(shardings)),
+                      [](auto valueAndSharding) {
+                        return isValueCompatibleWithFullReplicationSharding(
+                            std::get<0>(valueAndSharding),
+                            std::get<1>(valueAndSharding));
+                      });
 }
 #endif // NDEBUG
 
-void mesh::spmdizeFullyReplicatedOperation(
-    Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshSharding> operandShardings,
-    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
-    SymbolTableCollection &symbolTable, OpBuilder &builder) {
-  assert(spmdizedOperands.size() == operandShardings.size());
+void shard::partitionFullyReplicatedOperation(
+    Operation &op, ArrayRef<Value> partitionedOperands,
+    ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
+    IRMapping &partitionMap, SymbolTableCollection &symbolTable,
+    OpBuilder &builder) {
+  assert(partitionedOperands.size() == operandShardings.size());
   assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
                                                          operandShardings));
   assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(),
                                                          resultShardings));
   // `clone` will populate the mapping of old to new results.
-  builder.clone(op, spmdizationMap);
+  builder.clone(op, partitionMap);
 }
 
-static void updateMeshAxisAssignmentForLoopIterators(
-    ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
-    SmallVector<std::optional<SmallVector<MeshAxis>>>
-        &meshAxesAssignmentForLoopIterators) {
+static void updateGridAxisAssignmentForLoopIterators(
+    ArrayRef<GridAxis> gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
+    SmallVector<std::optional<SmallVector<GridAxis>>>
+        &gridAxesAssignmentForLoopIterators) {
   AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
   unsigned loopIteratorIdx = affineDimExpr.getPosition();
-  if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
-    assert(llvm::equal(meshAxesAssignmentForTensorAxis,
-                       *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
+  if (gridAxesAssignmentForLoopIterators[loopIteratorIdx]) {
+    assert(llvm::equal(gridAxesAssignmentForTensorAxis,
+                       *gridAxesAssignmentForLoopIterators[loopIteratorIdx]));
   } else {
-    meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
-        llvm::to_vector(meshAxesAssignmentForTensorAxis);
+    gridAxesAssignmentForLoopIterators[loopIteratorIdx] =
+        llvm::to_vector(gridAxesAssignmentForTensorAxis);
   }
 }
 
-ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
-    ArrayRef<MeshSharding> operandShardings,
-    ArrayRef<MeshSharding> resultShardings,
+ShardingArray shard::getGridAxisAssignmentForLoopIterators(
+    ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
     ArrayRef<utils::IteratorType> loopIteratorTypes,
     ArrayRef<AffineMap> indexingMaps) {
-  SmallVector<std::optional<SmallVector<MeshAxis>>>
-      meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
-  std::vector<MeshSharding> operatorAndResultShardings;
+  SmallVector<std::optional<SmallVector<GridAxis>>>
+      gridAxisAssignmentForLoopIterators(loopIteratorTypes.size());
+  std::vector<Sharding> operatorAndResultShardings;
   operatorAndResultShardings.reserve(operandShardings.size() +
                                      resultShardings.size());
   llvm::append_range(operatorAndResultShardings, operandShardings);
@@ -589,69 +583,69 @@ ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
     if (!sharding) {
       continue;
     }
-    for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
+    for (auto [gridAxesAssignmentForTensorAxis, indexingExpr] :
          llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
-      updateMeshAxisAssignmentForLoopIterators(
-          meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
-          meshAxisAssignmentForLoopIterators);
+      updateGridAxisAssignmentForLoopIterators(
+          gridAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
+          gridAxisAssignmentForLoopIterators);
     }
     // Missing trailing split axes means replication on those tensor dimensions.
     for (unsigned i = sharding.getSplitAxes().size();
          i < affineMap.getNumResults(); ++i) {
-      updateMeshAxisAssignmentForLoopIterators(
-          {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
+      updateGridAxisAssignmentForLoopIterators(
+          {}, affineMap.getResults()[i], gridAxisAssignmentForLoopIterators);
     }
   }
 
   ShardingArray res;
-  llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
-                  [](std::optional<SmallVector<MeshAxis>> &axes) {
+  llvm::transform(gridAxisAssignmentForLoopIterators, std::back_inserter(res),
+                  [](std::optional<SmallVector<GridAxis>> &axes) {
                     if (!axes) {
-                      return SmallVector<MeshAxis>();
+                      return SmallVector<GridAxis>();
                     };
                     return std::move(*axes);
                   });
   return res;
 }
 
-bool mesh::isAtLeastOneReductionIteratorSharded(
+bool shard::isAtLeastOneReductionIteratorSharded(
     ArrayRef<utils::IteratorType> loopIteratorTypes,
-    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
-  for (auto [loopIteratorType, meshAxisAssignment] :
-       llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+    ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) {
+  for (auto [loopIteratorType, gridAxisAssignment] :
+       llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
     if (loopIteratorType == utils::IteratorType::reduction &&
-        !meshAxisAssignment.empty()) {
+        !gridAxisAssignment.empty()) {
       return true;
     }
   }
   return false;
 }
 
-SmallVector<MeshAxis> mesh::getReductionMeshAxes(
+SmallVector<GridAxis> shard::getReductionGridAxes(
     ArrayRef<utils::IteratorType> loopIteratorTypes,
-    ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
-  SmallVector<MeshAxis> meshAxes;
-  for (auto [loopIteratorType, meshAxisAssignment] :
-       llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+    ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) {
+  SmallVector<GridAxis> gridAxes;
+  for (auto [loopIteratorType, gridAxisAssignment] :
+       llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
     if (loopIteratorType == utils::IteratorType::reduction) {
-      llvm::append_range(meshAxes, meshAxisAssignment);
+      llvm::append_range(gridAxes, gridAxisAssignment);
     }
   }
-  return meshAxes;
+  return gridAxes;
 }
 
-void mesh::spmdizeTriviallyShardableOperation(
-    Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshSharding> operandShardings,
-    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
-    SymbolTableCollection &symbolTable, OpBuilder &builder) {
+void shard::partitionTriviallyShardableOperation(
+    Operation &op, ArrayRef<Value> partitionedOperands,
+    ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
+    IRMapping &partitionMap, SymbolTableCollection &symbolTable,
+    OpBuilder &builder) {
   // `clone` will populate the mapping of old to new results.
-  Operation *newOp = builder.clone(op, spmdizationMap);
+  Operation *newOp = builder.clone(op, partitionMap);
   // Set the result types to the sharded counterparts.
   for (auto [oldResult, newResult, sharding] :
        llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
     newResult.setType(shardType(
         newResult.getType(),
-        getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
+        getGridOrNull(&op, sharding.getGridAttr(), symbolTable), sharding));
   }
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
similarity index 73%
rename from mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
rename to mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
index 381bc9afede07..a884764e70e92 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
@@ -1,14 +1,14 @@
-add_mlir_dialect_library(MLIRMeshTransforms
+add_mlir_dialect_library(MLIRShardTransforms
   Simplifications.cpp
   ShardingPropagation.cpp
-  Spmdization.cpp
+  Partition.cpp
   Transforms.cpp
 
   ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard
 
   DEPENDS
-  MLIRMeshPassIncGen
+  MLIRShardPassIncGen
   MLIRShardingInterface
 
   LINK_LIBS PUBLIC
@@ -21,7 +21,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
   MLIRFuncDialect
   MLIRFunctionInterfaces
   MLIRIR
-  MLIRMeshDialect
+  MLIRShardDialect
   MLIRPass
   MLIRSupport
   MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
similarity index 66%
rename from mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
rename to mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index c6e76ecae0f21..621a43a211206 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -1,4 +1,4 @@
-//===- Spmdization.cpp --------------------------------------------- C++ --===//
+//===- Partition.cpp --------------------------------------------- C++ --===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/Shard/Transforms/Partition.h"
 
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -33,7 +33,7 @@
 #include <optional>
 #include <tuple>
 
-namespace mlir::mesh {
+namespace mlir::shard {
 
 template <typename SourceAxes, typename TargetAxes>
 static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
@@ -43,52 +43,51 @@ static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
   });
 }
 
-static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
-                                                  MeshSharding sourceSharding,
-                                                  int64_t splitTensorAxis,
-                                                  MeshAxis splitMeshAxis) {
-  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
+static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx,
+                                              Sharding sourceSharding,
+                                              int64_t splitTensorAxis,
+                                              GridAxis splitGridAxis) {
+  SmallVector<GridAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
          splitTensorAxis) {
-    targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
+    targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
   }
   auto targetSplitAxes =
       llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
-  targetSplitAxes.push_back(splitMeshAxis);
+  targetSplitAxes.push_back(splitGridAxis);
   targetShardingSplitAxes[splitTensorAxis] =
-      MeshAxesAttr::get(ctx, targetSplitAxes);
-  return MeshSharding::get(sourceSharding.getMeshAttr(),
-                           targetShardingSplitAxes);
+      GridAxesAttr::get(ctx, targetSplitAxes);
+  return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
 }
 
-// Split a replicated tensor along a mesh axis.
+// Split a replicated tensor along a grid axis.
 // E.g. [[0, 1]] -> [[0, 1, 2]].
-// Returns the spmdized target value with its sharding.
-static std::tuple<TypedValue<ShapedType>, MeshSharding>
+// Returns the partitioned target value with its sharding.
+static std::tuple<TypedValue<ShapedType>, Sharding>
 splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
-                          MeshSharding sourceSharding,
-                          TypedValue<ShapedType> sourceShard, MeshOp mesh,
-                          int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+                          Sharding sourceSharding,
+                          TypedValue<ShapedType> sourceShard, GridOp grid,
+                          int64_t splitTensorAxis, GridAxis splitGridAxis) {
   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
       builder
-          .create<AllSliceOp>(sourceShard, mesh,
-                              ArrayRef<MeshAxis>(splitMeshAxis),
+          .create<AllSliceOp>(sourceShard, grid,
+                              ArrayRef<GridAxis>(splitGridAxis),
                               splitTensorAxis)
           .getResult());
-  MeshSharding targetSharding = targetShardingInSplitLastAxis(
-      builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
+  Sharding targetSharding = targetShardingInSplitLastAxis(
+      builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
   return {targetShard, targetSharding};
 }
 
 // Detect if the resharding is of type e.g.
 // [[0, 1]] -> [[0, 1, 2]].
-// If detected, returns the corresponding tensor axis mesh axis pair.
+// If detected, returns the corresponding tensor axis grid axis pair.
 // Does not detect insertions like
 // [[0, 1]] -> [[0, 2, 1]].
-static std::optional<std::tuple<int64_t, MeshAxis>>
-detectSplitLastAxisInResharding(MeshSharding sourceSharding,
-                                MeshSharding targetSharding) {
+static std::optional<std::tuple<int64_t, GridAxis>>
+detectSplitLastAxisInResharding(Sharding sourceSharding,
+                                Sharding targetSharding) {
   for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
        ++tensorAxis) {
     if (sourceSharding.getSplitAxes().size() > tensorAxis) {
@@ -118,16 +117,15 @@ detectSplitLastAxisInResharding(MeshSharding sourceSharding,
   return std::nullopt;
 }
 
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                             MeshSharding sourceSharding,
-                             MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+                             Sharding sourceSharding, Sharding targetSharding,
                              TypedValue<ShapedType> sourceShard) {
   if (auto detectRes =
           detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
-    auto [tensorAxis, meshAxis] = detectRes.value();
-    return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
-                                     tensorAxis, meshAxis);
+    auto [tensorAxis, gridAxis] = detectRes.value();
+    return splitLastAxisInResharding(builder, sourceSharding, sourceShard, grid,
+                                     tensorAxis, gridAxis);
   }
 
   return std::nullopt;
@@ -135,10 +133,10 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
 
 // Detect if the resharding is of type e.g.
 // [[0, 1, 2]] -> [[0, 1]].
-// If detected, returns the corresponding tensor axis mesh axis pair.
-static std::optional<std::tuple<int64_t, MeshAxis>>
-detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
-                                  MeshSharding targetSharding) {
+// If detected, returns the corresponding tensor axis grid axis pair.
+static std::optional<std::tuple<int64_t, GridAxis>>
+detectUnsplitLastAxisInResharding(Sharding sourceSharding,
+                                  Sharding targetSharding) {
   for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
        ++tensorAxis) {
     if (targetSharding.getSplitAxes().size() > tensorAxis) {
@@ -165,10 +163,10 @@ detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
   return std::nullopt;
 }
 
-static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
-                                                    MeshSharding sourceSharding,
-                                                    int64_t splitTensorAxis) {
-  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
+static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+                                                Sharding sourceSharding,
+                                                int64_t splitTensorAxis) {
+  SmallVector<GridAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
          splitTensorAxis);
@@ -177,9 +175,8 @@ static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
 
   targetSplitAxes.pop_back();
   targetShardingSplitAxes[splitTensorAxis] =
-      MeshAxesAttr::get(ctx, targetSplitAxes);
-  return MeshSharding::get(sourceSharding.getMeshAttr(),
-                           targetShardingSplitAxes);
+      GridAxesAttr::get(ctx, targetSplitAxes);
+  return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
 }
 
 static ShapedType allGatherResultShapeInUnsplitLastAxis(
@@ -190,45 +187,42 @@ static ShapedType allGatherResultShapeInUnsplitLastAxis(
   return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
 }
 
-static std::tuple<TypedValue<ShapedType>, MeshSharding>
-unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
-                            MeshSharding sourceSharding,
-                            ShapedType sourceUnshardedShape,
-                            TypedValue<ShapedType> sourceShard, MeshOp mesh,
-                            int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
+    ImplicitLocOpBuilder &builder, Sharding sourceSharding,
+    ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
+    GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
   MLIRContext *ctx = builder.getContext();
   builder.setInsertionPointAfterValue(sourceShard);
 
-  MeshSharding targetSharding =
+  Sharding targetSharding =
       targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
   ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
-      sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
+      sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
   Value allGatherResult = AllGatherOp::create(
       builder,
       RankedTensorType::get(allGatherResultShape.getShape(),
                             allGatherResultShape.getElementType()),
-      mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
+      grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard,
       APInt(64, splitTensorAxis));
   ShapedType targetShape =
-      shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+      shardShapedType(sourceUnshardedShape, grid, targetSharding);
   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
       tensor::CastOp::create(builder, targetShape, allGatherResult)
           .getResult());
   return {targetShard, targetSharding};
 }
 
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                               MeshSharding sourceSharding,
-                               MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+                               Sharding sourceSharding, Sharding targetSharding,
                                ShapedType sourceUnshardedShape,
                                TypedValue<ShapedType> sourceShard) {
   if (auto detectRes =
           detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
-    auto [tensorAxis, meshAxis] = detectRes.value();
+    auto [tensorAxis, gridAxis] = detectRes.value();
     return unsplitLastAxisInResharding(builder, sourceSharding,
-                                       sourceUnshardedShape, sourceShard, mesh,
-                                       tensorAxis, meshAxis);
+                                       sourceUnshardedShape, sourceShard, grid,
+                                       tensorAxis, gridAxis);
   }
 
   return std::nullopt;
@@ -238,10 +232,10 @@ tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
 // [[0, 1], [2]] -> [[0], [1, 2]].
 // Only moving the last axis counts.
 // If detected, returns the corresponding (source_tensor_axis,
-// target_tensor_axis, mesh_axis) tuple.
-static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
-detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
-                                    MeshSharding targetSharding) {
+// target_tensor_axis, grid_axis) tuple.
+static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
+detectMoveLastSplitAxisInResharding(Sharding sourceSharding,
+                                    Sharding targetSharding) {
   for (size_t sourceTensorAxis = 0;
        sourceTensorAxis < sourceSharding.getSplitAxes().size();
        ++sourceTensorAxis) {
@@ -281,33 +275,32 @@ detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
   return std::nullopt;
 }
 
-static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
-                                                 MeshSharding sourceSharding,
-                                                 int64_t sourceTensorAxis,
-                                                 int64_t targetTensorAxis) {
-  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
+static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx,
+                                             Sharding sourceSharding,
+                                             int64_t sourceTensorAxis,
+                                             int64_t targetTensorAxis) {
+  SmallVector<GridAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
          targetTensorAxis) {
-    targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
+    targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
   }
 
   auto sourceSplitAxes =
       llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
   assert(!sourceSplitAxes.empty());
-  auto meshAxis = sourceSplitAxes.back();
+  auto gridAxis = sourceSplitAxes.back();
   sourceSplitAxes.pop_back();
   targetShardingSplitAxes[sourceTensorAxis] =
-      MeshAxesAttr::get(ctx, sourceSplitAxes);
+      GridAxesAttr::get(ctx, sourceSplitAxes);
 
   auto targetSplitAxes =
       llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
-  targetSplitAxes.push_back(meshAxis);
+  targetSplitAxes.push_back(gridAxis);
   targetShardingSplitAxes[targetTensorAxis] =
-      MeshAxesAttr::get(ctx, targetSplitAxes);
+      GridAxesAttr::get(ctx, targetSplitAxes);
 
-  return MeshSharding::get(sourceSharding.getMeshAttr(),
-                           targetShardingSplitAxes);
+  return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
 }
 
 static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
@@ -322,46 +315,46 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
   return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
 }
 
-static std::tuple<TypedValue<ShapedType>, MeshSharding>
-moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                              MeshSharding sourceSharding,
+static std::tuple<TypedValue<ShapedType>, Sharding>
+moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+                              Sharding sourceSharding,
                               ShapedType sourceUnshardedShape,
                               TypedValue<ShapedType> sourceShard,
                               int64_t sourceTensorAxis,
-                              int64_t targetTensorAxis, MeshAxis meshAxis) {
+                              int64_t targetTensorAxis, GridAxis gridAxis) {
   MLIRContext *ctx = builder.getContext();
   builder.setInsertionPointAfterValue(sourceShard);
 
-  MeshSharding targetSharding = targetShardingInMoveLastAxis(
+  Sharding targetSharding = targetShardingInMoveLastAxis(
       ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
   ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
-      sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
+      sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
       targetTensorAxis);
   Value allToAllResult = AllToAllOp::create(
       builder,
       RankedTensorType::get(allToAllResultShape.getShape(),
                             allToAllResultShape.getElementType()),
-      mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
+      grid.getSymName(), SmallVector<GridAxis>({gridAxis}), sourceShard,
       APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
   ShapedType targetShape =
-      shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+      shardShapedType(sourceUnshardedShape, grid, targetSharding);
   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
       tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
   return {targetShard, targetSharding};
 }
 
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                                 MeshSharding sourceSharding,
-                                 MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+                                 Sharding sourceSharding,
+                                 Sharding targetSharding,
                                  ShapedType sourceUnshardedShape,
                                  TypedValue<ShapedType> sourceShard) {
   if (auto detectRes =
           detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
-    auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
+    auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
     return moveLastSplitAxisInResharding(
-        builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
-        sourceTensorAxis, targetTensorAxis, meshAxis);
+        builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
+        sourceTensorAxis, targetTensorAxis, gridAxis);
   }
 
   return std::nullopt;
@@ -371,10 +364,9 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
 // needed. A changed halo sizes requires copying the "core" of the source tensor
 // into the "core" of the destination tensor followed by an update halo
 // operation.
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                          MeshSharding sourceSharding,
-                          MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+                          Sharding sourceSharding, Sharding targetSharding,
                           ShapedType sourceUnshardedShape,
                           TypedValue<ShapedType> sourceShard) {
   // Currently handles only cases where halo sizes differ but everything else
@@ -392,7 +384,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
   assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
           ShapedType::isStaticShape(tgtHaloSizes) &&
           sourceShard.getType().hasStaticShape()) &&
-         "dynamic shapes/halos are not supported yet for mesh-spmdization");
+         "dynamic shapes/halos are not supported yet for shard-partition");
   auto rank = sourceShard.getType().getRank();
   auto splitAxes = sourceSharding.getSplitAxes();
   SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
@@ -433,8 +425,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
               sourceShard.getLoc(),
               RankedTensorType::get(outShape,
                                     sourceShard.getType().getElementType()),
-              initOprnd, mesh.getSymName(),
-              MeshAxesArrayAttr::get(builder.getContext(),
+              initOprnd, grid.getSymName(),
+              GridAxesArrayAttr::get(builder.getContext(),
                                      sourceSharding.getSplitAxes()),
               targetSharding.getDynamicHaloSizes(),
               targetSharding.getStaticHaloSizes())
@@ -443,41 +435,41 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                          targetSharding);
 }
 
-// Handles only resharding on a 1D mesh.
+// Handles only resharding on a 1D shard.
 // Currently the sharded tensor axes must be exactly divisible by the single
-// mesh axis size.
+// grid axis size.
 static TypedValue<ShapedType>
-reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                MeshSharding sourceSharding, MeshSharding targetSharding,
+reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
+                Sharding sourceSharding, Sharding targetSharding,
                 TypedValue<ShapedType> sourceUnshardedValue,
                 TypedValue<ShapedType> sourceShard) {
   assert(sourceShard.getType() ==
-         shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
+         shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
   [[maybe_unused]] ShapedType targetShardType =
-      shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
+      shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
   assert(sourceShard.getType().getRank() == targetShardType.getRank());
-  assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
+  assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
 
   if (sourceSharding == targetSharding) {
     return sourceShard;
   }
 
   TypedValue<ShapedType> targetShard;
-  MeshSharding actualTargetSharding;
+  Sharding actualTargetSharding;
   if (sourceSharding.getStaticShardedDimsOffsets().empty() &&
       targetSharding.getStaticShardedDimsOffsets().empty() &&
       sourceSharding.getStaticHaloSizes().empty() &&
       targetSharding.getStaticHaloSizes().empty()) {
     if (auto tryRes = tryMoveLastSplitAxisInResharding(
-            builder, mesh, sourceSharding, targetSharding,
+            builder, grid, sourceSharding, targetSharding,
             sourceUnshardedValue.getType(), sourceShard)) {
       std::tie(targetShard, actualTargetSharding) = tryRes.value();
     } else if (auto tryRes =
-                   trySplitLastAxisInResharding(builder, mesh, sourceSharding,
+                   trySplitLastAxisInResharding(builder, grid, sourceSharding,
                                                 targetSharding, sourceShard)) {
       std::tie(targetShard, actualTargetSharding) = tryRes.value();
     } else if (auto tryRes = tryUnsplitLastAxisInResharding(
-                   builder, mesh, sourceSharding, targetSharding,
+                   builder, grid, sourceSharding, targetSharding,
                    sourceUnshardedValue.getType(), sourceShard)) {
       std::tie(targetShard, actualTargetSharding) = tryRes.value();
     }
@@ -488,9 +480,8 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
   return targetShard;
 }
 
-TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
-                               MeshSharding sourceSharding,
-                               MeshSharding targetSharding,
+TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, GridOp grid,
+                               Sharding sourceSharding, Sharding targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
                                TypedValue<ShapedType> sourceShard) {
   // If source and destination sharding are the same, no need to do anything.
@@ -500,28 +491,28 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
   }
 
   // Tries to handle the case where the resharding is needed because the halo
-  // sizes are different. Supports arbitrary mesh dimensionality.
+  // sizes are different. Supports arbitrary grid dimensionality.
   if (auto tryRes = tryUpdateHaloInResharding(
-          builder, mesh, sourceSharding, targetSharding,
+          builder, grid, sourceSharding, targetSharding,
           sourceUnshardedValue.getType(), sourceShard)) {
     return std::get<0>(tryRes.value()); // targetShard
   }
 
-  // Resort to handling only 1D meshes since the general case is complicated if
+  // Resort to handling only 1D grids since the general case is complicated if
   // it needs to be communication efficient in terms of minimizing the data
   // transfered between devices.
-  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
+  return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
                          sourceUnshardedValue, sourceShard);
 }
 
-TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
                                ShardOp target,
                                TypedValue<ShapedType> sourceShardValue) {
   assert(source.getResult() == target.getSrc());
   auto sourceSharding = source.getSharding();
   auto targetSharding = target.getSharding();
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
-  return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
+  return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
                  cast<TypedValue<ShapedType>>(source.getSrc()),
                  sourceShardValue);
 }
@@ -530,21 +521,21 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
                                ShardOp target,
                                TypedValue<ShapedType> sourceShardValue,
                                SymbolTableCollection &symbolTableCollection) {
-  MeshOp srcMesh = getMesh(source, symbolTableCollection);
-  assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
-  return reshard(builder, srcMesh, source, target, sourceShardValue);
+  GridOp srcGrid = getGrid(source, symbolTableCollection);
+  assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection));
+  return reshard(builder, srcGrid, source, target, sourceShardValue);
 }
 
 void reshardingRegisterDependentDialects(DialectRegistry &registry) {
-  registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
+  registry.insert<shard::ShardDialect, tensor::TensorDialect>();
 }
 
-#define GEN_PASS_DEF_SPMDIZATION
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+#define GEN_PASS_DEF_PARTITION
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
 
 using UnshardedToShardedValueMap = DenseMap<Value, Value>;
 
-// Get the types of block arguments for an spmdized block.
+// Get the types of block arguments for an partitioned block.
 // Reads the sharding annotations of the arguments to deduce the sharded types.
 // Types that are not ranked tensors are left unchanged.
 SmallVector<Type>
@@ -563,35 +554,36 @@ shardedBlockArgumentTypes(Block &block,
         Operation *useOp = *rankedTensorArg.getUsers().begin();
         ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
         assert(shardOp);
-        MeshOp mesh = getMesh(shardOp, symbolTableCollection);
-        return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
+        GridOp grid = getGrid(shardOp, symbolTableCollection);
+        return cast<Type>(shardShapedType(rankedTensorArg.getType(), grid,
                                           shardOp.getSharding()));
       });
   return res;
 }
 
-static LogicalResult spmdizeOperation(
-    Operation &op, ArrayRef<Value> spmdizedOperands,
-    ArrayRef<MeshSharding> operandShardings,
-    ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
-    SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
+static LogicalResult
+partitionOperation(Operation &op, ArrayRef<Value> partitionedOperands,
+                   ArrayRef<Sharding> operandShardings,
+                   ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
+                   SymbolTableCollection &symbolTableCollection,
+                   OpBuilder &builder) {
   ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
   if (!shardingInterface) {
     // If there is no sharding interface we are conservative and assume that
     // the op should be fully replicated no all devices.
-    spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
-                                    resultShardings, spmdizationMap,
-                                    symbolTableCollection, builder);
+    partitionFullyReplicatedOperation(op, partitionedOperands, operandShardings,
+                                      resultShardings, partitionMap,
+                                      symbolTableCollection, builder);
   } else {
-    if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
-                                         resultShardings, spmdizationMap,
-                                         symbolTableCollection, builder))) {
+    if (failed(shardingInterface.partition(partitionedOperands, operandShardings,
+                                           resultShardings, partitionMap,
+                                           symbolTableCollection, builder))) {
       return failure();
     }
   }
 
-  assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
-    return spmdizationMap.contains(result);
+  assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) {
+    return partitionMap.contains(result);
   }));
 
   return success();
@@ -599,88 +591,88 @@ static LogicalResult spmdizeOperation(
 
 // Retrieve the sharding annotations for the operands of the given operation.
 // If the type is not a ranked tensor it is not require to have an annotation.
-static std::vector<MeshSharding> getOperandShardings(Operation &op) {
-  std::vector<MeshSharding> res;
+static std::vector<Sharding> getOperandShardings(Operation &op) {
+  std::vector<Sharding> res;
   res.reserve(op.getNumOperands());
   llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
     TypedValue<RankedTensorType> rankedTensor =
         dyn_cast<TypedValue<RankedTensorType>>(operand);
     if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
-      return MeshSharding();
+      return Sharding();
     }
 
     Operation *definingOp = operand.getDefiningOp();
     assert(definingOp);
     ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
-    return MeshSharding(shardOp.getSharding());
+    return Sharding(shardOp.getSharding());
   });
   return res;
 }
 
 // Retrieve the sharding annotations for the results of the given operation.
 // If the type is not a ranked tensor it is not require to have an annotation.
-static std::vector<MeshSharding> getResultShardings(Operation &op) {
-  std::vector<MeshSharding> res;
+static std::vector<Sharding> getResultShardings(Operation &op) {
+  std::vector<Sharding> res;
   res.reserve(op.getNumResults());
   llvm::transform(
       op.getResults(), std::back_inserter(res), [&op](OpResult result) {
         if (!result.hasOneUse() || result.use_empty()) {
-          return MeshSharding();
+          return Sharding();
         }
         TypedValue<RankedTensorType> rankedTensor =
             dyn_cast<TypedValue<RankedTensorType>>(result);
         if (!rankedTensor) {
-          return MeshSharding();
+          return Sharding();
         }
         Operation *userOp = *result.getUsers().begin();
         ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
         if (shardOp) {
-          return MeshSharding(shardOp.getSharding());
+          return Sharding(shardOp.getSharding());
         }
         if (rankedTensor.getType().getRank() == 0) {
           // This is a 0d tensor result without explicit sharding.
-          // Find mesh symbol from operands, if any.
-          // Shardings without mesh are not always fully supported yet.
+          // Find grid symbol from operands, if any.
+          // Shardings without grid are not always fully supported yet.
           for (auto operand : op.getOperands()) {
             if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
-              return MeshSharding(sharding.getMeshAttr());
+              return Sharding(sharding.getGridAttr());
             }
           }
         }
-        return MeshSharding();
+        return Sharding();
       });
   return res;
 }
 
 static LogicalResult
-spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
-                 SymbolTableCollection &symbolTableCollection,
-                 OpBuilder &builder) {
-  Value targetSpmdValue;
+partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
+                   SymbolTableCollection &symbolTableCollection,
+                   OpBuilder &builder) {
+  Value targetPartitionValue;
 
   // Check if 2 shard ops are chained. If not there is no need for resharding
   // as the source and target shared the same sharding.
   ShardOp srcShardOp =
       dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
   if (!srcShardOp) {
-    targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
+    targetPartitionValue = partitionMap.lookup(shardOp.getSrc());
   } else {
     // Insert resharding.
-    TypedValue<ShapedType> srcSpmdValue =
-        cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
-    targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
-                              symbolTableCollection);
+    TypedValue<ShapedType> srcPartitionValue =
+        cast<TypedValue<ShapedType>>(partitionMap.lookup(srcShardOp));
+    targetPartitionValue = reshard(builder, srcShardOp, shardOp,
+                                   srcPartitionValue, symbolTableCollection);
   }
 
-  assert(!spmdizationMap.contains(shardOp.getResult()));
-  spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+  assert(!partitionMap.contains(shardOp.getResult()));
+  partitionMap.map(shardOp.getResult(), targetPartitionValue);
   return success();
 }
 
 static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
-                 SymbolTableCollection &symbolTableCollection,
-                 OpBuilder &builder) {
+partitionOperation(Operation &op, IRMapping &partitionMap,
+                   SymbolTableCollection &symbolTableCollection,
+                   OpBuilder &builder) {
   if (isa<ShardingOp>(op)) {
     return success();
   }
@@ -690,30 +682,31 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
       return op.emitError("expected a shard op as source of get_sharding");
     }
     auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
-    spmdizationMap.map(op.getResult(0), newSharding->getResult(0));
+    partitionMap.map(op.getResult(0), newSharding->getResult(0));
     return success();
   }
 
   ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
   if (shardOp) {
-    return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
-                            builder);
+    return partitionOperation(shardOp, partitionMap, symbolTableCollection,
+                              builder);
   }
 
-  SmallVector<Value> spmdizedOperands;
-  llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
-                  [&spmdizationMap](Value operand) {
-                    assert(spmdizationMap.contains(operand));
-                    return spmdizationMap.lookup(operand);
+  SmallVector<Value> partitionedOperands;
+  llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
+                  [&partitionMap](Value operand) {
+                    assert(partitionMap.contains(operand));
+                    return partitionMap.lookup(operand);
                   });
-  return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
-                          getResultShardings(op), spmdizationMap,
-                          symbolTableCollection, builder);
+  return partitionOperation(op, partitionedOperands, getOperandShardings(op),
+                            getResultShardings(op), partitionMap,
+                            symbolTableCollection, builder);
 }
 
-static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
-                                  SymbolTableCollection &symbolTableCollection,
-                                  OpBuilder &builder) {
+static LogicalResult
+partitionBlock(Block &block, IRMapping &partitionMap,
+               SymbolTableCollection &symbolTableCollection,
+               OpBuilder &builder) {
 
   SmallVector<Location> argLocations;
   llvm::transform(block.getArguments(), std::back_inserter(argLocations),
@@ -721,16 +714,16 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
   Block *newBlock = builder.createBlock(
       block.getParent(), {},
       shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
-  for (auto [unshardedBlockArg, spmdizedBlockArg] :
+  for (auto [unshardedBlockArg, partitionedBlockArg] :
        llvm::zip(block.getArguments(), newBlock->getArguments())) {
-    spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
+    partitionMap.map(unshardedBlockArg, partitionedBlockArg);
   }
 
   OpBuilder::InsertionGuard insertionGuard(builder);
   builder.setInsertionPointToEnd(newBlock);
   for (Operation &op : block.getOperations()) {
-    if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
-                                builder))) {
+    if (failed(partitionOperation(op, partitionMap, symbolTableCollection,
+                                  builder))) {
       return failure();
     }
   }
@@ -739,8 +732,8 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
 }
 
 static LogicalResult
-spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
-              SymbolTableCollection &symbolTableCollection) {
+partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap,
+                SymbolTableCollection &symbolTableCollection) {
   OpBuilder builder(op.getFunctionBody());
 
   // Snapshot the original blocks to not mess up the iteration when adding new
@@ -754,8 +747,8 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
   }
 
   for (Block *block : originalBlocks) {
-    if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
-                            builder))) {
+    if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,
+                              builder))) {
       return failure();
     }
   }
@@ -788,22 +781,22 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
 
 namespace {
 
-struct Spmdization : public impl::SpmdizationBase<Spmdization> {
+struct Partition : public impl::PartitionBase<Partition> {
   void runOnOperation() override {
-    IRMapping spmdizationMap;
+    IRMapping partitionMap;
     SymbolTableCollection symbolTableCollection;
-    if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
-                             symbolTableCollection))) {
+    if (failed(partitionFuncOp(getOperation(), partitionMap,
+                               symbolTableCollection))) {
       return signalPassFailure();
     }
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
     reshardingRegisterDependentDialects(registry);
-    registry.insert<mesh::MeshDialect>();
+    registry.insert<shard::ShardDialect>();
   }
 };
 
 } // namespace
 
-} // namespace mlir::mesh
+} // namespace mlir::shard
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
similarity index 85%
rename from mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
rename to mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index 09c754da7a6b7..17220c4256bf3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -6,11 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Mesh/Transforms/Passes.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
 
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
@@ -21,17 +21,17 @@
 #include <vector>
 
 namespace mlir {
-namespace mesh {
+namespace shard {
 #define GEN_PASS_DEF_SHARDINGPROPAGATION
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
-} // namespace mesh
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
+} // namespace shard
 } // namespace mlir
 
 #define DEBUG_TYPE "sharding-propagation"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 
 using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
 
 enum class ReshardingRquirementKind {
   NO_RESHARDING = 0,
@@ -68,7 +68,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
 
 [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
                                                       const ShardingOption &v) {
-  return stream << "{empty = " << v.empty << ", mesh" << v.mesh
+  return stream << "{empty = " << v.empty << ", grid" << v.grid
                 << ", shardingArray = " << v.shardingArray << "}";
 }
 
@@ -105,15 +105,15 @@ operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
 // specific shardings. For example, mustShardings = [shard0, None] and
 // optionalShardings = [None, shard1], the result will be [[shard0, shard1],
 // [shard0, None]]
-static SmallVector<std::vector<MeshSharding>>
-getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
-                                ArrayRef<MeshSharding> optionalShardings) {
-  SmallVector<std::vector<MeshSharding>> allShardingAttrs;
-  std::vector<MeshSharding> curShardingAttrs;
+static SmallVector<std::vector<Sharding>>
+getOrderedPossibleShardingAttrs(ArrayRef<Sharding> mustShardings,
+                                ArrayRef<Sharding> optionalShardings) {
+  SmallVector<std::vector<Sharding>> allShardingAttrs;
+  std::vector<Sharding> curShardingAttrs;
 
   std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
     if (i == mustShardings.size()) {
-      allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
+      allShardingAttrs.push_back(std::vector<Sharding>(curShardingAttrs));
       return;
     }
 
@@ -147,14 +147,14 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
 // 1. No resharding is required (all existing annotations are compatible).
 // 2. No resharding for operands/results that have annotation specifically
 //   targeting this operation. This means
-//   * operands that are the result of `mesh.shard` ops marked with
+//   * operands that are the result of `shard.shard` ops marked with
 //     `annotate_for_users`.
-//   * results that are annotated with `mesh.shard` ops without
+//   * results that are annotated with `shard.shard` ops without
 //     `annotate_for_users`.
 // 3. All other cases. Resharding is required for operands/results with
 //   annotation targeting explicitly this operation.
 ReshardingRquirementKind getReshardingRquirementKind(
-    Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
+    Operation *op, const std::vector<Sharding> &operandAndResultShardings) {
   ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
 
   size_t operandsCount = op->getOperands().size();
@@ -213,14 +213,13 @@ ReshardingRquirementKind getReshardingRquirementKind(
 // 3. Resharding of existing explicit sharding annotations for this op.
 static FailureOr<ShardingOption> selectShardingOption(
     ShardingInterface shardingOp,
-    ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
-    ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
+    ArrayRef<std::vector<Sharding>> possibleOperandShardingAttrs,
+    ArrayRef<std::vector<Sharding>> possibleResultShardingAttrs) {
   SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
       shardingOptionsAndReshardingRequirements;
 
-  for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
-    for (ArrayRef<MeshSharding> operandShardings :
-         possibleOperandShardingAttrs) {
+  for (ArrayRef<Sharding> resultShardings : possibleResultShardingAttrs) {
+    for (ArrayRef<Sharding> operandShardings : possibleOperandShardingAttrs) {
       FailureOr<ShardingOption> shardingOption =
           shardingOp.getShardingOption(operandShardings, resultShardings);
       if (failed(shardingOption) || shardingOption->empty) {
@@ -231,7 +230,7 @@ static FailureOr<ShardingOption> selectShardingOption(
       // They may be missing some annotations.
       // Whatever is returned by getShardingAnnotations is exactly what the op
       // needs.
-      FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
+      FailureOr<std::vector<Sharding>> operandAndResultShardings =
           shardingOp.getShardingAnnotations(*shardingOption);
       if (failed(operandAndResultShardings)) {
         return failure();
@@ -276,13 +275,13 @@ static FailureOr<ShardingOption> selectShardingOption(
 // For each operation that implements the ShardingInterface, infer the sharding
 // option of the operation from its operands and/or results using the
 // `getShardingOption` method. If the inferred sharding option is not empty, add
-// a `mesh.shard` operation for all remaining operands and results that do not
+// a `shard.shard` operation for all remaining operands and results that do not
 // have sharding annotations.
 static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
   ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
   if (op->hasTrait<OpTrait::IsTerminator>() ||
       (op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
-      llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
+      llvm::isa<shard::ShardOp, shard::ShardingOp, shard::GetShardingOp>(op))
     return success();
 
   if (!shardingOp) {
@@ -290,14 +289,13 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
     return failure();
   }
 
-  // collect MeshSharding from results
-  std::vector<MeshSharding> allowConflictsResultShardings;
+  // collect Sharding from results
+  std::vector<Sharding> allowConflictsResultShardings;
   allowConflictsResultShardings.resize(op->getNumResults());
-  std::vector<MeshSharding> resultMustShardings;
+  std::vector<Sharding> resultMustShardings;
   resultMustShardings.resize(op->getNumResults());
   for (OpResult result : op->getResults()) {
-    FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
-        getMeshSharding(result);
+    FailureOr<std::pair<bool, Sharding>> maybeShardAttr = getSharding(result);
     if (failed(maybeShardAttr))
       continue;
     if (!maybeShardAttr->first)
@@ -307,14 +305,14 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
           maybeShardAttr->second;
   }
 
-  // collect MeshSharding from operands
-  std::vector<MeshSharding> allowConflictsOperandShardings;
+  // collect Sharding from operands
+  std::vector<Sharding> allowConflictsOperandShardings;
   allowConflictsOperandShardings.resize(op->getNumOperands());
-  std::vector<MeshSharding> operandMustShardings;
+  std::vector<Sharding> operandMustShardings;
   operandMustShardings.resize(op->getNumOperands());
   for (OpOperand &opOperand : op->getOpOperands()) {
-    FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
-        getMeshSharding(opOperand);
+    FailureOr<std::pair<bool, Sharding>> maybeShardAttr =
+        getSharding(opOperand);
     if (failed(maybeShardAttr))
       continue;
 
@@ -327,10 +325,10 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
   }
 
   // try to get the sharding option
-  SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
+  SmallVector<std::vector<Sharding>> possibleOperandShardingAttrs =
       getOrderedPossibleShardingAttrs(operandMustShardings,
                                       allowConflictsOperandShardings);
-  SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
+  SmallVector<std::vector<Sharding>> possibleResultShardingAttrs =
       getOrderedPossibleShardingAttrs(resultMustShardings,
                                       allowConflictsResultShardings);
   FailureOr<ShardingOption> shardingOption = selectShardingOption(
@@ -358,7 +356,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
 // ShardingPropagation
 //===----------------------------------------------------------------------===//
 struct ShardingPropagation
-    : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+    : public shard::impl::ShardingPropagationBase<ShardingPropagation> {
 
   using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
 
@@ -376,8 +374,7 @@ struct ShardingPropagation
     LLVM_DEBUG(
         DBGS() << "print all the ops' iterator types and indexing maps in the "
                   "block.\n";
-        for (Operation &op
-             : block.getOperations()) {
+        for (Operation &op : block.getOperations()) {
           if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
         });
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
similarity index 66%
rename from mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
rename to mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
index 1315502801d72..a17671e5408c4 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
@@ -1,4 +1,4 @@
-//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
+//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
 #include "TransformsDetail.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
@@ -18,7 +18,7 @@
 #include <numeric>
 
 namespace mlir {
-namespace mesh {
+namespace shard {
 
 void populateSimplificationPatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
@@ -52,53 +52,53 @@ namespace {
 // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
 // symbol tables.
 // We can't use DialectFoldInterface since the cache may be invalidated by some
-// pass changing the referenced MeshOp ops.
-struct MeshShapeFolder
-    : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+// pass changing the referenced GridOp ops.
+struct GridShapeFolder
+    : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
   using OpRewritePatternWithSymbolTableCollection::
       OpRewritePatternWithSymbolTableCollection;
-  LogicalResult matchAndRewrite(MeshShapeOp op,
+  LogicalResult matchAndRewrite(GridShapeOp op,
                                 PatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
-    MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
-        op.getOperation(), op.getMeshAttr());
-    if (!mesh) {
+    GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
+        op.getOperation(), op.getGridAttr());
+    if (!grid) {
       return failure();
     }
-    ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
-    SmallVector<MeshAxis> opAxesIota;
-    if (opMeshAxes.empty()) {
-      opAxesIota.resize(mesh.getRank());
+    ArrayRef<GridAxis> opGridAxes = op.getAxes();
+    SmallVector<GridAxis> opAxesIota;
+    if (opGridAxes.empty()) {
+      opAxesIota.resize(grid.getRank());
       std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
-      opMeshAxes = opAxesIota;
+      opGridAxes = opAxesIota;
     }
-    if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
-          return ShapedType::isDynamic(mesh.getShape()[axis]);
+    if (llvm::all_of(opGridAxes, [&grid](GridAxis axis) {
+          return ShapedType::isDynamic(grid.getShape()[axis]);
         })) {
-      // All mesh dimensions are dynamic. Nothing to fold.
+      // All grid dimensions are dynamic. Nothing to fold.
       return failure();
     }
 
     SmallVector<Value> newResults(op->getResults().size());
-    SmallVector<MeshAxis> newShapeOpMeshAxes;
+    SmallVector<GridAxis> newShapeOpGridAxes;
     SmallVector<size_t> newToOldResultsIndexMap;
 
-    for (size_t i = 0; i < opMeshAxes.size(); ++i) {
-      auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
-      if (ShapedType::isDynamic(meshAxisSize)) {
+    for (size_t i = 0; i < opGridAxes.size(); ++i) {
+      auto gridAxisSize = grid.getShape()[opGridAxes[i]];
+      if (ShapedType::isDynamic(gridAxisSize)) {
         newToOldResultsIndexMap.push_back(i);
-        newShapeOpMeshAxes.push_back(opMeshAxes[i]);
+        newShapeOpGridAxes.push_back(opGridAxes[i]);
       } else {
-        // Fold static mesh axes.
+        // Fold static grid axes.
         newResults[i] = arith::ConstantOp::create(
-            builder, builder.getIndexAttr(meshAxisSize));
+            builder, builder.getIndexAttr(gridAxisSize));
       }
     }
 
-    // Leave only the dynamic mesh axes to be queried.
-    if (!newShapeOpMeshAxes.empty()) {
-      MeshShapeOp newShapeOp =
-          MeshShapeOp::create(builder, mesh.getSymName(), newShapeOpMeshAxes);
+    // Leave only the dynamic grid axes to be queried.
+    if (!newShapeOpGridAxes.empty()) {
+      GridShapeOp newShapeOp =
+          GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
       for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
         newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
       }
@@ -113,8 +113,8 @@ struct MeshShapeFolder
 
 void populateFoldingPatterns(RewritePatternSet &patterns,
                              SymbolTableCollection &symbolTableCollection) {
-  patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
+  patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
 }
 
-} // namespace mesh
+} // namespace shard
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
similarity index 78%
rename from mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
rename to mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index 1bde1af28d8c3..772e66fee5c56 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
 #include "TransformsDetail.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Utils.h"
@@ -14,8 +14,8 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -28,12 +28,12 @@
 #include <iterator>
 #include <numeric>
 
-namespace mlir::mesh {
+namespace mlir::shard {
 
 namespace {
 
-/// Lower `mesh.process_multi_index` into expression using
-/// `mesh.process_linear_index` and `mesh.mesh_shape`.
+/// Lower `shard.process_multi_index` into expression using
+/// `shard.process_linear_index` and `shard.grid_shape`.
 struct ProcessMultiIndexOpLowering
     : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
   using OpRewritePatternWithSymbolTableCollection::
@@ -41,30 +41,30 @@ struct ProcessMultiIndexOpLowering
 
   LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    MeshOp mesh = getMesh(op, symbolTableCollection);
-    if (!mesh) {
+    GridOp grid = getGrid(op, symbolTableCollection);
+    if (!grid) {
       return failure();
     }
 
     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
     builder.setInsertionPointAfter(op.getOperation());
-    Value linearIndex = ProcessLinearIndexOp::create(builder, mesh);
-    ValueRange meshShape = MeshShapeOp::create(builder, mesh).getResults();
+    Value linearIndex = ProcessLinearIndexOp::create(builder, grid);
+    ValueRange gridShape = GridShapeOp::create(builder, grid).getResults();
     SmallVector<Value> completeMultiIndex =
         affine::AffineDelinearizeIndexOp::create(builder, linearIndex,
-                                                 meshShape)
+                                                 gridShape)
             .getMultiIndex();
     SmallVector<Value> multiIndex;
-    ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
-    SmallVector<MeshAxis> opAxesIota;
-    if (opMeshAxes.empty()) {
-      opAxesIota.resize(mesh.getRank());
+    ArrayRef<GridAxis> opGridAxes = op.getAxes();
+    SmallVector<GridAxis> opAxesIota;
+    if (opGridAxes.empty()) {
+      opAxesIota.resize(grid.getRank());
       std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
-      opMeshAxes = opAxesIota;
+      opGridAxes = opAxesIota;
     }
-    llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
-                    [&completeMultiIndex](MeshAxis meshAxis) {
-                      return completeMultiIndex[meshAxis];
+    llvm::transform(opGridAxes, std::back_inserter(multiIndex),
+                    [&completeMultiIndex](GridAxis gridAxis) {
+                      return completeMultiIndex[gridAxis];
                     });
     rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
     return success();
@@ -86,15 +86,15 @@ struct AllSliceOpLowering
     // axis.
     // The slice axis is split into equisized parts with count
     // the number of processes in the collective process group induced by
-    // the mesh axes.
+    // the grid axes.
     // The part for each process is determined by the corresponding
     // linear-index in the process group.
     //
     // There are no collectives that require communication.
     // Each process operates on its local tensor.
 
-    MeshOp mesh = getMesh(op, symbolTableCollection);
-    if (!mesh) {
+    GridOp grid = getGrid(op, symbolTableCollection);
+    if (!grid) {
       return failure();
     }
 
@@ -104,15 +104,15 @@ struct AllSliceOpLowering
     Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0));
 
     Operation::result_range processInGroupMultiIndex =
-        ProcessMultiIndexOp::create(builder, mesh.getSymName(),
-                                    op.getMeshAxes())
+        ProcessMultiIndexOp::create(builder, grid.getSymName(),
+                                    op.getGridAxes())
             .getResults();
 
     Operation::result_range processGroupShape =
-        MeshShapeOp::create(builder, mesh.getSymName(), op.getMeshAxes())
+        GridShapeOp::create(builder, grid.getSymName(), op.getGridAxes())
             .getResult();
     Value processGroupSize =
-        createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
+        createCollectiveProcessGroupSize(grid, op.getGridAxes(), builder);
 
     int64_t sliceAxis = op.getSliceAxis().getSExtValue();
     Value operandSliceAxisSize =
@@ -125,7 +125,7 @@ struct AllSliceOpLowering
     cf::AssertOp::create(builder, isTargetShapeExactlyDivisible,
                          "Slicing a tensor with axis size that is "
                          "not exactly divisible by the "
-                         "mesh process group size is not supported.");
+                         "grid process group size is not supported.");
     Value resultSliceAxisSize =
         arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize);
     OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
@@ -172,7 +172,7 @@ void populateProcessMultiIndexOpLoweringPatterns(
 }
 
 void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry) {
-  registry.insert<affine::AffineDialect, mesh::MeshDialect>();
+  registry.insert<affine::AffineDialect, shard::ShardDialect>();
 }
 
 void populateAllSliceOpLoweringPatterns(
@@ -183,7 +183,7 @@ void populateAllSliceOpLoweringPatterns(
 
 void registerAllSliceOpLoweringDialects(DialectRegistry &registry) {
   registry.insert<affine::AffineDialect, arith::ArithDialect,
-                  cf::ControlFlowDialect, mesh::MeshDialect,
+                  cf::ControlFlowDialect, shard::ShardDialect,
                   tensor::TensorDialect>();
 }
 
@@ -199,21 +199,21 @@ void registerAllOpLoweringDialects(DialectRegistry &registry) {
 }
 
 TypedValue<IndexType>
-createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
                                  ImplicitLocOpBuilder &builder) {
-  Operation::result_range meshShape =
-      mesh::MeshShapeOp::create(builder, mesh, axes).getResults();
+  Operation::result_range gridShape =
+      GridShapeOp::create(builder, grid, axes).getResults();
   return cast<TypedValue<IndexType>>(arith::createProduct(
-      builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
+      builder, builder.getLoc(), llvm::to_vector_of<Value>(gridShape),
       builder.getIndexType()));
 }
 
 TypedValue<IndexType>
-createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
-                         ArrayRef<MeshAxis> meshAxes,
+createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
+                         ArrayRef<GridAxis> gridAxes,
                          ImplicitLocOpBuilder &builder) {
   Operation::result_range processGroupShape =
-      MeshShapeOp::create(builder, mesh, meshAxes).getResult();
+      GridShapeOp::create(builder, grid, gridAxes).getResult();
   OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
       llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
       llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
@@ -225,11 +225,11 @@ createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
   return cast<TypedValue<IndexType>>(res);
 }
 
-TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
-                                               ArrayRef<MeshAxis> meshAxes,
+TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
+                                               ArrayRef<GridAxis> gridAxes,
                                                ImplicitLocOpBuilder &builder) {
   return createProcessLinearIndex(
-      mesh, ProcessMultiIndexOp::create(builder, mesh, meshAxes).getResults(),
-      meshAxes, builder);
+      grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
+      gridAxes, builder);
 }
-} // namespace mlir::mesh
+} // namespace mlir::shard
diff --git a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h b/mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h
similarity index 82%
rename from mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
rename to mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h
index 3e3f584caca24..60c9828ba736d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
+++ b/mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h
@@ -6,14 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H
 
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
 
 namespace mlir {
-namespace mesh {
+namespace shard {
 
 template <typename Op>
 struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> {
@@ -29,7 +29,7 @@ struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> {
   SymbolTableCollection &symbolTableCollection;
 };
 
-} // namespace mesh
+} // namespace shard
 } // namespace mlir
 
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H
diff --git a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
index 0421a6c0ff806..0784615b8edb8 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
-#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Tensor/Extensions/ShardingExtensions.h"
 
 using namespace mlir;
 
diff --git a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
index dba59333666f6..8f0b7da1fd7b5 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
@@ -1,10 +1,10 @@
 set(LLVM_OPTIONAL_SOURCES
   AllExtensions.cpp
-  MeshShardingExtensions.cpp
+  ShardingExtensions.cpp
   )
 
-add_mlir_extension_library(MLIRTensorMeshShardingExtensions
-  MeshShardingExtensions.cpp
+add_mlir_extension_library(MLIRTensorShardingExtensions
+  ShardingExtensions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
@@ -22,5 +22,5 @@ add_mlir_extension_library(MLIRTensorAllExtensions
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
 
   LINK_LIBS PUBLIC
-  MLIRTensorMeshShardingExtensions
+  MLIRTensorShardingExtensions
   )
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
similarity index 74%
rename from mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
rename to mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
index 7e4a5acb9867d..ca7287cec55ce 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
@@ -6,15 +6,15 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/DialectRegistry.h"
 
 using namespace mlir;
 using namespace mlir::tensor;
-using namespace mlir::mesh;
+using namespace mlir::shard;
 
 namespace {
 
@@ -40,20 +40,20 @@ struct CreatorOpShardingInterface
         {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
   }
 
-  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
-                        ArrayRef<MeshSharding> operandShardings,
-                        ArrayRef<MeshSharding> resultShardings,
-                        IRMapping &spmdizationMap,
-                        SymbolTableCollection &symbolTable,
-                        OpBuilder &builder) const {
+  LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands,
+                          ArrayRef<Sharding> operandShardings,
+                          ArrayRef<Sharding> resultShardings,
+                          IRMapping &partitionMap,
+                          SymbolTableCollection &symbolTable,
+                          OpBuilder &builder) const {
     assert(resultShardings.size() == 1);
     auto resType = cast<RankedTensorType>(op->getResult(0).getType());
-    mlir::mesh::MeshOp mesh;
+    mlir::shard::GridOp grid;
     ShapedType shardType;
     if (resType.getRank() > 0) {
-      mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+      grid = shard::getGrid(op, resultShardings[0].getGridAttr(), symbolTable);
       shardType =
-          cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
+          cast<ShapedType>(shard::shardType(resType, grid, resultShardings[0]));
     } else {
       shardType = resType;
     }
@@ -67,7 +67,7 @@ struct CreatorOpShardingInterface
       auto oldType = cast<ShapedType>(resType);
       assert(oldType.getRank() == shardType.getRank());
       int currOldOprndNum = -1;
-      mesh::ShardShapeOp shapeForDevice;
+      shard::ShardShapeOp shapeForDevice;
       ValueRange device;
       Operation *newSharding = nullptr;
       for (auto i = 0; i < oldType.getRank(); ++i) {
@@ -76,23 +76,23 @@ struct CreatorOpShardingInterface
             newSharding =
                 ShardingOp::create(builder, op->getLoc(), resultShardings[0]);
             device =
-                mesh::ProcessMultiIndexOp::create(builder, op->getLoc(), mesh)
+                shard::ProcessMultiIndexOp::create(builder, op->getLoc(), grid)
                     .getResults();
-            shapeForDevice = mesh::ShardShapeOp::create(
-                builder, op->getLoc(), oldType.getShape(), spmdizedOperands,
+            shapeForDevice = shard::ShardShapeOp::create(
+                builder, op->getLoc(), oldType.getShape(), partitionedOperands,
                 newSharding->getResult(0), device);
           }
           newOperands.emplace_back(shapeForDevice.getResult()[i]);
         } else if (oldType.isDynamicDim(i)) {
           assert(shardType.isDynamicDim(i));
-          newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
+          newOperands.emplace_back(partitionedOperands[++currOldOprndNum]);
         }
       }
       newOp = OpTy::create(builder, op->getLoc(), shardType, newOperands);
-      spmdizationMap.map(op->getResult(0), newOp->getResult(0));
+      partitionMap.map(op->getResult(0), newOp->getResult(0));
     } else {
       // `clone` will populate the mapping of old to new results.
-      newOp = builder.clone(*op, spmdizationMap);
+      newOp = builder.clone(*op, partitionMap);
     }
     newOp->getResult(0).setType(shardType);
 
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index b1fac8c85a204..c6a438d348946 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -36,7 +36,7 @@ add_mlir_dialect_library(MLIRTosaShardingInterfaceImpl
 
   LINK_LIBS PUBLIC
   MLIRIR
-  MLIRMeshDialect
+  MLIRShardDialect
   MLIRShardingInterface
   MLIRSupport
   MLIRTosaDialect
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index d3a5f44798106..45994a7ec679f 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -7,9 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/DialectRegistry.h"
@@ -19,7 +19,7 @@
 
 using namespace mlir;
 using namespace mlir::tosa;
-using namespace mlir::mesh;
+using namespace mlir::shard;
 
 namespace {
 
@@ -87,15 +87,15 @@ struct NegateOpSharding
     return maps;
   }
 
-  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
-                        ArrayRef<MeshSharding> operandShardings,
-                        ArrayRef<MeshSharding> resultShardings,
-                        IRMapping &spmdizationMap,
-                        SymbolTableCollection &symbolTable,
-                        OpBuilder &builder) const {
-    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
-                                       resultShardings, spmdizationMap,
-                                       symbolTable, builder);
+  LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+                          ArrayRef<Sharding> operandShardings,
+                          ArrayRef<Sharding> resultShardings,
+                          IRMapping &partitionMap,
+                          SymbolTableCollection &symbolTable,
+                          OpBuilder &builder) const {
+    partitionTriviallyShardableOperation(*op, partitiondOperands,
+                                         operandShardings, resultShardings,
+                                         partitionMap, symbolTable, builder);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 648e508a9788f..ecd93ff4c6e7b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -13,8 +13,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
@@ -166,7 +166,7 @@ void TosaDialect::initialize() {
       >();
   addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
   declarePromisedInterfaces<
-      mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
+      shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
       ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
       LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
       LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
similarity index 90%
rename from mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
rename to mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index d54d0034da5be..5e20b5a59d927 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -1,14 +1,14 @@
-// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-shard-to-mpi -canonicalize -split-input-file | FileCheck %s
 
 // -----
-// CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 3x4x5)
+// CHECK: shard.grid @grid0
+shard.grid @grid0(shape = 3x4x5)
 func.func @process_multi_index() -> (index, index, index) {
   // CHECK: mpi.comm_rank
   // CHECK-DAG: %[[v4:.*]] = arith.remsi
   // CHECK-DAG: %[[v0:.*]] = arith.remsi
   // CHECK-DAG: %[[v1:.*]] = arith.remsi
-  %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+  %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
   // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
@@ -17,7 +17,7 @@ func.func @process_multi_index() -> (index, index, index) {
 func.func @process_linear_index() -> index {
   // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank
   // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index
-  %0 = mesh.process_linear_index on @mesh0 : index
+  %0 = shard.process_linear_index on @grid0 : index
   // CHECK: return %[[cast]] : index
   return %0 : index
 }
@@ -29,7 +29,7 @@ func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
   %c4 = arith.constant 4 : index
   // CHECK-DAG: [[up:%.*]] = arith.constant 44 : index
   // CHECK-DAG: [[down:%.*]] = arith.constant 4 : index
-  %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index
+  %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [0] : index, index
   // CHECK: return [[down]], [[up]] : index, index
   return %idx#0, %idx#1 : index, index
 }
@@ -41,7 +41,7 @@ func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
   %c4 = arith.constant 4 : index
   // CHECK-DAG: [[up:%.*]] = arith.constant 29 : index
   // CHECK-DAG: [[down:%.*]] = arith.constant -1 : index
-  %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index
+  %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [1] : index, index
   // CHECK: return [[down]], [[up]] : index, index
   return %idx#0, %idx#1 : index, index
 }
@@ -53,20 +53,20 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
   %c4 = arith.constant 4 : index
   // CHECK-DAG: [[up:%.*]] = arith.constant -1 : index
   // CHECK-DAG: [[down:%.*]] = arith.constant 23 : index
-  %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index
+  %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [2] : index, index
   // CHECK: return [[down]], [[up]] : index, index
   return %idx#0, %idx#1 : index, index
 }
 
 // -----
-// CHECK: mesh.mesh @mesh0
+// CHECK: shard.grid @grid0
 module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
-  mesh.mesh @mesh0(shape = 3x4x5)
+  shard.grid @grid0(shape = 3x4x5)
   func.func @process_multi_index() -> (index, index, index) {
     // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
     // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
     // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
-    %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+    %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
     // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
     return %0#0, %0#1, %0#2 : index, index, index
   }
@@ -74,7 +74,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
   // CHECK-LABEL: func @process_linear_index
   func.func @process_linear_index() -> index {
     // CHECK: %[[c24:.*]] = arith.constant 24 : index
-    %0 = mesh.process_linear_index on @mesh0 : index
+    %0 = shard.process_linear_index on @grid0 : index
     // CHECK: return %[[c24]] : index
     return %0 : index
   }
@@ -82,7 +82,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
 
 // -----
 module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
-  mesh.mesh @mesh0(shape = 3x4x5)
+  shard.grid @grid0(shape = 3x4x5)
   // CHECK-LABEL: func.func @allreduce_tensor(
   func.func @allreduce_tensor(
     // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
@@ -97,7 +97,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
     // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
     // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
-    %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
+    %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
     // CHECK: return [[v2]] : tensor<3x4xf32>
     return %0 : tensor<3x4xf32>
   }
@@ -114,7 +114,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
     // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
     // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
-    %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
+    %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
     // CHECK: return [[valloc]] : memref<3x4xf32>
     return %0 : memref<3x4xf32>
   }
@@ -131,14 +131,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
     // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
     // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
-    %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
+    %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
     // CHECK: return [[valloc]] : memref<3x4xf64>
     return %0 : memref<3x4xf64>
   }
 }
 
 // -----
-mesh.mesh @mesh0(shape = 3x4x5)
+shard.grid @grid0(shape = 3x4x5)
 // CHECK-LABEL: func @update_halo_1d_first
 func.func @update_halo_1d_first(
   // CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8>
@@ -155,14 +155,14 @@ func.func @update_halo_1d_first(
   // CHECK: mpi.recv(
   // CHECK-SAME: : memref<3x120x120xi8>, i32, i32
   // CHECK: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
-  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
+  %res = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
   // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
   return %res : memref<120x120x120xi8>
 }
 
 // -----
 module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
-  mesh.mesh @mesh0(shape = 4)
+  shard.grid @grid0(shape = 4)
   // CHECK-LABEL: func @update_halo_1d_with_zero
   func.func @update_halo_1d_with_zero (
     // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
@@ -179,7 +179,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
     // CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
     // CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
     // CHECK: memref.dealloc [[valloc]] : memref<2x120x120xi8>
-    %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
+    %res = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
     // CHECK: return [[varg0]] : memref<120x120x120xi8>
     return %res : memref<120x120x120xi8>
   }
@@ -187,7 +187,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
 
 // -----
 module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
-  mesh.mesh @mesh0(shape = 3x4x5)
+  shard.grid @grid0(shape = 3x4x5)
   // CHECK-LABEL: func @update_halo_3d
   func.func @update_halo_3d(
     // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
@@ -236,7 +236,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
     // CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
     // CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32
     // CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
-    %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
+    %res = shard.update_halo %arg0 on @grid0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
     // CHECK: return [[varg0]] : memref<120x120x120xi8>
     return %res : memref<120x120x120xi8>
   }
@@ -291,18 +291,18 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
     // CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32
     // CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
     // CHECK: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
-    %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
+    %res = shard.update_halo %arg0 on @grid0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
     // CHECK: return [[v4]] : tensor<120x120x120xi8>
     return %res : tensor<120x120x120xi8>
   }
 }
 
 // -----
-mesh.mesh @mesh0(shape = 2x2x4)
+shard.grid @grid0(shape = 2x2x4)
 // CHECK-LABEL: func.func @return_sharding(
 // CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
-func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) {
-  %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding
+func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.sharding) {
+  %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] : !shard.sharding
   // CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
   // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
   // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
@@ -316,13 +316,13 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sh
   // CHECK: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
   // CHECK: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
   // CHECK: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
-  return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding
+  return %arg0, %sharding : tensor<2x4xf32>, !shard.sharding
 }
 
 // CHECK-LABEL: func.func @return_sharding_halos(
 // CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
-func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) {
-  %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding
+func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !shard.sharding) {
+  %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !shard.sharding
   // CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
   // CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
   // CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
@@ -336,13 +336,13 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !m
   // CHECK: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
   // CHECK: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
   // CHECK: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
-  return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding
+  return %arg0, %sharding : tensor<6x8xf32>, !shard.sharding
 }
 
 // CHECK-LABEL: func.func @return_sharding_offs(
 // CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
-func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !mesh.sharding) {
-  %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding
+func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !shard.sharding) {
+  %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !shard.sharding
   // CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
   // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
   // CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
@@ -362,5 +362,5 @@ func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !me
   // CHECK: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
   // CHECK: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
   // CHECK: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
-  return %arg0, %sharding : tensor<?x?xf32>, !mesh.sharding
+  return %arg0, %sharding : tensor<?x?xf32>, !shard.sharding
 }
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
similarity index 62%
rename from mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
rename to mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
index 156bbfb54845b..9729d2bfb384e 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
@@ -1,21 +1,21 @@
-// RUN: mlir-opt %s --convert-mesh-to-mpi -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --convert-shard-to-mpi -canonicalize | FileCheck %s
 
 module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
 
-  // CHECK: mesh.mesh @mesh0
-  mesh.mesh @mesh0(shape = 3x4x5)
+  // CHECK: shard.grid @grid0
+  shard.grid @grid0(shape = 3x4x5)
   
-  // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @mesh0
+  // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @grid0
 
   // all shards are equal
   // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
   func.func @shard_shape_equal() -> (index, index, index) {
-    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
-    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+    %0:3 = shard.process_multi_index on @grid0 : index, index, index
     %c9 = arith.constant 9 : index
     %c12 = arith.constant 12 : index
     // CHECK: [[vc3:%.*]] = arith.constant 3 : index
-    %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
     // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
     return %1#0, %1#1, %1#2 : index, index, index
   }
@@ -23,13 +23,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
   // last shard in last dim gets an extra element
   // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
   func.func @shard_shape_odd_1() -> (index, index, index) {
-    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
-    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+    %0:3 = shard.process_multi_index on @grid0 : index, index, index
     %c9 = arith.constant 9 : index
     %c12 = arith.constant 12 : index
     // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
     // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
-    %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    %1:3 = shard.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
     // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
     return %1#0, %1#1, %1#2 : index, index, index
   }
@@ -37,11 +37,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
   // In the second dimension the shard sizes are now [3 4 4 4]
   // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
   func.func @shard_shape_odd_2() -> (index, index, index) {
-    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
-    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+    %0:3 = shard.process_multi_index on @grid0 : index, index, index
     %c9 = arith.constant 9 : index
     // CHECK: [[vc3:%.*]] = arith.constant 3 : index
-    %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    %1:3 = shard.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
     // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
     return %1#0, %1#1, %1#2 : index, index, index
   }
@@ -49,11 +49,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
   // In the first dimension the shard sizes are now [3 4 4]
   // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
   func.func @shard_shape_odd_3() -> (index, index, index) {
-    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
-    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+    %0:3 = shard.process_multi_index on @grid0 : index, index, index
     // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
     // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
-    %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    %1:3 = shard.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
     // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
     return %1#0, %1#1, %1#2 : index, index, index
   }
@@ -61,14 +61,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
   // extract from sharded_dims_offsets
   // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
   func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
-    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
-        sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
-    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]]
+        sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !shard.sharding
+    %0:3 = shard.process_multi_index on @grid0 : index, index, index
     %c9 = arith.constant 9 : index
     %c12 = arith.constant 12 : index
     // CHECK: [[vc3:%.*]] = arith.constant 3 : index
     // CHECK: [[vc2:%.*]] = arith.constant 2 : index
-    %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
     // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
     return %1#0, %1#1, %1#2 : index, index, index
   }
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.mlir b/mlir/test/Dialect/Arith/mesh-spmdize.mlir
deleted file mode 100644
index 6b55dd533a92c..0000000000000
--- a/mlir/test/Dialect/Arith/mesh-spmdize.mlir
+++ /dev/null
@@ -1,17 +0,0 @@
-// RUN: mlir-opt \
-// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
-// RUN:   %s | FileCheck %s
-
-mesh.mesh @mesh4x4(shape = 4x4)
-
-// CHECK-LABEL: func @test_spmdize_constant
-// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
-// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
-// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
-func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
-  %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-  %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
-  %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
-  %ci = arith.constant 434 : i32
-  return %sharding_annotated_1 : tensor<1024x1024xf32>
-}
diff --git a/mlir/test/Dialect/Arith/shard-partition.mlir b/mlir/test/Dialect/Arith/shard-partition.mlir
new file mode 100644
index 0000000000000..be894278e5e95
--- /dev/null
+++ b/mlir/test/Dialect/Arith/shard-partition.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(shard-partition))" \
+// RUN:   %s | FileCheck %s
+
+shard.grid @grid4x4(shape = 4x4)
+
+// CHECK-LABEL: func @test_partition_constant
+// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
+// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
+// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
+func.func @test_partition_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
+  %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+  %sharded_1 = shard.shard %cst to %sharding_1 : tensor<1024x1024xf32>
+  %ci = arith.constant 434 : i32
+  return %sharded_1 : tensor<1024x1024xf32>
+}
diff --git a/mlir/test/Dialect/Arith/sharding-propagation.mlir b/mlir/test/Dialect/Arith/sharding-propagation.mlir
index 19eb340549b0b..762620d9dae0c 100644
--- a/mlir/test/Dialect/Arith/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Arith/sharding-propagation.mlir
@@ -1,54 +1,54 @@
 // RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
 
-mesh.mesh @mesh4x4(shape = 4x4)
+shard.grid @grid4x4(shape = 4x4)
 
 // CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
 // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
 // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
 // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
-// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
-// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_1:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_2:%.*]] = shard.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_3:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_4:%.*]] = shard.shard [[vsharded]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_5:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_6:%.*]] = shard.shard [[vsharded_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharded_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharded_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_7:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
+// CHECK-NEXT: return [[vsharded_8]] : tensor<1024x1024xf32>
 func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
     %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-    %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
-    %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+    %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+    %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
     %ci = arith.constant 43.4e+00 : f32
     %o1 = tensor.empty() : tensor<1024x1024xf32>
-    %res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+    %res = linalg.add ins(%sharded_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
     return %res : tensor<1024x1024xf32>
 }
 
 // CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
 // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
 // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
 // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
-// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_1:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_2:%.*]] = shard.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_3:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_4:%.*]] = shard.shard [[vsharded]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_5:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_6:%.*]] = shard.shard [[vsharded_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharded_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharded_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_7:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
 func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
     %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
     %ci = arith.constant 43.4e+00 : f32
     %o1 = tensor.empty() : tensor<1024x1024xf32>
     %res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
-    %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
-    %sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32>
-    return %sharding_annotated_1 : tensor<1024x1024xf32>
+    %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+    %sharded_1 = shard.shard %res to %sharding_1 : tensor<1024x1024xf32>
+    return %sharded_1 : tensor<1024x1024xf32>
 }
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
deleted file mode 100644
index 5297eeb666c1e..0000000000000
--- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
+++ /dev/null
@@ -1,42 +0,0 @@
-// RUN: mlir-opt \
-// RUN:   --verify-each \
-// RUN:   --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
-// RUN:   %s | FileCheck %s
-
-mesh.mesh @mesh_2(shape = 2)
-
-// CHECK-LABEL: func @matmul_shard_prallel_axis
-func.func @matmul_shard_prallel_axis(
-  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
-  %arg0 : tensor<2x3xf32>,
-  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
-  %arg1 : tensor<3x2xf32>,
-  // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
-  %out_dps: tensor<2x2xf32>
-) -> tensor<2x2xf32> {
-  // CHECK: %[[SIN1_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
-  // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
-  // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
-  // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
-  // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
-  // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
-  // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
-  // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
-  %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding
-  %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
-
-  // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
-  // CHECK-SAME:  outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
-  %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
-    outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
-
-  // CHECK: %[[SRES_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
-  // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32>
-  // CHECK: %[[SRES_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
-  // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32>
-  %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding
-  %res_sharded = mesh.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32>
-
-  // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
-  return %res_sharded : tensor<2x2xf32>
-}
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/shard-partition.mlir
similarity index 50%
rename from mlir/test/Dialect/Linalg/mesh-spmdization.mlir
rename to mlir/test/Dialect/Linalg/shard-partition.mlir
index ce12b296df1fa..aee97079fb197 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/shard-partition.mlir
@@ -1,15 +1,15 @@
 // RUN: mlir-opt \
-// RUN:  --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
+// RUN:  --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
 // RUN:  --split-input-file \
 // RUN:  %s | FileCheck %s
 
 // CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)>
 #map_identity_1d = affine_map<(d0) -> (d0)>
 
-mesh.mesh @mesh_1d(shape = 2)
+shard.grid @grid_1d(shape = 2)
 
-// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor
-func.func @elementwise_static_1d_mesh_static_1d_tensor(
+// CHECK-LABEL: func @elementwise_static_1d_grid_static_1d_tensor
+func.func @elementwise_static_1d_grid_static_1d_tensor(
   // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>,
   %in1: tensor<2xi8>,
   // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>,
@@ -18,13 +18,13 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor(
   %dps_out: tensor<2xi8>
 // CHECK-SAME: -> tensor<1xi8> {
 ) -> tensor<2xi8> {
-  %sharding = mesh.sharding @mesh_1d split_axes = [[0]]  : !mesh.sharding
-  %in1_sharded1 = mesh.shard %in1 to %sharding  : tensor<2xi8>
-  %in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8>
-  %in2_sharded1 = mesh.shard %in2 to %sharding : tensor<2xi8>
-  %in2_sharded2 = mesh.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8>
-  %dps_out_sharded1 = mesh.shard %dps_out to %sharding : tensor<2xi8>
-  %dps_out_shared2 = mesh.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+  %sharding = shard.sharding @grid_1d split_axes = [[0]]  : !shard.sharding
+  %in1_sharded1 = shard.shard %in1 to %sharding  : tensor<2xi8>
+  %in1_sharded2 = shard.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+  %in2_sharded1 = shard.shard %in2 to %sharding : tensor<2xi8>
+  %in2_sharded2 = shard.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+  %dps_out_sharded1 = shard.shard %dps_out to %sharding : tensor<2xi8>
+  %dps_out_shared2 = shard.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8>
   // CHECK: %[[RES:.*]] = linalg.generic {
   // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]],
   // CHECK-SAME: iterator_types = ["parallel"]}
@@ -39,18 +39,18 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor(
       %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8
       linalg.yield %res_scalar : i8
     } -> tensor<2xi8>
-  %res_sharded1 = mesh.shard %res to %sharding : tensor<2xi8>
-  %res_shared2 = mesh.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+  %res_sharded1 = shard.shard %res to %sharding : tensor<2xi8>
+  %res_shared2 = shard.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8>
   // CHECK: return %[[RES]] : tensor<1xi8>
   return %res_shared2 : tensor<2xi8>
 }
 
 // -----
 
-mesh.mesh @mesh_1d(shape = 4)
+shard.grid @grid_1d(shape = 4)
 
-// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding
-func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
+// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_sharding
+func.func @matmul_1d_grid_static_tensors_parallel_iterator_sharding(
   // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>,
   %in1: tensor<4x3xi8>,
 // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>,
@@ -59,32 +59,32 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
   %dps_out: tensor<4x8xi8>
 // CHECK-SAME: -> tensor<1x8xi8> {
 ) -> tensor<4x8xi8> {
-  %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x3xi8>
-  %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8>
-  %sharding2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<3x8xi8>
-  %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8>
-  %dps_out_shared1 = mesh.shard %dps_out to %sharding : tensor<4x8xi8>
-  %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
+  %sharding = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %in1_shared1 = shard.shard %in1 to %sharding : tensor<4x3xi8>
+  %in1_shared2 = shard.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8>
+  %sharding2 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %in2_shared1 = shard.shard %in2 to %sharding2 : tensor<3x8xi8>
+  %in2_shared2 = shard.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8>
+  %dps_out_shared1 = shard.shard %dps_out to %sharding : tensor<4x8xi8>
+  %dps_out_shared2 = shard.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
   // CHECK: %[[RES:.*]] = linalg.matmul
   // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>)
   // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>)
   // CHECK-SAME: -> tensor<1x8xi8>
   %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>)
       outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
-  %res_shared1 = mesh.shard %res to %sharding : tensor<4x8xi8>
-  %res_shared2 = mesh.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
+  %res_shared1 = shard.shard %res to %sharding : tensor<4x8xi8>
+  %res_shared2 = shard.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
   // CHECK: return %[[RES]] : tensor<1x8xi8>
   return %res_shared2 : tensor<4x8xi8>
 }
 
 // -----
 
-mesh.mesh @mesh_1d(shape = 3)
+shard.grid @grid_1d(shape = 3)
 
-// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding
-func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
+// CHECK-LABEL: func @matmul_1d_grid_static_tensors_reduction_iterator_sharding
+func.func @matmul_1d_grid_static_tensors_reduction_iterator_sharding(
   // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
   %in1: tensor<4x6xi8>,
 // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
@@ -93,19 +93,19 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
   %dps_out: tensor<4x8xi8>
 // CHECK-SAME: -> tensor<4x8xi8> {
 ) -> tensor<4x8xi8> {
-  %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
-  %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8>
-  %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
-  %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8>
-  %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
-  %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8>
-  %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
+  %sharding = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+  %in1_shared1 = shard.shard %in1 to %sharding : tensor<4x6xi8>
+  %in1_shared2 = shard.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
+  %sharding2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %in2_shared1 = shard.shard %in2 to %sharding2 : tensor<6x8xi8>
+  %in2_shared2 = shard.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
+  %sharding3 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %dps_out_shared1 = shard.shard %dps_out to %sharding3 : tensor<4x8xi8>
+  %dps_out_shared2 = shard.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
   // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG:  %[[C0_I8:.*]] = arith.constant 0 : i8
-  // CHECK-DAG:  %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
-  // CHECK-DAG:  %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+  // CHECK-DAG:  %[[PROCESS_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index
+  // CHECK-DAG:  %[[SHARD_SIZE:.*]] = shard.grid_shape @grid_1d axes = [0] : index
   // CHECK:      %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
   // CHECK:      %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
   // CHECK:        scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
@@ -117,21 +117,21 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
   // CHECK:      }
   // CHECK:      %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
   // CHECK-SAME:     outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
-  // CHECK:      %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
+  // CHECK:      %[[ALL_REDUCED:.*]] = shard.all_reduce %[[SHARDED_MATMUL]] on @grid_1d grid_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
   %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
       outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
-  %res_shared1 = mesh.shard %res to %sharding3 : tensor<4x8xi8>
-  %res_shared2 = mesh.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
+  %res_shared1 = shard.shard %res to %sharding3 : tensor<4x8xi8>
+  %res_shared2 = shard.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
   // CHECK:      return %[[ALL_REDUCED]] : tensor<4x8xi8>
   return %res_shared2 : tensor<4x8xi8>
 }
 
 // -----
 
-mesh.mesh @mesh_1d(shape = 4)
+shard.grid @grid_1d(shape = 4)
 
-// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis
-func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
+// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis
+func.func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis(
   // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>,
   %in1: tensor<4x6xi8>,
   // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>,
@@ -140,25 +140,25 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
   %dps_out: tensor<4x8xi8>
   // CHECK-SAME: -> tensor<4x8xi8> {
 ) -> tensor<4x8xi8> {
-  %sharding1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
-  %in1_replicated1 = mesh.shard %in1 to %sharding1 : tensor<4x6xi8>
-  %in1_replicated2 = mesh.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8>
-  // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
-  %in2_replicated = mesh.shard %in2 to %sharding1 : tensor<6x8xi8>
-  %sharding2 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
-  %in2_sharded = mesh.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8>
-  // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
-  %dps_out_replicated = mesh.shard %dps_out to %sharding1 : tensor<4x8xi8>
-  %dps_out_sharded = mesh.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8>
+  %sharding1 = shard.sharding @grid_1d split_axes = [[], []] : !shard.sharding
+  %in1_replicated1 = shard.shard %in1 to %sharding1 : tensor<4x6xi8>
+  %in1_replicated2 = shard.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8>
+  // CHECK: %[[ALL_SLICE1:.*]] = shard.all_slice %[[IN2]] on @grid_1d grid_axes = [0] slice_axis = 1
+  %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x8xi8>
+  %sharding2 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+  %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8>
+  // CHECK: %[[ALL_SLICE2:.*]] = shard.all_slice %[[DPS_OUT]] on @grid_1d grid_axes = [0] slice_axis = 1
+  %dps_out_replicated = shard.shard %dps_out to %sharding1 : tensor<4x8xi8>
+  %dps_out_sharded = shard.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8>
   // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
   // CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
   // CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
   // CHECK-SAME: -> tensor<4x2xi8>
   %res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>)
       outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8>
-  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
-  %res_sharded = mesh.shard %res to %sharding2 : tensor<4x8xi8>
-  %res_replicated = mesh.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8>
+  // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[MATMUL_RES]] on @grid_1d grid_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
+  %res_sharded = shard.shard %res to %sharding2 : tensor<4x8xi8>
+  %res_replicated = shard.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8>
   // CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
   return %res_replicated : tensor<4x8xi8>
 }
diff --git a/mlir/test/Dialect/Linalg/sharding-propagation.mlir b/mlir/test/Dialect/Linalg/sharding-propagation.mlir
new file mode 100644
index 0000000000000..e0ecefcf2d6bd
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/sharding-propagation.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt \
+// RUN:   --verify-each \
+// RUN:   --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
+// RUN:   %s | FileCheck %s
+
+shard.grid @grid_2(shape = 2)
+
+// CHECK-LABEL: func @matmul_shard_prallel_axis
+func.func @matmul_shard_prallel_axis(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
+  %arg0 : tensor<2x3xf32>,
+  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
+  %arg1 : tensor<3x2xf32>,
+  // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
+  %out_dps: tensor<2x2xf32>
+) -> tensor<2x2xf32> {
+  // CHECK: %[[SIN1_ANNOTATED_0:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+  // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = shard.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
+  // CHECK: %[[SIN1_ANNOTATED_1:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+  // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = shard.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
+  // CHECK: %[[SIN2_ANNOTATED:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[]] : !shard.sharding
+  // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = shard.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
+  // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+  // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = shard.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
+  %sarg0_sharded = shard.sharding @grid_2 split_axes = [[0]] : !shard.sharding
+  %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
+
+  // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+  // CHECK-SAME:  outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+  %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+    outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+  // CHECK: %[[SRES_ANNOTATED_0:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+  // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = shard.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32>
+  // CHECK: %[[SRES_ANNOTATED_1:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[]] : !shard.sharding
+  // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = shard.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32>
+  %sres_sharded = shard.sharding @grid_2 split_axes = [[]] : !shard.sharding
+  %res_sharded = shard.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32>
+
+  // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
+  return %res_sharded : tensor<2x2xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
deleted file mode 100644
index aff07bbf8a214..0000000000000
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ /dev/null
@@ -1,248 +0,0 @@
-// RUN: mlir-opt --canonicalize %s | FileCheck %s
-
-mesh.mesh @mesh0(shape = 2x4)
-
-// CHECK-LABEL: func @all_reduce_empty_mesh_axes
-func.func @all_reduce_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
-    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.all_reduce
-  %0 = mesh.all_reduce %arg0 on @mesh0
-    mesh_axes = []
-    : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type
-func.func @all_reduce_empty_mesh_axes_different_return_type(
-    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-// CHECK: mesh.all_reduce
-  %0 = mesh.all_reduce %arg0 on @mesh0
-// CHECK-NOT: mesh_axes
-    mesh_axes = []
-    : tensor<4xf32> -> tensor<4xf64>
-  return %0 : tensor<4xf64>
-}
-
-// CHECK-LABEL: func @all_reduce_default_reduction
-func.func @all_reduce_default_reduction(
-    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-  %0 = mesh.all_reduce %arg0 on @mesh0
-    mesh_axes = [0]
-// CHECK-NOT: reduction
-    reduction = sum
-    : tensor<4xf32> -> tensor<4xf64>
-  return %0 : tensor<4xf64>
-}
-
-// CHECK-LABEL: func @all_to_all_empty_mesh_axes
-func.func @all_to_all_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
-    %arg0 : tensor<8xf32>) -> tensor<8xf32> {
-// CHECK-NOT: mesh.all_to_all
-  %0 = mesh.all_to_all %arg0 on @mesh0
-    mesh_axes = []
-    split_axis = 0
-    concat_axis = 0
-    : tensor<8xf32> -> tensor<8xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<8xf32>
-}
-
-// CHECK-LABEL: func @all_gather_empty_mesh_axes
-func.func @all_gather_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
-    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.all_gather
-  %0 = mesh.all_gather %arg0 on @mesh0
-    mesh_axes = []
-    gather_axis = 0
-    : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @all_slice_empty_mesh_axes
-func.func @all_slice_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
-    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.scatter
-  %0 = mesh.all_slice %arg0 on @mesh0
-    mesh_axes = []
-    slice_axis = 0
-    : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @broadcast_empty_mesh_axes
-func.func @broadcast_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
-    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.broadcast
-  %0 = mesh.broadcast %arg0 on @mesh0
-    mesh_axes = []
-    root = []
-    : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @gather_empty_mesh_axes
-func.func @gather_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
-    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.gather
-  %0 = mesh.gather %arg0 on @mesh0
-    mesh_axes = []
-    gather_axis = 0
-    root = []
-    : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @receive_empty_mesh_axes
-func.func @receive_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
-    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.recv
-  %0 = mesh.recv %arg0 on @mesh0
-    mesh_axes = []
-    : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @reduce_empty_mesh_axes
-func.func @reduce_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
-    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.reduce
-  %0 = mesh.reduce %arg0 on @mesh0
-    mesh_axes = []
-    root = []
-    : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
-func.func @reduce_scatter_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
-    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.reduce_scatter
-  %0 = mesh.reduce_scatter %arg0 on @mesh0
-    mesh_axes = []
-    scatter_axis = 0
-    : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type
-func.func @reduce_scatter_empty_mesh_axes_different_return_type(
-    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-// CHECK: mesh.reduce_scatter
-  %0 = mesh.reduce_scatter %arg0 on @mesh0
-// CHECK-NOT: mesh_axes
-    mesh_axes = []
-    scatter_axis = 0
-    : tensor<4xf32> -> tensor<4xf64>
-  return %0 : tensor<4xf64>
-}
-
-// CHECK-LABEL: func @reduce_scatter_default_reduction
-func.func @reduce_scatter_default_reduction(
-    %arg0 : tensor<4xf32>) -> tensor<2xf64> {
-  %0 = mesh.reduce_scatter %arg0 on @mesh0
-    mesh_axes = [0]
-// CHECK-NOT: reduction
-    reduction = sum
-    scatter_axis = 0
-    : tensor<4xf32> -> tensor<2xf64>
-  return %0 : tensor<2xf64>
-}
-
-// CHECK-LABEL: func @scatter_empty_mesh_axes
-func.func @scatter_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
-    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.scatter
-  %0 = mesh.scatter %arg0 on @mesh0
-    mesh_axes = []
-    scatter_axis = 0
-    root = []
-    : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @send_empty_mesh_axes
-func.func @send_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
-    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.send
-  %0 = mesh.send %arg0 on @mesh0
-    mesh_axes = []
-    destination = []
-    : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
-  return %0 : tensor<4xf32>
-}
-
-mesh.mesh @mesh4x4(shape = 4x4)
-// CHECK-LABEL: func @test_halo_sizes
-func.func @test_halo_sizes() -> !mesh.sharding {
-  %c2_i64 = arith.constant 2 : i64
-  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding
-  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
-  return %sharding : !mesh.sharding
-}
-
-// CHECK-LABEL: func @test_shard_offs
-func.func @test_shard_offs() -> !mesh.sharding {
-  %c2_i64 = arith.constant 2 : i64
-  // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
-  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
-  return %sharding : !mesh.sharding
-}
-
-// CHECK-LABEL: func @test_duplicate_shardops
-func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
-  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-  %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-  // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
-  %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
-  %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-  %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
-  %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
-  %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-  %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
-  %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
-  // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
-  %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
-  // CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
-  return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
-}
-
-// CHECK-LABEL: func @test_duplicate_shardops_diff
-func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
-  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-  %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-  // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-  %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
-  %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-  // CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
-  %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
-  // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32>
-  %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
-  %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-  %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
-  %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
-  // CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32>
-  %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
-  // CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
-  return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
-}
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
deleted file mode 100644
index 369f316d0f797..0000000000000
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ /dev/null
@@ -1,22 +0,0 @@
-// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
-
-mesh.mesh @mesh0(shape = 4x?x2)
-mesh.mesh @mesh1(shape = 2x3)
-
-// CHECK-LABEL: func.func @mesh_shape_op_folding
-func.func @mesh_shape_op_folding() -> (index, index) {
-  // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
-  // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.mesh_shape @mesh0 axes = [1] : index
-  %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index
-  // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
-  return %0#0, %0#1 : index, index
-}
-
-// CHECK-LABEL: func.func @mesh_shape_op_folding_all_axes_static_mesh
-func.func @mesh_shape_op_folding_all_axes_static_mesh() -> (index, index) {
-  // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
-  // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
-  %0:2 = mesh.mesh_shape @mesh1 : index, index
-  // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
-  return %0#0, %0#1 : index, index
-}
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
deleted file mode 100644
index 6ab711b1b653c..0000000000000
--- a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
+++ /dev/null
@@ -1,49 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
-  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
-  func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
-    %c1_i32 = arith.constant 1 : i32
-    // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
-    %0 = tensor.empty() : tensor<6x6xi32>
-    // CHECK: [[v1:%.*]] = linalg.fill ins
-    // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
-    %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
-    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-    %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
-    // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
-    // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
-    %3 = tensor.empty() : tensor<6x6xi32>
-    // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
-    // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
-    // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
-    // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
-    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
-        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
-    ^bb0(%in: i32, %in_2: i32, %out: i32):
-      %9 = arith.addi %in, %in_2 : i32
-      linalg.yield %9 : i32
-    } -> tensor<6x6xi32>
-    %c0_i32 = arith.constant 0 : i32
-    %6 = tensor.empty() : tensor<i32>
-    %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
-    // CHECK: [[vreduced:%.*]] = linalg.reduce ins
-    // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
-    %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] 
-      (%in: i32, %init: i32) {
-        %9 = arith.addi %in, %init : i32
-        linalg.yield %9 : i32
-      }
-    // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
-    %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
-    %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
-    return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
-  }
-}
diff --git a/mlir/test/Dialect/Mesh/inlining.mlir b/mlir/test/Dialect/Mesh/inlining.mlir
deleted file mode 100644
index c41a709e1a4eb..0000000000000
--- a/mlir/test/Dialect/Mesh/inlining.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: mlir-opt -inline %s | FileCheck %s
-
-mesh.mesh @mesh0(shape = 4x?x2)
-
-func.func private @mesh_to_inline() -> (index, index) {
-  %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index
-  return %0#0, %0#1 : index, index
-}
-// CHECK-LABEL: func.func @main
-func.func @main() -> (index, index) {
-  // CHECK-NEXT: %[[AXIS_SIZE:.*]]:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index
-  %0:2 = func.call @mesh_to_inline() : () -> (index, index)
-  // CHECK-NEXT: return %[[AXIS_SIZE]]#0, %[[AXIS_SIZE]]#1
-  return %0#0, %0#1 : index, index
-}
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
deleted file mode 100644
index e23cfd79a4274..0000000000000
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
-
-mesh.mesh @mesh2d(shape = ?x?)
-
-// CHECK-LABEL: func.func @multi_index_2d_mesh
-func.func @multi_index_2d_mesh() -> (index, index) {
-  // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
-  // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
-  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
-  %0:2 = mesh.process_multi_index on @mesh2d : index, index
-  // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
-  return %0#0, %0#1 : index, index
-}
-
-// CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis
-func.func @multi_index_2d_mesh_single_inner_axis() -> index {
-  // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
-  // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
-  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
-  %0 = mesh.process_multi_index on @mesh2d axes = [0] : index
-  // CHECK: return %[[MULTI_IDX]]#0 : index
-  return %0 : index
-}
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
deleted file mode 100644
index 5e62c929aa4ff..0000000000000
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ /dev/null
@@ -1,168 +0,0 @@
-// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
-
-mesh.mesh @mesh_1d(shape = 2)
-mesh.mesh @mesh_1d_dynamic(shape = ?)
-
-// CHECK-LABEL: func @same_source_and_target_sharding
-func.func @same_source_and_target_sharding(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
-  %arg0: tensor<2xf32>
-) -> tensor<2xf32> {
-  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xf32>
-  // CHECK: return %[[ARG]]
-  return %1 : tensor<2xf32>
-}
-
-// CHECK-LABEL: func @identical_source_and_target_sharding
-func.func @identical_source_and_target_sharding(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
-  %arg0: tensor<2xf32>
-) -> tensor<2xf32> {
-  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
-  %1 = mesh.shard %0 to %s0 annotate_for_users : tensor<2xf32>
-  // CHECK: return %[[ARG]]
-  return %1 : tensor<2xf32>
-}
-
-// CHECK-LABEL: func @split_replicated_tensor_axis
-func.func @split_replicated_tensor_axis(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
-  %arg0: tensor<3x14xf32>
-) -> tensor<3x14xf32> {
-  // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1
-  // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
-  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
-  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<3x14xf32>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<3x14xf32>
-  // CHECK: return %[[RESULT]] : tensor<3x14xf32>
-  return %1 : tensor<3x14xf32>
-}
-
-// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
-func.func @split_replicated_tensor_axis_dynamic(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
-  %arg0: tensor<?x3x?xf32>
-) -> tensor<?x3x?xf32> {
-  // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0
-  // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
-  %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [], []] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<?x3x?xf32>
-  %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32>
-  // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
-  return %1 : tensor<?x3x?xf32>
-}
-
-// CHECK-LABEL: func @move_split_axis
-func.func @move_split_axis(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
-  %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
-  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
-  // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
-  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
-  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
-  // CHECK: return %[[RES]] : tensor<10x14xf32>
-  return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @move_split_axis_dynamic_mesh
-func.func @move_split_axis_dynamic_mesh(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
-  %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
-  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
-  // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
-  // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
-  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
-  %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
-  %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [0]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
-  // CHECK: return %[[RES]] : tensor<10x14xf32>
-  return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @move_split_dynamic_axis
-func.func @move_split_dynamic_axis(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
-  %arg0: tensor<?x14xf32>
-) -> tensor<?x14xf32> {
-  // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
-  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
-  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
-  // CHECK: return %[[RES]] : tensor<?x14xf32>
-  return %1 : tensor<?x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_static_axis
-func.func @unshard_static_axis(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
-  %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
-  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
-  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
-  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
-  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
-  return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_static_last_axis
-func.func @unshard_static_last_axis(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
-  %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
-  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
-  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
-  %s0 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
-  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
-  return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_dynamic_axis
-func.func @unshard_dynamic_axis(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
-  %arg0: tensor<?x14xf32>
-) -> tensor<?x14xf32> {
-  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
-  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
-  // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
-  return %1 : tensor<?x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
-func.func @unshard_static_axis_on_dynamic_mesh_axis(
-// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>  
-  %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
-  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
-  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
-  // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
-  %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
-  %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
-  // CHECK: return %[[RES]] : tensor<10x14xf32>
-  return %1 : tensor<10x14xf32>
-}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
deleted file mode 100644
index 0881d994d60e7..0000000000000
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ /dev/null
@@ -1,301 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s
-
-mesh.mesh @mesh_2(shape = 2)
-mesh.mesh @mesh_1d(shape = ?)
-mesh.mesh @mesh_2d(shape = 2x4)
-mesh.mesh @mesh_3d(shape = ?x?x?)
-
-// CHECK-LABEL: func.func @element_wise_empty_sharding_info
-func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
-  // CHECK-NEXT: tosa.sigmoid
-  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT: return
-  return %0 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_def
-// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
-  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
-  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]]  : tensor<8x16xf32>
-  %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1  : tensor<8x16xf32>
-  // CHECK-NEXT:  return %[[V2]]
-  return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_use
-// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
-  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
-  %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 annotate_for_users  : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
-  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]]  : tensor<8x16xf32>
-  // CHECK-NEXT:  return %[[V2]]
-  return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_graph_output
-// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
-  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
-  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]]  : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
-  %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users  : tensor<8x16xf32>
-  // CHECK-NEXT:  return %[[V3]]
-  return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_graph_input
-// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
-  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]]  : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[V0]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
-  %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0  : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = tosa.sigmoid %[[V1]]
-  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]]  : tensor<8x16xf32>
-  // CHECK-NEXT:  return %[[V3]]
-  return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @arrow_structure
-// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
-func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
-  // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users  : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = tosa.tanh %[[V1]]
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]]  : tensor<8x16xf32>
-  %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users  : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V5:.*]] = tosa.abs %[[V4]]
-  // CHECK-NEXT:  %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]]  : tensor<8x16xf32>
-  %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
-  // CHECK-NEXT:  %[[ZP1:.*]] = mesh.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32>
-  // CHECK-NEXT:  %[[ZP2:.*]] = mesh.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32>
-  // CHECK-NEXT:  %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]]
-  // CHECK-NEXT:  %[[V8:.*]] = mesh.shard %[[V7]] to %[[S1]]  : tensor<8x16xf32>
-  %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
-  %s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
-  %3 = mesh.shard %2 to %s3  : tensor<8x16xf32>
-  // CHECK-NEXT: return %[[V6]], %[[V8]]
-  return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
-// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
-func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
-  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
-  // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
-  // CHECK-NEXT:  %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
-  // CHECK-NEXT:  %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users  : tensor<1xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
-  %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]]  : tensor<2x16x32xf32>
-  %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1  : tensor<2x16x32xf32>
-  // CHECK-NEXT:  return %[[V3]]
-  return %1 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n
-// CHECK-SAME:     [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
-func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
-  // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
-  // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
-  // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [], [1]] : !mesh.sharding
-  // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32>
-  // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
-  // CHECK: [[vsharded_3:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
-  // CHECK: [[v0:%.*]] = tosa.matmul
-  %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
-  // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding
-  // CHECK: [[vsharded_5:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32>
-  %s1 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1  : tensor<2x16x32xf32>
-  // CHECK-NEXT:  return [[vsharded_5]]
-  return %1 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
-// CHECK-SAME:     [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
-func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
-  // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding
-  %s0 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding
-  // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32>
-  %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x16x8xf32>
-  // CHECK: [[vsharded_0:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
-  // CHECK: [[vsharding_1:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding
-  // CHECK: [[vsharded_2:%.*]] = mesh.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32>
-  // CHECK: [[vsharding_3:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
-  // CHECK: [[vsharded_4:%.*]] = mesh.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32>
-  // CHECK: [[v0:%.*]] = tosa.matmul
-  // CHECK: [[vsharding_5:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
-  // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32>
-  %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
-  // CHECK: return [[vsharded_6]]
-  return %0 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
-// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
-func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
-  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
-  %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 annotate_for_users  : tensor<2x16x8xf32>
-  // CHECK-NEXT:  %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
-  %s1 = mesh.sharding @mesh_2d split_axes = [[], [0]] : !mesh.sharding
-  %1 = mesh.shard %arg1 to %s1 annotate_for_users  : tensor<2x8x32xf32>
-  // CHECK-NEXT:  %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
-  // CHECK-NEXT:  %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users  : tensor<1xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
-  %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
-  // CHECK-NEXT:  %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]]  : tensor<2x16x32xf32>
-  // CHECK-NEXT:  return %[[V3]]
-  return %2 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @resolve_conflicting_annotations
-func.func @resolve_conflicting_annotations(
-  // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
-  %arg0: tensor<2x3xf32>,
-  // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
-  %arg1: tensor<3x2xf32>,
-  // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
-  %out_dps: tensor<2x2xf32>
-// CHECK-SAME: ) -> tensor<2x2xf32> {
-) -> tensor<2x2xf32> {
-  // CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding
-  // CHECK-NEXT:  %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]]  : tensor<2x3xf32>
-  // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding
-  // CHECK-NEXT:  %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<2x3xf32>
-  // CHECK-NEXT:  %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<3x2xf32>
-  // CHECK-NEXT:  %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<2x2xf32>
-  %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding
-  %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded  : tensor<2x3xf32>
-  // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
-  // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
-  %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
-    outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
-  // CHECK-NEXT: %[[RES:.*]] = mesh.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32>
-  %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding
-  %res_sharded = mesh.shard %res to %sres_sharded  : tensor<2x2xf32>
-  // CHECK: return %[[RES]] : tensor<2x2xf32>
-  return %res_sharded : tensor<2x2xf32>
-}
-
-// https://arxiv.org/abs/2211.05102 Figure 2(a)
-// The sharding propagation results in unnecessary reshards,
-//   an optimization pass should be able to remove them.
-// CHECK-LABEL: func.func @mlp_1d_weight_stationary
-// CHECK-SAME:     [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
-func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
-  %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
-  %sharded0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
-  %sharded1 = mesh.shard %arg1 to %s0 : tensor<2x8x32xf32>
-  // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
-  // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
-  // CHECK: [[vsharded_0:%.*]] = mesh.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32>
-  // CHECK: [[vsharded_1:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
-  // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0, 1, 2]] : !mesh.sharding
-  // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32>
-  // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}]] : !mesh.sharding
-  // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32>
-  // CHECK: [[v0:%.*]] = tosa.matmul
-  %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x32xf32>
-  // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32>
-  // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
-  // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
-  // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32>
-  %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
-  %sharding = mesh.sharding @mesh_1d split_axes = [[], [0, 1, 2]] : !mesh.sharding
-  // CHECK: [[vsharded_9:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32>
-  %sharded2 = mesh.shard %arg2 to %sharding  : tensor<2x32x8xf32>
-  // CHECK: [[vsharded_10:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
-  // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
-  // CHECK: [[v2:%.*]] = tosa.matmul
-  %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x8xf32>
-  // CHECK: [[vsharded_12:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
-  %s4 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
-  %4 = mesh.shard %3 to %s4  : tensor<2x4x8xf32>
-  // CHECK: return [[vsharded_12]]
-  return %4 : tensor<2x4x8xf32>
-}
-
-// https://arxiv.org/abs/2211.05102 Figure 2(b)
-// The sharding propagation results in unnecessary reshards,
-//   an optimization pass should be able to remove them.
-// CHECK-LABEL: func.func @mlp_2d_weight_stationary
-// CHECK-SAME:     [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
-func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
-    // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
-  %s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
-    // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
-  %arg0_s = mesh.shard %arg0 to %s0  : tensor<2x4x8xf32>
-    // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding
-  %s1 = mesh.sharding @mesh_3d split_axes = [[], [0], [1, 2]] : !mesh.sharding
-    // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32>
-  %arg1_s = mesh.shard %arg1 to %s1  : tensor<2x8x32xf32>
-    // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}]] : !mesh.sharding
-    // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32>
-    // CHECK: [[vsharded_4:%.*]] = mesh.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32>
-    // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
-    // CHECK: [[v0:%.*]] = tosa.matmul
-  %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x32xf32>
-    // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32>
-  %2 = mesh.shard %1 to %s0  : tensor<2x4x32xf32>
-    // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32>
-    // CHECK: [[v1:%.*]] = tosa.sigmoid 
-    // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32>
-  %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
-    // CHECK: [[vsharding_9:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding
-  %s2 = mesh.sharding @mesh_3d split_axes = [[], [1, 2], [0]] : !mesh.sharding
-    // CHECK: [[vsharded_10:%.*]] = mesh.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32>
-  %arg2_s = mesh.shard %arg2 to %s2  : tensor<2x32x8xf32>
-    // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32>
-    // CHECK: [[vsharded_12:%.*]] = mesh.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
-    // CHECK: [[v2:%.*]] = tosa.matmul
-  %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x8xf32>
-    // CHECK: [[vsharded_13:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
-  %5 = mesh.shard %4 to %s0  : tensor<2x4x8xf32>
-    // CHECK: [[vsharded_14:%.*]] = mesh.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
-  %6 = mesh.shard %5 to %s0 annotate_for_users  : tensor<2x4x8xf32>
-    // CHECK: return [[vsharded_14]]
-  return %6 : tensor<2x4x8xf32>
-}
-
-// CHECK-LABEL: func.func @elementwise_duplicated_chain
-// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
-func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
-  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
-  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
-  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V4:.*]] = tosa.sigmoid %[[V3]]
-  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V5:.*]] = mesh.shard %[[V4]] to %[[S0]]  : tensor<8x16xf32>
-  %s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding
-  %2 = mesh.shard %1 to %s0 : tensor<8x16xf32>
-  // CHECK-NEXT:  return %[[V5]]
-  return %2 : tensor<8x16xf32>
-}
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
deleted file mode 100644
index 701898cbdc74d..0000000000000
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ /dev/null
@@ -1,317 +0,0 @@
-// RUN: mlir-opt \
-// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
-// RUN:   %s | FileCheck %s
-
-mesh.mesh @mesh_1d(shape = 2)
-
-// CHECK-LABEL: func @return_sharding
-func.func @return_sharding(
-  // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
-  %arg0: tensor<2xf32>
-// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) {
-) -> (tensor<2xf32>, !mesh.sharding) {
-  %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated  : tensor<2xf32>
-  // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding
-  %r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding
-  // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding
-  return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding
-}
-
-// CHECK-LABEL: func @full_replication
-func.func @full_replication(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
-  %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<2xi8> {
-) -> tensor<2xi8> {
-  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
-  // CHECK: return %[[ARG]] : tensor<2xi8>
-  return %1 : tensor<2xi8>
-}
-
-// CHECK-LABEL: func @sharding_triplet
-func.func @sharding_triplet(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
-  %arg0: tensor<2xf32>
-// CHECK-SAME: ) -> tensor<2xf32> {
-) -> tensor<2xf32> {
-  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
-  %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated  : tensor<2xf32>
-  %ssharding_annotated_0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %sharding_annotated_0 = mesh.shard %sharding_annotated to %ssharding_annotated_0  annotate_for_users : tensor<2xf32>
-  %ssharding_annotated_1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1  : tensor<2xf32>
-  // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
-  return %sharding_annotated_1 : tensor<2xf32>
-}
-
-
-// CHECK-LABEL: func @move_split_axis
-func.func @move_split_axis(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>
-  %arg0: tensor<2x2xi8>
-// CHECK-SAME: -> tensor<2x1xi8> {
-) -> tensor<2x2xi8> {
-  // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d
-  // CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
-  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0  : tensor<2x2xi8>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2x2xi8>
-  // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
-  return %1 : tensor<2x2xi8>
-}
-
-// CHECK-LABEL: func @non_tensor_value
-func.func @non_tensor_value(
-  // CHECK-SAME: %[[ARG:.*]]: i8
-  %arg0: i8
-// CHECK-SAME: -> i8 {
-) -> i8 {
-  // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8
-  %0 = arith.addi %arg0, %arg0 : i8
-  // CHECK: return %[[RES]] : i8
-  return %0 : i8
-}
-
-// CHECK-LABEL: func @unary_elementwise
-func.func @unary_elementwise(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
-  %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<1xi8> {
-) -> tensor<2xi8> {
-  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
-  // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
-  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
-  %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %3 = mesh.shard %2 to %s3  : tensor<2xi8>
-  %s4 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %4 = mesh.shard %3 to %s4  annotate_for_users : tensor<2xi8>
-  // CHECK: return %[[RES]] : tensor<1xi8>
-  return %4 : tensor<2xi8>
-}
-
-// full replication -> shard axis -> abs -> shard axis -> full replication
-// CHECK-LABEL: func @unary_elementwise_with_resharding
-func.func @unary_elementwise_with_resharding(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
-  %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<2xi8> {
-) -> tensor<2xi8> {
-  // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
-  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
-  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
-  // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
-  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
-  // CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d
-  // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
-  %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %3 = mesh.shard %2 to %s3  : tensor<2xi8>
-  %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %4 = mesh.shard %3 to %s4  annotate_for_users : tensor<2xi8>
-  // CHECK: return %[[RES]] : tensor<2xi8>
-  return %4 : tensor<2xi8>
-}
-
-// CHECK-LABEL: func @binary_elementwise
-func.func @binary_elementwise(
-  // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>,
-  %arg0: tensor<2xi8>,
-  // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8>
-  %arg1: tensor<2xi8>
-// CHECK-SAME: -> tensor<1xi8> {
-) -> tensor<2xi8> {
-  %sarg0_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded  : tensor<2xi8>
-  %sop_arg0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %op_arg0 = mesh.shard %arg0_sharded to %sop_arg0  annotate_for_users : tensor<2xi8>
-  %sarg1_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %arg1_sharded = mesh.shard %arg1 to %sarg1_sharded  : tensor<2xi8>
-  %sop_arg1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %op_arg1 = mesh.shard %arg1_sharded to %sop_arg1  annotate_for_users : tensor<2xi8>
-  // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
-  %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
-  %sop_res_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %op_res_sharded = mesh.shard %op_res to %sop_res_sharded  : tensor<2xi8>
-  %sres = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %res = mesh.shard %op_res_sharded to %sres  annotate_for_users : tensor<2xi8>
-  // CHECK: return %[[RES]] : tensor<1xi8>
-  return %res : tensor<2xi8>
-}
-
-// reshard
-// abs
-// reshard
-// abs
-// reshard
-// CHECK-LABEL: func @multiple_chained_ops
-func.func @multiple_chained_ops(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
-  %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<1xi8> {
-) -> tensor<2xi8> {
-  // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
-  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
-  %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0  : tensor<2xi8>
-  %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1  annotate_for_users : tensor<2xi8>
-  // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
-  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
-  // CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d
-  // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
-  %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %3 = mesh.shard %2 to %s3  : tensor<2xi8>
-  %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %4 = mesh.shard %3 to %s4  annotate_for_users : tensor<2xi8>
-  // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
-  %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
-  // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 :
-  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
-  %s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
-  %6 = mesh.shard %5 to %s6  : tensor<2xi8>
-  %s7 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %7 = mesh.shard %6 to %s7  annotate_for_users : tensor<2xi8>
-  // CHECK: return %[[RESHARD3]] : tensor<1xi8>
-  return %7 : tensor<2xi8>
-}
-
-// CHECK-LABEL: func @incomplete_sharding
-func.func @incomplete_sharding(
-  // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
-  %arg0: tensor<8x16xf32>
-// CHECK-SAME: -> tensor<4x16xf32> {
-) -> tensor<8x16xf32> {
-  %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0  annotate_for_users : tensor<8x16xf32>
-  // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
-  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  %s2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
-  %2 = mesh.shard %1 to %s2  : tensor<8x16xf32>
-  // CHECK: return %[[RES]] : tensor<4x16xf32>
-  return %2 : tensor<8x16xf32>
-}
-
-mesh.mesh @mesh_1d_4(shape = 4)
-
-// CHECK-LABEL: func @ew_chain_with_halo
-func.func @ew_chain_with_halo(
-  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
-  %arg0: tensor<8x16xf32>,
-  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32>
-  %arg1: tensor<1xf32>,
-  // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32>
-  %arg2: tensor<1xf32>)
-  // CHECK-SAME: -> tensor<5x16xf32>
-   -> tensor<8x16xf32> {
-  %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
-  %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated  annotate_for_users : tensor<8x16xf32>
-  // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
-  %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  %ssharding_annotated_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
-  %sharding_annotated_0 = mesh.shard %0 to %ssharding_annotated_0  : tensor<8x16xf32>
-  %ssharding_annotated_1 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
-  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1  annotate_for_users : tensor<8x16xf32>
-  // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
-  %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  %ssharding_annotated_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
-  %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2  : tensor<8x16xf32>
-  %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
-  %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4  annotate_for_users : tensor<8x16xf32>
-  // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32>
-  %sharding_1 = mesh.sharding @mesh_1d_4 split_axes = [[]] : !mesh.sharding
-  %zero_point_1 = mesh.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32>
-  %zero_point_2 = mesh.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32>
-  %2 = tosa.negate %sharding_annotated_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
-  %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
-  %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5  : tensor<8x16xf32>
-  %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
-  %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to %ssharding_annotated_6  annotate_for_users : tensor<8x16xf32>
-  // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
-  return %sharding_annotated_6 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func @test_shard_update_halo
-// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
-func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
-  // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
-  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
-  // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
-  %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
-  %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
-  %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
-  %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
-  // CHECK: return %[[UH]] : tensor<304x1200xi64>
-  return %sharding_annotated_3 : tensor<1200x1200xi64>
-}
-
-mesh.mesh @mesh4x4(shape = 4x4)
-// CHECK-LABEL: func @test_shard_update_halo2d
-// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
-func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
-  %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
-  // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
-  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
-  // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
-  %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
-  %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
-  %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
-  %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
-  // CHECK: return %[[UH]] : tensor<303x307xi64>
-  return %sharding_annotated_3 : tensor<1200x1200xi64>
-}
-
-mesh.mesh @mesh(shape = 2)
-// CHECK-LABEL: func.func @test_reduce_0d(
-// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
-func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor<i32>) {
-  %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-  %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
-  %4 = tensor.empty() : tensor<i32>
-  %sharding_out = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
-  %sharded_out = mesh.shard %4 to %sharding_out : tensor<i32>
-  %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
-  // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
-  %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<i32>) dimensions = [0, 1] 
-    (%in: i32, %init: i32) {
-      %6 = arith.addi %in, %init : i32
-      linalg.yield %6 : i32
-    }
-  // CHECK: %[[all_reduce:.*]] = mesh.all_reduce %[[reduced]] on @mesh mesh_axes = [0] : tensor<i32> -> tensor<i32>
-  %sharded_red = mesh.shard %reduced to %sharding_out : tensor<i32>
-  %sharded_ret = mesh.shard %sharded_red to %sharding_out annotate_for_users : tensor<i32>
-  // CHECK: return %[[all_reduce]] : tensor<i32>
-  return %sharded_ret : tensor<i32>
-}
-
-// CHECK-LABEL: func.func @test_reduce_1d(
-// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
-func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
-  %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-  %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
-  %4 = tensor.empty() : tensor<6xi32>
-  %sharded_out = mesh.shard %4 to %sharding : tensor<6xi32>
-  %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
-  // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
-  %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1] 
-    (%in: i32, %init: i32) {
-      %6 = arith.addi %in, %init : i32
-      linalg.yield %6 : i32
-    }
-  // CHECK-NOT: mesh.all_reduce
-  %sharded_red = mesh.shard %reduced to %sharding : tensor<6xi32>
-  %sharded_ret = mesh.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32>
-  // CHECK: return %[[reduced]] : tensor<3xi32>
-  return %sharded_ret : tensor<6xi32>
-}
diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
similarity index 72%
rename from mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
rename to mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
index 4f54607a1c7ff..bc911215851aa 100644
--- a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
+++ b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
@@ -1,43 +1,43 @@
-// RUN: mlir-opt --split-input-file --test-mesh-all-slice-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --test-grid-simplifications --cse %s | FileCheck %s
 
-mesh.mesh @mesh_1d(shape = ?)
+shard.grid @grid_1d(shape = ?)
 
-// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
-func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_grid
+func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_grid(
   // CHECK: %[[ARG:.*]]: tensor<?xf16>
   %arg0: tensor<?xf16>
 // CHECK-SAME: -> tensor<?xf16> {
 ) -> tensor<?xf16> {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
-  // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+  // CHECK-DAG: %[[PROC_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index
+  // CHECK-DAG: %[[SHARD_SIZE:.*]] = shard.grid_shape @grid_1d axes = [0] : index
   // CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor<?xf16>
-  // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+  // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[SHARD_SIZE]] : index
   // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
   // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
-  // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+  // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[SHARD_SIZE]] : index
   // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index
   // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16>
-  %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16>
+  %0 = shard.all_slice %arg0 on @grid_1d grid_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16>
   // CHECK: return %[[RESULT]] : tensor<?xf16>
   return %0 : tensor<?xf16>
 }
 
 // -----
 
-mesh.mesh @mesh_1d(shape = 2)
+shard.grid @grid_1d(shape = 2)
 
-// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh
-func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_grid
+func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_grid(
   // CHECK: %[[ARG:.*]]: tensor<2xf16>
   %arg0: tensor<2xf16>
 // CHECK-SAME: -> tensor<1xf16> {
 ) -> tensor<1xf16> {
   // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+  // CHECK: %[[PROC_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index
   // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16>
   // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16>
-  %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16>
+  %0 = shard.all_slice %arg0 on @grid_1d grid_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16>
   // CHECK: return %[[RESULT]] : tensor<1xf16>
   return %0 : tensor<1xf16>
 }
@@ -46,18 +46,18 @@ func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
 
 // CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
 
-mesh.mesh @mesh_4d(shape = ?x?x?x?)
+shard.grid @grid_4d(shape = ?x?x?x?)
 
-// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
-func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid
+func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid(
   // CHECK: %[[ARG:.*]]: tensor<?x?xf16>
   %arg0 : tensor<?x?xf16>
 // CHECK-SAME: -> tensor<?x?xf16> {
 ) -> tensor<?x?xf16> {
   // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index
-  // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index
+  // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = shard.process_multi_index on @grid_4d axes = [3, 1] : index, index
+  // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = shard.grid_shape @grid_4d axes = [3, 1] : index, index
   // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
   // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
   // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
@@ -68,7 +68,7 @@ func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
   // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
   // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
   // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
-  %0 = mesh.all_slice %arg0 on @mesh_4d mesh_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
+  %0 = shard.all_slice %arg0 on @grid_4d grid_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
   // CHECK: return %[[RESULT]] : tensor<?x?xf16>
   return %0 : tensor<?x?xf16>
 }
diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
similarity index 67%
rename from mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
rename to mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
index 4223d01d65111..8894c4aee49c0 100644
--- a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
@@ -2,17 +2,17 @@
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+  shard.grid @grid(shape = 1) {sym_visibility = "private"}
   func.func @test_forward() -> tensor<6x6xi32> {
     %c1_i32 = arith.constant 1 : i32
     // CHECK: tensor.empty()
     %0 = tensor.empty() : tensor<6x6xi32>
-    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-    // CHECK-COUNT-2: mesh.shard
-    %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
-    %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
+    %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+    // CHECK-COUNT-2: shard.shard
+    %sharded = shard.shard %0 to %sharding : tensor<6x6xi32>
+    %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharded : tensor<6x6xi32>) -> tensor<6x6xi32>
     // CHECK: tensor.empty()
-    // CHECK-NOT: mesh.shard @
+    // CHECK-NOT: shard.shard @
     %2 = tensor.empty() : tensor<6x6xi32>
     %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
         : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
diff --git a/mlir/test/Dialect/Shard/canonicalization.mlir b/mlir/test/Dialect/Shard/canonicalization.mlir
new file mode 100644
index 0000000000000..ed40dfb7237da
--- /dev/null
+++ b/mlir/test/Dialect/Shard/canonicalization.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
+
+shard.grid @grid0(shape = 2x4)
+
+// CHECK-LABEL: func @all_reduce_empty_grid_axes
+func.func @all_reduce_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.all_reduce
+  %0 = shard.all_reduce %arg0 on @grid0
+    grid_axes = []
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_reduce_empty_grid_axes_different_return_type
+func.func @all_reduce_empty_grid_axes_different_return_type(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: shard.all_reduce
+  %0 = shard.all_reduce %arg0 on @grid0
+// CHECK-NOT: grid_axes
+    grid_axes = []
+    : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_reduce_default_reduction
+func.func @all_reduce_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  %0 = shard.all_reduce %arg0 on @grid0
+    grid_axes = [0]
+// CHECK-NOT: reduction
+    reduction = sum
+    : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_to_all_empty_grid_axes
+func.func @all_to_all_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
+    %arg0 : tensor<8xf32>) -> tensor<8xf32> {
+// CHECK-NOT: shard.all_to_all
+  %0 = shard.all_to_all %arg0 on @grid0
+    grid_axes = []
+    split_axis = 0
+    concat_axis = 0
+    : tensor<8xf32> -> tensor<8xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @all_gather_empty_grid_axes
+func.func @all_gather_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.all_gather
+  %0 = shard.all_gather %arg0 on @grid0
+    grid_axes = []
+    gather_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_slice_empty_grid_axes
+func.func @all_slice_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.scatter
+  %0 = shard.all_slice %arg0 on @grid0
+    grid_axes = []
+    slice_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @broadcast_empty_grid_axes
+func.func @broadcast_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.broadcast
+  %0 = shard.broadcast %arg0 on @grid0
+    grid_axes = []
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @gather_empty_grid_axes
+func.func @gather_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.gather
+  %0 = shard.gather %arg0 on @grid0
+    grid_axes = []
+    gather_axis = 0
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @receive_empty_grid_axes
+func.func @receive_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.recv
+  %0 = shard.recv %arg0 on @grid0
+    grid_axes = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_empty_grid_axes
+func.func @reduce_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.reduce
+  %0 = shard.reduce %arg0 on @grid0
+    grid_axes = []
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_grid_axes
+func.func @reduce_scatter_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.reduce_scatter
+  %0 = shard.reduce_scatter %arg0 on @grid0
+    grid_axes = []
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_grid_axes_different_return_type
+func.func @reduce_scatter_empty_grid_axes_different_return_type(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: shard.reduce_scatter
+  %0 = shard.reduce_scatter %arg0 on @grid0
+// CHECK-NOT: grid_axes
+    grid_axes = []
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_default_reduction
+func.func @reduce_scatter_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<2xf64> {
+  %0 = shard.reduce_scatter %arg0 on @grid0
+    grid_axes = [0]
+// CHECK-NOT: reduction
+    reduction = sum
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}
+
+// CHECK-LABEL: func @scatter_empty_grid_axes
+func.func @scatter_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.scatter
+  %0 = shard.scatter %arg0 on @grid0
+    grid_axes = []
+    scatter_axis = 0
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @send_empty_grid_axes
+func.func @send_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.send
+  %0 = shard.send %arg0 on @grid0
+    grid_axes = []
+    destination = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+shard.grid @grid4x4(shape = 4x4)
+// CHECK-LABEL: func @test_halo_sizes
+func.func @test_halo_sizes() -> !shard.sharding {
+  %c2_i64 = arith.constant 2 : i64
+  // CHECK shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !shard.sharding
+  %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !shard.sharding
+  return %sharding : !shard.sharding
+}
+
+// CHECK-LABEL: func @test_shard_offs
+func.func @test_shard_offs() -> !shard.sharding {
+  %c2_i64 = arith.constant 2 : i64
+  // CHECK shard.sharding @grid4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !shard.sharding
+  %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !shard.sharding
+  return %sharding : !shard.sharding
+}
+
+// CHECK-LABEL: func @test_duplicate_shardops
+func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0, 1]] : !shard.sharding
+  %sharding_1 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+  %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %sharding_2 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+  %sharded_2 = shard.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
+  %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %sharding_3 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+  %sharded_3 = shard.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
+  // CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+  %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+  // CHECK-NEXT: return [[vsharded]], [[vsharded]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+  return %sharded_1, %sharded_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+}
+
+// CHECK-LABEL: func @test_duplicate_shardops_diff
+func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+  %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+  %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  // CHECK-NEXT: [[vsharding_0:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0, 1]] : !shard.sharding
+  %sharding_2 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+  // CHECK-NEXT: [[vsharded:%.*]] = shard.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32>
+  %sharded_2 = shard.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
+  %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+  %sharding_3 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+  %sharded_3 = shard.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
+  // CHECK-NEXT: [[vsharded_1:%.*]] = shard.shard [[vsharded]] to [[vsharding]] : tensor<1024x1024xf32>
+  %sharded_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+  // CHECK-NEXT: return [[vsharded_1]], [[vsharded]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+  return %sharded_1, %sharded_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+}
diff --git a/mlir/test/Dialect/Shard/folding.mlir b/mlir/test/Dialect/Shard/folding.mlir
new file mode 100644
index 0000000000000..5a0f35b53a129
--- /dev/null
+++ b/mlir/test/Dialect/Shard/folding.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
+
+shard.grid @grid0(shape = 4x?x2)
+shard.grid @grid1(shape = 2x3)
+
+// CHECK-LABEL: func.func @grid_shape_op_folding
+func.func @grid_shape_op_folding() -> (index, index) {
+  // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
+  // CHECK: %[[AXIS_1_SIZE:.*]] = shard.grid_shape @grid0 axes = [1] : index
+  %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index
+  // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @grid_shape_op_folding_all_axes_static_grid
+func.func @grid_shape_op_folding_all_axes_static_grid() -> (index, index) {
+  // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
+  // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
+  %0:2 = shard.grid_shape @grid1 : index, index
+  // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
+  return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir
similarity index 63%
rename from mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
rename to mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir
index dd2eee2f7def8..0d8d99752620a 100644
--- a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir
@@ -2,25 +2,25 @@
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+  shard.grid @grid(shape = 1) {sym_visibility = "private"}
   func.func @test_forward() -> tensor<6x6xi32> {
     %c1_i32 = arith.constant 1 : i32
     // CHECK: tensor.empty()
     %0 = tensor.empty() : tensor<6x6xi32>
-    // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
-    %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-    %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
+    // CHECK-COUNT-3: shard.sharding @grid split_axes = {{\[\[0}}]]
+    %sharding_row = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+    %annotated_row = shard.shard %0 to %sharding_row : tensor<6x6xi32>
     %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
     %2 = tensor.empty() : tensor<6x6xi32>
-    // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
+    // CHECK-COUNT-4: shard.sharding @grid split_axes = {{\[\[1}}]]
     %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
         : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
     ^bb0(%in: i32, %in_2: i32, %out: i32):
       %9 = arith.addi %in, %in_2 : i32
       linalg.yield %9 : i32
     } -> tensor<6x6xi32>
-    %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
-    %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
+    %sharding_col = shard.sharding @grid split_axes = [[1]] : !shard.sharding
+    %annotated_col = shard.shard %3 to %sharding_col : tensor<6x6xi32>
     // CHECK: return
     return %annotated_col : tensor<6x6xi32>
   }
diff --git a/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..3cda9eaa365fd
--- /dev/null
+++ b/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
+  shard.grid @grid(shape = 1) {sym_visibility = "private"}
+  func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
+    %c1_i32 = arith.constant 1 : i32
+    // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
+    %0 = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[v1:%.*]] = linalg.fill ins
+    // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+    // CHECK: [[vsharded_1:%.*]] = shard.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
+    %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
+    %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+    %sharded = shard.shard %1 to %sharding : tensor<6x6xi32>
+    // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+    // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
+    %3 = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+    // CHECK: [[vsharded_5:%.*]] = shard.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
+    // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+    // CHECK-SAME: ins([[vsharded_3]], [[vsharded_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharded_5]] : tensor<6x6xi32>) {
+    // CHECK: [[vsharding_6:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+    // CHECK: [[vsharded_7:%.*]] = shard.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
+    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharded, %sharded
+        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
+    ^bb0(%in: i32, %in_2: i32, %out: i32):
+      %9 = arith.addi %in, %in_2 : i32
+      linalg.yield %9 : i32
+    } -> tensor<6x6xi32>
+    %c0_i32 = arith.constant 0 : i32
+    %6 = tensor.empty() : tensor<i32>
+    %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
+    // CHECK: [[vreduced:%.*]] = linalg.reduce ins
+    // CHECK: [[vsharding_12:%.*]] = shard.sharding @grid split_axes = [] : !shard.sharding
+    // CHECK: [[vsharded_13:%.*]] = shard.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
+    %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] 
+      (%in: i32, %init: i32) {
+        %9 = arith.addi %in, %init : i32
+        linalg.yield %9 : i32
+      }
+    // CHECK: [[vsharding_14:%.*]] = shard.sharding @grid split_axes = {{\[\[}}]] : !shard.sharding
+    %sharding_0 = shard.sharding @grid split_axes = [[]] : !shard.sharding
+    // CHECK: [[vsharded_15:%.*]] = shard.shard [[vsharded_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
+    %sharded_1 = shard.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
+    return %sharded, %4, %sharded_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
+  }
+}
diff --git a/mlir/test/Dialect/Shard/inlining.mlir b/mlir/test/Dialect/Shard/inlining.mlir
new file mode 100644
index 0000000000000..ce664b31abf7a
--- /dev/null
+++ b/mlir/test/Dialect/Shard/inlining.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -inline %s | FileCheck %s
+
+shard.grid @grid0(shape = 4x?x2)
+
+func.func private @grid_to_inline() -> (index, index) {
+  %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index
+  return %0#0, %0#1 : index, index
+}
+// CHECK-LABEL: func.func @main
+func.func @main() -> (index, index) {
+  // CHECK-NEXT: %[[AXIS_SIZE:.*]]:2 = shard.grid_shape @grid0 axes = [2, 1] : index
+  %0:2 = func.call @grid_to_inline() : () -> (index, index)
+  // CHECK-NEXT: return %[[AXIS_SIZE]]#0, %[[AXIS_SIZE]]#1
+  return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Shard/invalid.mlir
similarity index 57%
rename from mlir/test/Dialect/Mesh/invalid.mlir
rename to mlir/test/Dialect/Shard/invalid.mlir
index 2656332942382..6acac971164ed 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Shard/invalid.mlir
@@ -1,55 +1,55 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s
 
-// expected-error at +1 {{rank of mesh is expected to be a positive integer}}
-mesh.mesh @mesh0(shape = [])
+// expected-error at +1 {{rank of grid is expected to be a positive integer}}
+shard.grid @grid0(shape = [])
 
 // -----
 
-// expected-error at +1 {{custom op 'mesh.mesh' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
-mesh.mesh @mesh0(shape = -1)
+// expected-error at +1 {{custom op 'shard.grid' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
+shard.grid @grid0(shape = -1)
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @mesh_axis_duplicated_different_subarray(
+func.func @grid_axis_duplicated_different_subarray(
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // expected-error at +1 {{mesh axis duplicated}}
-  %s = mesh.sharding @mesh0 split_axes = [[0], [0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  // expected-error at +1 {{grid axis duplicated}}
+  %s = shard.sharding @grid0 split_axes = [[0], [0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @mesh_axis_duplicated_same_subarray(
+func.func @grid_axis_duplicated_same_subarray(
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // expected-error at +1 {{mesh axis duplicated}}
-  %s = mesh.sharding @mesh0 split_axes = [[0, 0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  // expected-error at +1 {{grid axis duplicated}}
+  %s = shard.sharding @grid0 split_axes = [[0, 0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @mesh_axis_negtive_in_split_part(
+func.func @grid_axis_negtive_in_split_part(
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // expected-error at +1 {{mesh axis is expected to be non-negative}}
-  %s = mesh.sharding @mesh0 split_axes = [[-1]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  // expected-error at +1 {{grid axis is expected to be non-negative}}
+  %s = shard.sharding @grid0 split_axes = [[-1]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
 // -----
 
 func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
-  // expected-error at +1 {{custom op 'mesh.sharding' invalid kind of attribute specified}}
-  %s = mesh.sharding @a::@b split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  // expected-error at +1 {{custom op 'shard.sharding' invalid kind of attribute specified}}
+  %s = shard.sharding @a::@b split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
 
@@ -57,8 +57,8 @@ func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
 
 func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
   // expected-error at +1 {{halo sizes must be specified for all split axes}}
-  %s = mesh.sharding @mesh0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  %s = shard.sharding @grid0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !shard.sharding
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
 
@@ -66,292 +66,292 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
 
 func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
   // expected-error at +1 {{halo sizes and shard offsets are mutually exclusive}}
-  %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  %s = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !shard.sharding
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
 
 // -----
 
-mesh.mesh @mesh_dyn(shape = ?x?)
-func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
-  // expected-error at +1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}}
-  %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+shard.grid @grid_dyn(shape = ?x?)
+func.func @sharding_dyn_grid_and_sizes(%arg0 : tensor<4x8xf32>) {
+  // expected-error at +1 {{sharded dims offsets are not allowed for device grids with dynamic shape}}
+  %s = shard.sharding @grid_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !shard.sharding
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) {
   // expected-error at +1 {{sharded dims offsets has wrong size}}
-  %s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  %s = shard.sharding @grid0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !shard.sharding
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 4)
+shard.grid @grid0(shape = 4)
 func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) {
   // expected-error at +1 {{sharded dims offsets must be non-decreasing}}
-  %s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  %s = shard.sharding @grid0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !shard.sharding
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) {
-  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
-  %0:2 = mesh.mesh_shape @mesh0 axes = [0, 2] : index, index
+func.func @grid_shape_grid_axis_out_of_bounds() -> (index, index) {
+  // expected-error at +1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+  %0:2 = shard.grid_shape @grid0 axes = [0, 2] : index, index
   return %0#0, %0#1 : index, index
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
 
-func.func @mesh_shape_duplicate_mesh_axis() -> (index, index, index) {
-  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0:3 = mesh.mesh_shape @mesh0 axes = [0, 2, 0] : index, index, index
+func.func @grid_shape_duplicate_grid_axis() -> (index, index, index) {
+  // expected-error at +1 {{Grid axes contains duplicate elements.}}
+  %0:3 = shard.grid_shape @grid0 axes = [0, 2, 0] : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @mesh_shape_wrong_number_of_results() -> (index, index) {
+func.func @grid_shape_wrong_number_of_results() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
-  %0:2 = mesh.mesh_shape @mesh0 axes = [0] : index, index
+  %0:2 = shard.grid_shape @grid0 axes = [0] : index, index
   return %0#0, %0#1 : index, index
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
 
-func.func @mesh_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+func.func @grid_shape_wrong_number_of_results_empty_grid_axes() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
-  %0:2 = mesh.mesh_shape @mesh0 : index, index
+  %0:2 = shard.grid_shape @grid0 : index, index
   return %0#0, %0#1 : index, index
 }
 
 // -----
 
-func.func @mesh_shape_invalid_mesh_name() -> (index) {
-  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.mesh_shape @this_mesh_symbol_does_not_exist : index
+func.func @grid_shape_invalid_grid_name() -> (index) {
+  // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+  %0 = shard.grid_shape @this_grid_symbol_does_not_exist : index
   return %0#0 : index
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
-  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
-  %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index
+func.func @process_multi_index_grid_axis_out_of_bounds() -> (index, index) {
+  // expected-error at +1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+  %0:2 = shard.process_multi_index on @grid0 axes = [0, 2] : index, index
   return %0#0, %0#1 : index, index
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
 
-func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
-  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index
+func.func @process_multi_index_duplicate_grid_axis() -> (index, index, index) {
+  // expected-error at +1 {{Grid axes contains duplicate elements.}}
+  %0:3 = shard.process_multi_index on @grid0 axes = [0, 2, 0] : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
 func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
-  %0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index
+  %0:2 = shard.process_multi_index on @grid0 axes = [0] : index, index
   return %0#0, %0#1 : index, index
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
 
-func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+func.func @process_multi_index_wrong_number_of_results_empty_grid_axes() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
-  %0:2 = mesh.process_multi_index on @mesh0 : index, index
+  %0:2 = shard.process_multi_index on @grid0 : index, index
   return %0#0, %0#1 : index, index
 }
 
 // -----
 
-func.func @process_multi_index_invalid_mesh_name() -> (index) {
-  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index
+func.func @process_multi_index_invalid_grid_name() -> (index) {
+  // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+  %0 = shard.process_multi_index on @this_grid_symbol_does_not_exist : index
   return %0 : index
 }
 
 // -----
 
-func.func @process_linear_index_invalid_mesh_name() -> (index) {
-  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index
+func.func @process_linear_index_invalid_grid_name() -> (index) {
+  // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+  %0 = shard.process_linear_index on @this_grid_symbol_does_not_exist : index
   return %0 : index
 }
 
 // -----
 
-func.func @all_reduce_invalid_mesh_symbol(
+func.func @all_reduce_invalid_grid_symbol(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = sum
+  // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+  %0 = shard.all_reduce %arg0 on @this_grid_symbol_does_not_exist reduction = sum
     : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @all_reduce_invalid_mesh_axis(
+func.func @all_reduce_invalid_grid_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = sum
+  // expected-error at +1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [2] reduction = sum
     : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @all_reduce_duplicate_mesh_axis(
+func.func @all_reduce_duplicate_grid_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = sum
+  // expected-error at +1 {{Grid axes contains duplicate elements.}}
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1, 0] reduction = sum
     : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
 func.func @all_reduce_invalid_tensor_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<5xf64> {
-  // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
-  %0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64>
+  // expected-error at +1 {{'shard.all_reduce' op requires the same shape for all operands and results}}
+  %0 = shard.all_reduce %arg0 on @grid0 : tensor<4xf32> -> tensor<5xf64>
   return %0 : tensor<5xf64>
 }
 
 // -----
 
-func.func @all_gather_invalid_mesh_symbol(
+func.func @all_gather_invalid_grid_symbol(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0
+  // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+  %0 = shard.all_gather %arg0 on @this_grid_symbol_does_not_exist gather_axis = 0
     : tensor<4xf32> -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @all_gather_invalid_mesh_axis(
+func.func @all_gather_invalid_grid_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
-  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0
+  // expected-error at +1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 0
     : tensor<4xf32> -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @all_reduce_duplicate_mesh_axis(
+func.func @all_reduce_duplicate_grid_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0
+  // expected-error at +1 {{Grid axes contains duplicate elements.}}
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2, 2] gather_axis = 0
     : tensor<4xf32> -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
 
 func.func @all_gather_invalid_non_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
-  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
     : tensor<3x4xf32> -> tensor<3x5xf32>
   return %0 : tensor<3x5xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1x2)
+shard.grid @grid0(shape = 1x2)
 
 func.func @all_gather_invalid_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
-  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 1
     : tensor<3x4xf32> -> tensor<3x5xf32>
   return %0 : tensor<3x5xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
 
 func.func @all_gather_invalid_gather_axis_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<3xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
-  %0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0
+  %0 = shard.all_gather %arg0 on @grid0 gather_axis = 0
     : tensor<?xf32> -> tensor<3xf32>
   return %0 : tensor<3xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
 
 func.func @all_gather_invalid_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
   // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
-  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 1
     : tensor<3xf32> -> tensor<3xf32>
   return %0 : tensor<3xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
 
 func.func @all_gather_invalid_negative_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
   // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
-  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = -1
     : tensor<3xf32> -> tensor<3xf32>
   return %0 : tensor<3xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
-func.func @all_slice_duplicate_mesh_axis(
+func.func @all_slice_duplicate_grid_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf32> {
-  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0, 0]
+  // expected-error at +1 {{Grid axes contains duplicate elements.}}
+  %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0, 0]
     slice_axis = 0
     : tensor<?xf32> -> tensor<?xf32>
   return %0 : tensor<?xf32>
@@ -359,12 +359,12 @@ func.func @all_slice_duplicate_mesh_axis(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @all_slice_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
-  %0 = mesh.all_slice %arg0 on @mesh0
+  %0 = shard.all_slice %arg0 on @grid0
     slice_axis = 0
     : tensor<?xf32> -> tensor<2xf32>
   return %0 : tensor<2xf32>
@@ -372,12 +372,12 @@ func.func @all_slice_invalid_dynamic_dimension(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @all_slice_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
-  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0]
     slice_axis = 0
     : tensor<3xf32> -> tensor<2xf32>
   return %0 : tensor<2xf32>
@@ -385,12 +385,12 @@ func.func @all_slice_invalid_static_dimension_size(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @all_slice_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf32> {
   // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
-  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0]
     slice_axis = 0
     : tensor<4xf32> -> tensor<?xf32>
   return %0 : tensor<?xf32>
@@ -398,10 +398,10 @@ func.func @all_slice_invalid_operand_static_dimension_size(
 
 // -----
 
-func.func @all_to_all_invalid_mesh_symbol(
+func.func @all_to_all_invalid_grid_symbol(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
-  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist
+  // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+  %0 = shard.all_to_all %arg0 on @this_grid_symbol_does_not_exist
     split_axis = 1 concat_axis = 0
     : tensor<3x6xi8> -> tensor<3x6xi8>
   return %0 : tensor<3x6xi8>
@@ -409,12 +409,12 @@ func.func @all_to_all_invalid_mesh_symbol(
 
 // -----
 
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
 
-func.func @all_to_all_duplicate_mesh_axis(
+func.func @all_to_all_duplicate_grid_axis(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
-  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0]
+  // expected-error at +1 {{Grid axes contains duplicate elements.}}
+  %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0, 0]
     split_axis = 0 concat_axis = 0
     : tensor<3x6xi8> -> tensor<3x6xi8>
   return %0 : tensor<3x6xi8>
@@ -422,12 +422,12 @@ func.func @all_to_all_duplicate_mesh_axis(
 
 // -----
 
-mesh.mesh @mesh0(shape = ?x1)
+shard.grid @grid0(shape = ?x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
-  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0]
     split_axis = 0 concat_axis = 1
     : tensor<3x6xi8> -> tensor<3x6xi8>
   return %0 : tensor<3x6xi8>
@@ -435,12 +435,12 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de
 
 // -----
 
-mesh.mesh @mesh0(shape = 1x1)
+shard.grid @grid0(shape = 1x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
-  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+  %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [1]
     split_axis = 0 concat_axis = 1
     : tensor<?x6xi8> -> tensor<3x?xi8>
   return %0 : tensor<3x?xi8>
@@ -448,12 +448,12 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna
 
 // -----
 
-mesh.mesh @mesh0(shape = 1x1)
+shard.grid @grid0(shape = 1x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
-  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+  %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [1]
     split_axis = 0 concat_axis = 1
     : tensor<3x?xi8> -> tensor<?x3xi8>
   return %0 : tensor<?x3xi8>
@@ -461,12 +461,12 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
-  %0 = mesh.all_to_all %arg0  on @mesh0 mesh_axes = [0]
+  %0 = shard.all_to_all %arg0  on @grid0 grid_axes = [0]
     split_axis = 0 concat_axis = 1
     : tensor<3x2xi8> -> tensor<1x7xi8>
   return %0 : tensor<1x7xi8>
@@ -474,12 +474,12 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
-  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0]
     split_axis = 0 concat_axis = 1
     : tensor<3x2xi8> -> tensor<2x6xi8>
   return %0 : tensor<2x6xi8>
@@ -487,12 +487,12 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @broadcast_root_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
   // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
-  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0]
     root = [3]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -500,12 +500,12 @@ func.func @broadcast_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @broadcast_root_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
   // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
-  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0]
     root = [2, 2]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -513,12 +513,12 @@ func.func @broadcast_root_wrong_number_dimensions(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @broadcast_different_input_and_result_type(
     %arg0 : tensor<2xi8>) -> tensor<2xi16> {
-  // expected-error at +1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}}
-  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+  // expected-error at +1 {{'shard.broadcast' op failed to verify that all of {input, result} have same element type}}
+  %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0]
     root = [2]
     : (tensor<2xi8>) -> tensor<2xi16>
   return %0 : tensor<2xi16>
@@ -526,84 +526,84 @@ func.func @broadcast_different_input_and_result_type(
 
 // -----
 
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
 
 func.func @gather_wrong_return_element_type(
     %arg0 : tensor<1xf32>) -> tensor<1xi8> {
-  // expected-error at +1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}}
-  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+  // expected-error at +1 {{'shard.gather' op failed to verify that all of {input, result} have same element type}}
+  %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [0]
     : (tensor<1xf32>) -> tensor<1xi8>
   return %0 : tensor<1xi8>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
 
 func.func @gather_invalid_non_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
-  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+  %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [0]
     : (tensor<3x4xf32>) -> tensor<3x5xf32>
   return %0 : tensor<3x5xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1x2)
+shard.grid @grid0(shape = 1x2)
 
 func.func @gather_invalid_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
-  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0]
+  %0 = shard.gather %arg0 on @grid0 grid_axes = [1] gather_axis = 1 root = [0]
     : (tensor<3x4xf32>) -> tensor<3x5xf32>
   return %0 : tensor<3x5xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
 
 func.func @gather_invalid_gather_axis_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<3xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
-  %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = []
+  %0 = shard.gather %arg0 on @grid0 gather_axis = 0 root = []
     : (tensor<?xf32>) -> tensor<3xf32>
   return %0 : tensor<3xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
 
 func.func @gather_invalid_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
   // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
-  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0]
+  %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 1 root = [0]
     : (tensor<3xf32>) -> tensor<3xf32>
   return %0 : tensor<3xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
 
 func.func @gather_invalid_negative_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
   // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
-  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0]
+  %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = -1 root = [0]
     : (tensor<3xf32>) -> tensor<3xf32>
   return %0 : tensor<3xf32>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @gather_root_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<6xi8> {
   // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
-  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+  %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
     root = [3]
     : (tensor<2xi8>) -> tensor<6xi8>
   return %0 : tensor<6xi8>
@@ -611,12 +611,12 @@ func.func @gather_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @gather_root_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
   // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
-  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+  %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
     root = [2, 2]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -624,12 +624,12 @@ func.func @gather_root_wrong_number_dimensions(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @receive_source_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
   // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "source". Got 3, but expected value in the range [0, 2].}}
-  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.recv %arg0 on @grid0 grid_axes = [0]
     source = [3]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -637,12 +637,12 @@ func.func @receive_source_dimension_out_of_bounds(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @receive_source_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
   // expected-error at +1 {{In-group device "source" has unexpected multi-index size 2. Expected 1.}}
-  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.recv %arg0 on @grid0 grid_axes = [0]
     source = [2, 2]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -650,12 +650,12 @@ func.func @receive_source_wrong_number_dimensions(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @receive_different_input_and_result_type(
     %arg0 : tensor<2xi8>) -> tensor<2xi16> {
-  // expected-error at +1 {{'mesh.recv' op failed to verify that all of {input, result} have same element type}}
-  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+  // expected-error at +1 {{'shard.recv' op failed to verify that all of {input, result} have same element type}}
+  %0 = shard.recv %arg0 on @grid0 grid_axes = [0]
     source = [2]
     : (tensor<2xi8>) -> tensor<2xi16>
   return %0 : tensor<2xi16>
@@ -663,12 +663,12 @@ func.func @receive_different_input_and_result_type(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @reduce_root_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
   // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
-  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.reduce %arg0 on @grid0 grid_axes = [0]
     root = [3]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -676,12 +676,12 @@ func.func @reduce_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @reduce_root_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
   // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
-  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.reduce %arg0 on @grid0 grid_axes = [0]
     root = [2, 2]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -689,12 +689,12 @@ func.func @reduce_root_wrong_number_dimensions(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @reduce_different_input_and_result_shape(
     %arg0 : tensor<2xi8>) -> tensor<3xi16> {
-  // expected-error at +1 {{'mesh.reduce' op failed to verify that all of {input, result} have same shape}}
-  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+  // expected-error at +1 {{'shard.reduce' op failed to verify that all of {input, result} have same shape}}
+  %0 = shard.reduce %arg0 on @grid0 grid_axes = [0]
     root = [2]
     : (tensor<2xi8>) -> tensor<3xi16>
   return %0 : tensor<3xi16>
@@ -702,60 +702,60 @@ func.func @reduce_different_input_and_result_shape(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
-func.func @reduce_scatter_duplicate_mesh_axis(
+func.func @reduce_scatter_duplicate_grid_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf64> {
-  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0
+  // expected-error at +1 {{Grid axes contains duplicate elements.}}
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0
     : tensor<?xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @reduce_scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf64> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
-  %0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0
+  %0 = shard.reduce_scatter %arg0 on @grid0 scatter_axis = 0
     : tensor<?xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @reduce_scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf64> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
-  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
     : tensor<3xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @reduce_scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf64> {
   // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
-  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
     : tensor<4xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
-func.func @scatter_duplicate_mesh_axis(
+func.func @scatter_duplicate_grid_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf32> {
-  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 0]
+  // expected-error at +1 {{Grid axes contains duplicate elements.}}
+  %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 0]
     scatter_axis = 0 root = [0, 0]
     : (tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
@@ -763,12 +763,12 @@ func.func @scatter_duplicate_mesh_axis(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
-  %0 = mesh.scatter %arg0 on @mesh0
+  %0 = shard.scatter %arg0 on @grid0
     scatter_axis = 0 root = []
     : (tensor<?xf32>) -> tensor<2xf32>
   return %0 : tensor<2xf32>
@@ -776,12 +776,12 @@ func.func @scatter_invalid_dynamic_dimension(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
-  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
     scatter_axis = 0 root = [1]
     : (tensor<3xf32>) -> tensor<2xf32>
   return %0 : tensor<2xf32>
@@ -789,12 +789,12 @@ func.func @scatter_invalid_static_dimension_size(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
 
 func.func @scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf32> {
   // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
-  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
     scatter_axis = 0 root = [1]
     : (tensor<4xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
@@ -802,12 +802,12 @@ func.func @scatter_invalid_operand_static_dimension_size(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @scatter_root_dimension_out_of_bounds(
     %arg0 : tensor<3xi8>) -> tensor<1xi8> {
   // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
-  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
     scatter_axis = 0 root = [3]
     : (tensor<3xi8>) -> tensor<1xi8>
   return %0 : tensor<1xi8>
@@ -815,12 +815,12 @@ func.func @scatter_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @scatter_root_wrong_number_dimensions(
     %arg0 : tensor<3xi8>) -> tensor<1xi8> {
   // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
-  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
     scatter_axis = 0 root = [2, 2]
     : (tensor<3xi8>) -> tensor<1xi8>
   return %0 : tensor<1xi8>
@@ -828,12 +828,12 @@ func.func @scatter_root_wrong_number_dimensions(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @send_destination_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
   // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "destination". Got 3, but expected value in the range [0, 2].}}
-  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.send %arg0 on @grid0 grid_axes = [0]
     destination = [3]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -841,12 +841,12 @@ func.func @send_destination_dimension_out_of_bounds(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @send_destination_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
   // expected-error at +1 {{In-group device "destination" has unexpected multi-index size 2. Expected 1.}}
-  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.send %arg0 on @grid0 grid_axes = [0]
     destination = [2, 2]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -854,12 +854,12 @@ func.func @send_destination_wrong_number_dimensions(
 
 // -----
 
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
 
 func.func @send_different_input_and_result_type(
     %arg0 : tensor<2xi8>) -> tensor<2xi16> {
-  // expected-error at +1 {{'mesh.send' op failed to verify that all of {input, result} have same element type}}
-  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+  // expected-error at +1 {{'shard.send' op failed to verify that all of {input, result} have same element type}}
+  %0 = shard.send %arg0 on @grid0 grid_axes = [0]
     destination = [2]
     : (tensor<2xi8>) -> tensor<2xi16>
   return %0 : tensor<2xi16>
@@ -867,10 +867,10 @@ func.func @send_different_input_and_result_type(
 
 // -----
 
-func.func @shift_invalid_mesh_symbol(
+func.func @shift_invalid_grid_symbol(
     %arg0 : tensor<4xi8>) -> tensor<4xi8> {
-  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.shift %arg0 on @this_mesh_symbol_does_not_exist
+  // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+  %0 = shard.shift %arg0 on @this_grid_symbol_does_not_exist
     shift_axis = 0 offset = -2
     : tensor<4xi8> -> tensor<4xi8>
   return %0 : tensor<4xi8>
@@ -878,12 +878,12 @@ func.func @shift_invalid_mesh_symbol(
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @shift_invalid_mesh_axis(
+func.func @shift_invalid_grid_axis(
     %arg0 : tensor<4xi8>) -> tensor<4xi8> {
-  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
-  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [2]
+  // expected-error at +1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+  %0 = shard.shift %arg0 on @grid0 grid_axes = [2]
         shift_axis = 2 offset = -2
     : tensor<4xi8> -> tensor<4xi8>
   return %0 : tensor<4xi8>
@@ -891,12 +891,12 @@ func.func @shift_invalid_mesh_axis(
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
-func.func @shift_duplicate_mesh_axis(
+func.func @shift_duplicate_grid_axis(
     %arg0 : tensor<4xi8>) -> tensor<4xi8> {
-  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 1, 0]
+  // expected-error at +1 {{Grid axes contains duplicate elements.}}
+  %0 = shard.shift %arg0 on @grid0 grid_axes = [0, 1, 0]
     shift_axis = 0 offset = -2
     : tensor<4xi8> -> tensor<4xi8>
   return %0 : tensor<4xi8>
@@ -904,12 +904,12 @@ func.func @shift_duplicate_mesh_axis(
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
 func.func @shift_invalid_tensor_dimension_size(
     %arg0 : tensor<4xi8>) -> tensor<5xi8> {
-  // expected-error at +1 {{'mesh.shift' op requires the same shape for all operands and results}}
-  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
+  // expected-error at +1 {{'shard.shift' op requires the same shape for all operands and results}}
+  %0 = shard.shift %arg0 on @grid0 grid_axes = [0]
     shift_axis = 0 offset = 2
     : tensor<4xi8> -> tensor<5xi8>
   return %0 : tensor<5xi8>
@@ -917,12 +917,12 @@ func.func @shift_invalid_tensor_dimension_size(
 
 // -----
 
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
 
 func.func @shift_invalid_shift_axis(
     %arg0 : tensor<4xi8>) -> tensor<4xi8> {
-  // expected-error at +1 {{Invalid shift axis 1. It must be one of the grouping mesh axes.}}
-  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
+  // expected-error at +1 {{Invalid shift axis 1. It must be one of the grouping grid axes.}}
+  %0 = shard.shift %arg0 on @grid0 grid_axes = [0]
     shift_axis = 1 offset = 2
     : tensor<4xi8> -> tensor<4xi8>
   return %0 : tensor<4xi8>
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Shard/ops.mlir
similarity index 55%
rename from mlir/test/Dialect/Mesh/ops.mlir
rename to mlir/test/Dialect/Shard/ops.mlir
index c354de514fba8..5265dadd2a845 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Shard/ops.mlir
@@ -1,176 +1,176 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
-// CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 2x2x4)
+// CHECK: shard.grid @grid0
+shard.grid @grid0(shape = 2x2x4)
 
-// CHECK: mesh.mesh @mesh1(shape = 4x?)
-mesh.mesh @mesh1(shape = 4x?)
+// CHECK: shard.grid @grid1(shape = 4x?)
+shard.grid @grid1(shape = 4x?)
 
-// CHECK: mesh.mesh @mesh2(shape = ?x4)
-mesh.mesh @mesh2(shape = ?x4)
+// CHECK: shard.grid @grid2(shape = ?x4)
+shard.grid @grid2(shape = ?x4)
 
-// CHECK: mesh.mesh @mesh3(shape = ?x?)
-mesh.mesh @mesh3(shape = ?x?)
+// CHECK: shard.grid @grid3(shape = ?x?)
+shard.grid @grid3(shape = ?x?)
 
-mesh.mesh @mesh4(shape = 3)
+shard.grid @grid4(shape = 3)
 
-// CHECK: mesh.mesh @mesh5(shape = ?)
-mesh.mesh @mesh5(shape = ?)
+// CHECK: shard.grid @grid5(shape = ?)
+shard.grid @grid5(shape = ?)
 
-// CHECK-LABEL: func @mesh_shard_op_fully_replicated
+// CHECK-LABEL: func @grid_shard_op_fully_replicated
 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
-  %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
-  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+func.func @grid_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}]] : !shard.sharding
+  %s = shard.sharding @grid0 split_axes = [[]] : !shard.sharding
+  // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
-// CHECK-LABEL: func @mesh_shard_op_1st_dim
+// CHECK-LABEL: func @grid_shard_op_1st_dim
 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
-  %s = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+func.func @grid_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}0]] : !shard.sharding
+  %s = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
 
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
-// CHECK-LABEL: func @mesh_shard_op_2nd_dim
+// CHECK-LABEL: func @grid_shard_op_2nd_dim
 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
-  %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
-  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+func.func @grid_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid1 split_axes = {{\[\[}}], [0]] : !shard.sharding
+  %s = shard.sharding @grid1 split_axes = [[], [0]] : !shard.sharding
+  // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+  %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
-// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim
-func.func @mesh_shard_op_1st_and_3rd_dim(
+// CHECK-LABEL: func @grid_shard_op_1st_and_3rd_dim
+func.func @grid_shard_op_1st_and_3rd_dim(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32>
     %arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
-  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0], [], [1]] : !mesh.sharding
-  %s = mesh.sharding @mesh3 split_axes = [[0], [], [1]] : !mesh.sharding
-  // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32>
-  %0 = mesh.shard %arg0 to %s : tensor<4x8x16xf32>
+  // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid3 split_axes = {{\[\[}}0], [], [1]] : !shard.sharding
+  %s = shard.sharding @grid3 split_axes = [[0], [], [1]] : !shard.sharding
+  // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32>
+  %0 = shard.shard %arg0 to %s : tensor<4x8x16xf32>
   return %0 : tensor<4x8x16xf32>
 }
 
-// CHECK-LABEL: func @mesh_shard_op_two_users
+// CHECK-LABEL: func @grid_shard_op_two_users
 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
+func.func @grid_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
                                   (tensor<4x8xf32>, tensor<4x8xf32>) {
-  // CHECK-NEXT: %[[V0:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
-  %s0 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
-  %0 = mesh.shard %arg0 to %s0 : tensor<4x8xf32>
-  // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}1]] : !mesh.sharding
-  %s1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
-  %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<4x8xf32>
-  // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}2]] : !mesh.sharding
-  %s2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
-  %2 = mesh.shard %0 to %s2 annotate_for_users : tensor<4x8xf32>
+  // CHECK-NEXT: %[[V0:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}0]] : !shard.sharding
+  %s0 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<4x8xf32>
+  // CHECK-DAG: shard.sharding @grid0 split_axes = {{\[\[}}1]] : !shard.sharding
+  %s1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<4x8xf32>
+  // CHECK-DAG: shard.sharding @grid0 split_axes = {{\[\[}}2]] : !shard.sharding
+  %s2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding
+  %2 = shard.shard %0 to %s2 annotate_for_users : tensor<4x8xf32>
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
 
-// CHECK-LABEL: func @mesh_shard_halo_sizes
-func.func @mesh_shard_halo_sizes() -> () {
+// CHECK-LABEL: func @grid_shard_halo_sizes
+func.func @grid_shard_halo_sizes() -> () {
   // CHECK: %[[C3:.*]] = arith.constant 3 : i64
   %c3 = arith.constant 3 : i64
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !mesh.sharding
-  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [1, 4] : !mesh.sharding
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !mesh.sharding
-  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [4, %c3] : !mesh.sharding
+  // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !shard.sharding
+  %sharding1 = shard.sharding @grid4 split_axes = [[0]] halo_sizes = [1, 4] : !shard.sharding
+  // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !shard.sharding
+  %sharding2 = shard.sharding @grid4 split_axes = [[0]] halo_sizes = [4, %c3] : !shard.sharding
   return
 }
 
-// CHECK-LABEL: func @mesh_shard_dims_sizes
-func.func @mesh_shard_dims_sizes() -> () {
+// CHECK-LABEL: func @grid_shard_dims_sizes
+func.func @grid_shard_dims_sizes() -> () {
   // CHECK: %[[C3:.*]] = arith.constant 3 : i64
   %c3 = arith.constant 3 : i64
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
-  %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
-  // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !mesh.sharding
-  %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !mesh.sharding
+  // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !shard.sharding
+  %sharding1 = shard.sharding @grid4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !shard.sharding
+  // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !shard.sharding
+  %sharding2 = shard.sharding @grid4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !shard.sharding
   return
 }
 
-// CHECK-LABEL: func @mesh_shard_shape
-func.func @mesh_shard_shape() {
+// CHECK-LABEL: func @grid_shard_shape
+func.func @grid_shard_shape() {
   // CHECK: %[[C3:.*]] = arith.constant 3 : index
   %c3 = arith.constant 3 : index
-  // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
-  %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
-  // CHECK-NEXT: mesh.shard_shape dims = [8, %[[C3]]
+  // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}]] : !shard.sharding
+  %s = shard.sharding @grid0 split_axes = [[]] : !shard.sharding
+  // CHECK-NEXT: shard.shard_shape dims = [8, %[[C3]]
   // CHECK-SAME: ] sharding = %[[S]] device = [%[[C3]]
   // CHECK-SAME: ] : index, index
-  %shp:2 = mesh.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index
-  // CHECK-NEXT: mesh.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index
-  %shp1:2 = mesh.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index
+  %shp:2 = shard.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index
+  // CHECK-NEXT: shard.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index
+  %shp1:2 = shard.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index
   return
 }
 
-// CHECK-LABEL: func @mesh_get_sharding
+// CHECK-LABEL: func @grid_get_sharding
 // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
-  // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
-  %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
-  return %0 : !mesh.sharding
+func.func @grid_get_sharding(%arg0 : tensor<4x8xf32>) -> !shard.sharding {
+  // CHECK-NEXT: shard.get_sharding %[[ARG]] : tensor<4x8xf32> -> !shard.sharding
+  %0 = shard.get_sharding %arg0 : tensor<4x8xf32> -> !shard.sharding
+  return %0 : !shard.sharding
 }
 
-// CHECK-LABEL: func @mesh_shape
-func.func @mesh_shape() -> (index, index) {
-  // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
-  %0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
+// CHECK-LABEL: func @grid_shape
+func.func @grid_shape() -> (index, index) {
+  // CHECK: %[[RES:.*]]:2 = shard.grid_shape @grid0 axes = [0, 1] : index, index
+  %0:2 = shard.grid_shape @grid0 axes = [0, 1] : index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
   return %0#0, %0#1 : index, index
 }
 
-// CHECK-LABEL: func @mesh_shape_default_axes
-func.func @mesh_shape_default_axes() -> (index, index, index) {
-  // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
-  %0:3 = mesh.mesh_shape @mesh0 : index, index, index
+// CHECK-LABEL: func @grid_shape_default_axes
+func.func @grid_shape_default_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = shard.grid_shape @grid0 : index, index, index
+  %0:3 = shard.grid_shape @grid0 : index, index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
-// CHECK-LABEL: func @mesh_shape_empty_axes
-func.func @mesh_shape_empty_axes() -> (index, index, index) {
-  // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
-  %0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index
+// CHECK-LABEL: func @grid_shape_empty_axes
+func.func @grid_shape_empty_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = shard.grid_shape @grid0 : index, index, index
+  %0:3 = shard.grid_shape @grid0 axes = [] : index, index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
 // CHECK-LABEL: func @process_multi_index
 func.func @process_multi_index() -> (index, index) {
-  // CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
-  %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
+  // CHECK: %[[RES:.*]]:2 = shard.process_multi_index on @grid0 axes = [0, 1] : index, index
+  %0:2 = shard.process_multi_index on @grid0 axes = [0, 1] : index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
   return %0#0, %0#1 : index, index
 }
 
 // CHECK-LABEL: func @process_multi_index_default_axes
 func.func @process_multi_index_default_axes() -> (index, index, index) {
-  // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK: %[[RES:.*]]:3 = shard.process_multi_index on @grid0 : index, index, index
+  %0:3 = shard.process_multi_index on @grid0 : index, index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
 // CHECK-LABEL: func @process_multi_index_empty_axes
 func.func @process_multi_index_empty_axes() -> (index, index, index) {
-  // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+  // CHECK: %[[RES:.*]]:3 = shard.process_multi_index on @grid0 : index, index, index
+  %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
 // CHECK-LABEL: func @process_linear_index
 func.func @process_linear_index() -> index {
-  // CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index
-  %0 = mesh.process_linear_index on @mesh0 : index
+  // CHECK: %[[RES:.*]] = shard.process_linear_index on @grid0 : index
+  %0 = shard.process_linear_index on @grid0 : index
   // CHECK: return %[[RES]] : index
   return %0 : index
 }
@@ -179,9 +179,9 @@ func.func @process_linear_index() -> index {
 func.func @all_reduce(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
-  // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = max
+  // CHECK-NEXT: shard.all_reduce %[[ARG]] on @grid0 grid_axes = [1, 0] reduction = max
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = max
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [1, 0] reduction = max
     : tensor<3x4xf32> -> tensor<3x4xf64>
   return %0 : tensor<3x4xf64>
 }
@@ -190,9 +190,9 @@ func.func @all_reduce(
 func.func @all_gather(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
-  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
+  // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid0 grid_axes = [2] gather_axis = 1
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
-  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1
     : tensor<3x4xf32> -> tensor<3x16xf32>
   return %0 : tensor<3x16xf32>
 }
@@ -201,20 +201,20 @@ func.func @all_gather(
 func.func @all_gather_dynamic_dims_in_tensor(
     // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
     %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
+  // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid0 grid_axes = [2] gather_axis = 1
   // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
-  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1
     : tensor<?x?xf32> -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
-// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
-func.func @all_gather_dynamic_dims_in_mesh(
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_grid
+func.func @all_gather_dynamic_dims_in_grid(
     // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
     %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
-  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
+  // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid3 grid_axes = [1] gather_axis = 1
   // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
-  %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
+  %0 = shard.all_gather %arg0 on @grid3 grid_axes = [1] gather_axis = 1
     : tensor<5x6xf32> -> tensor<5x?xf32>
   return %0 : tensor<5x?xf32>
 }
@@ -223,10 +223,10 @@ func.func @all_gather_dynamic_dims_in_mesh(
 func.func @all_slice_static_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
-  // CHECK-NEXT: mesh.all_slice %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1
+  // CHECK-NEXT: shard.all_slice %[[ARG]]
+  // CHECK-SAME: on @grid0 grid_axes = [2] slice_axis = 1
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
-  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1
+  %0 = shard.all_slice %arg0 on @grid0 grid_axes = [2] slice_axis = 1
     : tensor<3x4xf32> -> tensor<3x1xf32>
   return %0 : tensor<3x1xf32>
 }
@@ -235,10 +235,10 @@ func.func @all_slice_static_dimensions(
 func.func @all_slice_dynamic_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
     %arg0 : tensor<?xf32>) -> tensor<?xf32> {
-  // CHECK-NEXT: mesh.all_slice %[[ARG]]
-  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0
+  // CHECK-NEXT: shard.all_slice %[[ARG]]
+  // CHECK-SAME: on @grid3 grid_axes = [0, 1] slice_axis = 0
   // CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
-  %0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0
+  %0 = shard.all_slice %arg0 on @grid3 grid_axes = [0, 1] slice_axis = 0
     : tensor<?xf32> -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
@@ -247,10 +247,10 @@ func.func @all_slice_dynamic_dimensions(
 func.func @all_to_all(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
-  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
-  // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
+  // CHECK-NEXT: shard.all_to_all %[[ARG]]
+  // CHECK-SAME: on @grid4 split_axis = 1 concat_axis = 0
   // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
-  %0 = mesh.all_to_all %arg0 on @mesh4
+  %0 = shard.all_to_all %arg0 on @grid4
     split_axis = 1 concat_axis = 0
     : tensor<3x6xi8> -> tensor<3x6xi8>
   return %0 : tensor<3x6xi8>
@@ -260,10 +260,10 @@ func.func @all_to_all(
 func.func @all_to_all_dynamic_dims_in_result(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
     %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
-  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
-  // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
+  // CHECK-NEXT: shard.all_to_all %[[ARG]]
+  // CHECK-SAME: on @grid4 split_axis = 1 concat_axis = 0
   // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
-  %0 = mesh.all_to_all %arg0 on @mesh4
+  %0 = shard.all_to_all %arg0 on @grid4
     split_axis = 1 concat_axis = 0
     : tensor<3x6xi8> -> tensor<3x?xi8>
   return %0 : tensor<3x?xi8>
@@ -273,10 +273,10 @@ func.func @all_to_all_dynamic_dims_in_result(
 func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
     %arg0 : tensor<3xi8>) -> tensor<3xi8> {
-  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
-  // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
+  // CHECK-NEXT: shard.all_to_all %[[ARG]]
+  // CHECK-SAME: @grid4 split_axis = 0 concat_axis = 0
   // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
-  %0 = mesh.all_to_all %arg0 on @mesh4
+  %0 = shard.all_to_all %arg0 on @grid4
     split_axis = 0 concat_axis = 0
     : tensor<3xi8> -> tensor<3xi8>
   return %0 : tensor<3xi8>
@@ -286,10 +286,10 @@ func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
 func.func @all_to_all_non_divisible_split_axis_size(
     // CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
     %arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
-  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
-  // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
+  // CHECK-NEXT: shard.all_to_all %[[ARG]]
+  // CHECK-SAME: @grid0 grid_axes = [0, 1] split_axis = 0 concat_axis = 1
   // CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
-  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
+  %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0, 1]
     split_axis = 0 concat_axis = 1
     : tensor<2x3xi8> -> tensor<?x12xi8>
   return %0 : tensor<?x12xi8>
@@ -299,11 +299,11 @@ func.func @all_to_all_non_divisible_split_axis_size(
 func.func @broadcast_static_root(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
-  // CHECK-NEXT: mesh.broadcast %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.broadcast %[[ARG]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: root = [0, 1]
   // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8>
-  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0, 2]
     root = [0, 1]
     : (tensor<3x6xi8>) -> tensor<3x6xi8>
   return %0 : tensor<3x6xi8>
@@ -316,11 +316,11 @@ func.func @broadcast_dynamic_root(
     // CHECK-SAME: %[[ARG1:.*]]: index
     %arg1 : index
     ) -> tensor<3x6xi8> {
-  // CHECK-NEXT: mesh.broadcast %[[ARG0]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.broadcast %[[ARG0]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: root = [1, %[[ARG1]]]
   // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
-  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0, 2]
     root = [1, %arg1]
     : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
   return %0 : tensor<3x6xi8>
@@ -330,12 +330,12 @@ func.func @broadcast_dynamic_root(
 func.func @gather_static_root(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
     %arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> {
-  // CHECK-NEXT: mesh.gather %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.gather %[[ARG]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: gather_axis = 0
   // CHECK-SAME: root = [0, 1]
   // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8>
-  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.gather %arg0 on @grid0 grid_axes = [0, 2]
     gather_axis = 0
     root = [0, 1]
     : (tensor<3x6xi8>) -> tensor<24x6xi8>
@@ -349,12 +349,12 @@ func.func @gather_dynamic_root(
     // CHECK-SAME: %[[ARG1:.*]]: index
     %arg1 : index
     ) -> tensor<24x6xi8> {
-  // CHECK-NEXT: mesh.gather %[[ARG0]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.gather %[[ARG0]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: gather_axis = 0
   // CHECK-SAME: root = [1, %[[ARG1]]]
   // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
-  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.gather %arg0 on @grid0 grid_axes = [0, 2]
     gather_axis = 0
     root = [1, %arg1]
     : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
@@ -365,11 +365,11 @@ func.func @gather_dynamic_root(
 func.func @receive_static_source(
     // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
-  // CHECK-NEXT: mesh.recv %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.recv %[[ARG]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: source = [0, 1]
   // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
-  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2]
     source = [0, 1]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -382,11 +382,11 @@ func.func @receive_dynamic_source(
     // CHECK-SAME: %[[ARG1:.*]]: index
     %arg1 : index
     ) -> tensor<2xi8> {
-  // CHECK-NEXT: mesh.recv %[[ARG0]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.recv %[[ARG0]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: source = [1, %[[ARG1]]]
   // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
-  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2]
     source = [1, %arg1]
     : (tensor<2xi8>, index) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -396,9 +396,9 @@ func.func @receive_dynamic_source(
 func.func @receive_no_source(
     // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
-  // CHECK-NEXT: mesh.recv %[[ARG]]
+  // CHECK-NEXT: shard.recv %[[ARG]]
   // CHECK-NOT: source
-  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
@@ -407,11 +407,11 @@ func.func @receive_no_source(
 func.func @reduce_static_root(
     // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
-  // CHECK-NEXT: mesh.reduce %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.reduce %[[ARG]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: root = [0, 1]
   // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
-  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2]
     root = [0, 1]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -424,11 +424,11 @@ func.func @reduce_dynamic_root(
     // CHECK-SAME: %[[ARG1:.*]]: index
     %arg1 : index
     ) -> tensor<2xi8> {
-  // CHECK-NEXT: mesh.reduce %[[ARG0]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.reduce %[[ARG0]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: root = [1, %[[ARG1]]]
   // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
-  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2]
     root = [1, %arg1]
     : (tensor<2xi8>, index) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -438,11 +438,11 @@ func.func @reduce_dynamic_root(
 func.func @reduce_different_return_element_type(
     // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
     %arg0 : tensor<2xi8>) -> tensor<2xi16> {
-  // CHECK-NEXT: mesh.reduce %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.reduce %[[ARG]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: root = [0, 1]
   // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16>
-  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2]
     root = [0, 1]
     : (tensor<2xi8>) -> tensor<2xi16>
   return %0 : tensor<2xi16>
@@ -452,10 +452,10 @@ func.func @reduce_different_return_element_type(
 func.func @reduce_scatter_static_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
-  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = max scatter_axis = 1
+  // CHECK-NEXT: shard.reduce_scatter %[[ARG]]
+  // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_axis = 1
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
-  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [2]
     reduction = max scatter_axis = 1
     : tensor<3x4xf32> -> tensor<3x1xf64>
   return %0 : tensor<3x1xf64>
@@ -465,10 +465,10 @@ func.func @reduce_scatter_static_dimensions(
 func.func @reduce_scatter_dynamic_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
     %arg0 : tensor<?xf32>) -> tensor<?xf64> {
-  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
-  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+  // CHECK-NEXT: shard.reduce_scatter %[[ARG]]
+  // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_axis = 0
   // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
-  %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+  %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0
     : tensor<?xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }
@@ -477,11 +477,11 @@ func.func @reduce_scatter_dynamic_dimensions(
 func.func @scatter_static_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
-  // CHECK-NEXT: mesh.scatter %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [2]
+  // CHECK-NEXT: shard.scatter %[[ARG]]
+  // CHECK-SAME: on @grid0 grid_axes = [2]
   // CHECK-SAME: scatter_axis = 1 root = [1]
   // CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
-  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2]
+  %0 = shard.scatter %arg0 on @grid0 grid_axes = [2]
     scatter_axis = 1 root = [1]
     : (tensor<3x4xf32>) -> tensor<3x1xf32>
   return %0 : tensor<3x1xf32>
@@ -491,11 +491,11 @@ func.func @scatter_static_dimensions(
 func.func @scatter_dynamic_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
     %arg0 : tensor<?xf32>) -> tensor<?xf32> {
-  // CHECK-NEXT: mesh.scatter %[[ARG]]
-  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1]
+  // CHECK-NEXT: shard.scatter %[[ARG]]
+  // CHECK-SAME: on @grid3 grid_axes = [0, 1]
   // CHECK-SAME: scatter_axis = 0 root = [1, 2]
   // CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
-  %0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1]
+  %0 = shard.scatter %arg0 on @grid3 grid_axes = [0, 1]
     scatter_axis = 0 root = [1, 2]
     : (tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
@@ -508,12 +508,12 @@ func.func @scatter_dynamic_root(
     // CHECK-SAME: %[[ARG1:.*]]: index
     %arg1 : index
     ) -> tensor<1xi8> {
-  // CHECK-NEXT: mesh.scatter %[[ARG0]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.scatter %[[ARG0]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: scatter_axis = 0
   // CHECK-SAME: root = [1, %[[ARG1]]]
   // CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
-  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 2]
     scatter_axis = 0
     root = [1, %arg1]
     : (tensor<8xi8>, index) -> tensor<1xi8>
@@ -524,11 +524,11 @@ func.func @scatter_dynamic_root(
 func.func @send_static_destination(
     // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
-  // CHECK-NEXT: mesh.send %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.send %[[ARG]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: destination = [0, 1]
   // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
-  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.send %arg0 on @grid0 grid_axes = [0, 2]
     destination = [0, 1]
     : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -541,11 +541,11 @@ func.func @send_dynamic_destination(
     // CHECK-SAME: %[[ARG1:.*]]: index
     %arg1 : index
     ) -> tensor<2xi8> {
-  // CHECK-NEXT: mesh.send %[[ARG0]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.send %[[ARG0]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: destination = [1, %[[ARG1]]]
   // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
-  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.send %arg0 on @grid0 grid_axes = [0, 2]
     destination = [1, %arg1]
     : (tensor<2xi8>, index) -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -555,11 +555,11 @@ func.func @send_dynamic_destination(
 func.func @shift(
     // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
-  // CHECK-NEXT: mesh.shift %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-NEXT: shard.shift %[[ARG]]
+  // CHECK-SAME: on @grid0 grid_axes = [0, 2]
   // CHECK-SAME: shift_axis = 2 offset = -2 rotate
   // CHECK-SAME: : tensor<2xi8> -> tensor<2xi8>
-  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2]
+  %0 = shard.shift %arg0 on @grid0 grid_axes = [0, 2]
     shift_axis = 2 offset = -2 rotate
     : tensor<2xi8> -> tensor<2xi8>
   return %0 : tensor<2xi8>
@@ -570,16 +570,16 @@ func.func @update_halo(
     // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
   // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
-  // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0
+  // CHECK-NEXT: %[[UH1:.*]] = shard.update_halo %[[ARG]] on @grid0
   // CHECK-SAME: split_axes = {{\[\[}}0]]
   // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
   %c2 = arith.constant 2 : i64
-  %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+  %uh1 = shard.update_halo %arg0 on @grid0 split_axes = [[0]]
     halo_sizes = [2, %c2] : memref<12x12xi8>
-  // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0
+  // CHECK-NEXT: %[[UH2:.*]] = shard.update_halo %[[UH1]] on @grid0
   // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
   // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8>
-  %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]]
+  %uh2 = shard.update_halo %uh1 on @grid0 split_axes = [[0], [1]]
     halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8>
   return
 }
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
new file mode 100644
index 0000000000000..c2572cc3b987b
--- /dev/null
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -0,0 +1,317 @@
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
+// RUN:   %s | FileCheck %s
+
+shard.grid @grid_1d(shape = 2)
+
+// CHECK-LABEL: func @return_sharding
+func.func @return_sharding(
+  // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
+  %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> (tensor<1xf32>, !shard.sharding) {
+) -> (tensor<2xf32>, !shard.sharding) {
+  %ssharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %sharded = shard.shard %arg0 to %ssharded  : tensor<2xf32>
+  // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}0]] : !shard.sharding
+  %r = shard.get_sharding %sharded : tensor<2xf32> -> !shard.sharding
+  // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !shard.sharding
+  return %sharded, %r : tensor<2xf32>, !shard.sharding
+}
+
+// CHECK-LABEL: func @full_replication
+func.func @full_replication(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+  %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<2xi8> {
+) -> tensor<2xi8> {
+  %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0  : tensor<2xi8>
+  %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %1 = shard.shard %0 to %s1  annotate_for_users : tensor<2xi8>
+  // CHECK: return %[[ARG]] : tensor<2xi8>
+  return %1 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @sharding_triplet
+func.func @sharding_triplet(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
+  %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> tensor<2xf32> {
+) -> tensor<2xf32> {
+  // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[ARG]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
+  %ssharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %sharded = shard.shard %arg0 to %ssharded  : tensor<2xf32>
+  %ssharded_0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %sharded_0 = shard.shard %sharded to %ssharded_0  annotate_for_users : tensor<2xf32>
+  %ssharded_1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %sharded_1 = shard.shard %sharded_0 to %ssharded_1  : tensor<2xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
+  return %sharded_1 : tensor<2xf32>
+}
+
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>
+  %arg0: tensor<2x2xi8>
+// CHECK-SAME: -> tensor<2x1xi8> {
+) -> tensor<2x2xi8> {
+  // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[ARG]] on @grid_1d
+  // CHECK-SAME: grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
+  %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0  : tensor<2x2xi8>
+  %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+  %1 = shard.shard %0 to %s1  annotate_for_users : tensor<2x2xi8>
+  // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
+  return %1 : tensor<2x2xi8>
+}
+
+// CHECK-LABEL: func @non_tensor_value
+func.func @non_tensor_value(
+  // CHECK-SAME: %[[ARG:.*]]: i8
+  %arg0: i8
+// CHECK-SAME: -> i8 {
+) -> i8 {
+  // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8
+  %0 = arith.addi %arg0, %arg0 : i8
+  // CHECK: return %[[RES]] : i8
+  return %0 : i8
+}
+
+// CHECK-LABEL: func @unary_elementwise
+func.func @unary_elementwise(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
+  %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+  %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0  : tensor<2xi8>
+  %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %1 = shard.shard %0 to %s1  annotate_for_users : tensor<2xi8>
+  // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
+  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+  %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %3 = shard.shard %2 to %s3  : tensor<2xi8>
+  %s4 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %4 = shard.shard %3 to %s4  annotate_for_users : tensor<2xi8>
+  // CHECK: return %[[RES]] : tensor<1xi8>
+  return %4 : tensor<2xi8>
+}
+
+// full replication -> shard axis -> abs -> shard axis -> full replication
+// CHECK-LABEL: func @unary_elementwise_with_resharding
+func.func @unary_elementwise_with_resharding(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+  %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<2xi8> {
+) -> tensor<2xi8> {
+  // CHECK: %[[SLICE:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 0
+  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
+  %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0  : tensor<2xi8>
+  %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %1 = shard.shard %0 to %s1  annotate_for_users : tensor<2xi8>
+  // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
+  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+  // CHECK: %[[RES:.*]] = shard.all_gather %[[ABS]] on @grid_1d
+  // CHECK-SAME: grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+  %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %3 = shard.shard %2 to %s3  : tensor<2xi8>
+  %s4 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %4 = shard.shard %3 to %s4  annotate_for_users : tensor<2xi8>
+  // CHECK: return %[[RES]] : tensor<2xi8>
+  return %4 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @binary_elementwise
+func.func @binary_elementwise(
+  // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>,
+  %arg0: tensor<2xi8>,
+  // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8>
+  %arg1: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+  %sarg0_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %arg0_sharded = shard.shard %arg0 to %sarg0_sharded  : tensor<2xi8>
+  %sop_arg0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %op_arg0 = shard.shard %arg0_sharded to %sop_arg0  annotate_for_users : tensor<2xi8>
+  %sarg1_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %arg1_sharded = shard.shard %arg1 to %sarg1_sharded  : tensor<2xi8>
+  %sop_arg1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %op_arg1 = shard.shard %arg1_sharded to %sop_arg1  annotate_for_users : tensor<2xi8>
+  // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
+  %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
+  %sop_res_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %op_res_sharded = shard.shard %op_res to %sop_res_sharded  : tensor<2xi8>
+  %sres = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %res = shard.shard %op_res_sharded to %sres  annotate_for_users : tensor<2xi8>
+  // CHECK: return %[[RES]] : tensor<1xi8>
+  return %res : tensor<2xi8>
+}
+
+// reshard
+// abs
+// reshard
+// abs
+// reshard
+// CHECK-LABEL: func @multiple_chained_ops
+func.func @multiple_chained_ops(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+  %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+  // CHECK: %[[RESHARD1:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 0
+  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
+  %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0  : tensor<2xi8>
+  %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %1 = shard.shard %0 to %s1  annotate_for_users : tensor<2xi8>
+  // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
+  %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+  // CHECK: %[[RESHARD2:.*]] = shard.all_gather %[[ABS1]] on @grid_1d
+  // CHECK-SAME: grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+  %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %3 = shard.shard %2 to %s3  : tensor<2xi8>
+  %s4 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %4 = shard.shard %3 to %s4  annotate_for_users : tensor<2xi8>
+  // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
+  %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
+  // CHECK: %[[RESHARD3:.*]] = shard.all_slice %[[ABS2]] on @grid_1d grid_axes = [0] slice_axis = 0 :
+  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
+  %s6 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %6 = shard.shard %5 to %s6  : tensor<2xi8>
+  %s7 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %7 = shard.shard %6 to %s7  annotate_for_users : tensor<2xi8>
+  // CHECK: return %[[RESHARD3]] : tensor<1xi8>
+  return %7 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @incomplete_sharding
+func.func @incomplete_sharding(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
+  %arg0: tensor<8x16xf32>
+// CHECK-SAME: -> tensor<4x16xf32> {
+) -> tensor<8x16xf32> {
+  %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0  annotate_for_users : tensor<8x16xf32>
+  // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %s2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %2 = shard.shard %1 to %s2  : tensor<8x16xf32>
+  // CHECK: return %[[RES]] : tensor<4x16xf32>
+  return %2 : tensor<8x16xf32>
+}
+
+shard.grid @grid_1d_4(shape = 4)
+
+// CHECK-LABEL: func @ew_chain_with_halo
+func.func @ew_chain_with_halo(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
+  %arg0: tensor<8x16xf32>,
+  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32>
+  %arg1: tensor<1xf32>,
+  // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32>
+  %arg2: tensor<1xf32>)
+  // CHECK-SAME: -> tensor<5x16xf32>
+   -> tensor<8x16xf32> {
+  %ssharded = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+  %sharded = shard.shard %arg0 to %ssharded  annotate_for_users : tensor<8x16xf32>
+  // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+  %0 = tosa.tanh %sharded : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %ssharded_0 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+  %sharded_0 = shard.shard %0 to %ssharded_0  : tensor<8x16xf32>
+  %ssharded_1 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+  %sharded_1 = shard.shard %sharded_0 to %ssharded_1  annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+  %1 = tosa.abs %sharded_1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  %ssharded_2 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+  %sharded_2 = shard.shard %1 to %ssharded_2  : tensor<8x16xf32>
+  %ssharded_4 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+  %sharded_4 = shard.shard %sharded_2 to %ssharded_4  annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32>
+  %sharding_1 = shard.sharding @grid_1d_4 split_axes = [[]] : !shard.sharding
+  %zero_point_1 = shard.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32>
+  %zero_point_2 = shard.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32>
+  %2 = tosa.negate %sharded_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
+  %ssharded_5 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+  %sharded_5 = shard.shard %2 to %ssharded_5  : tensor<8x16xf32>
+  %ssharded_6 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+  %sharded_6 = shard.shard %sharded_5 to %ssharded_6  annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
+  return %sharded_6 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func @test_shard_update_halo
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
+func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+  %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] : !shard.sharding
+  // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
+  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
+  // CHECK: %[[UH:.*]] = shard.update_halo %[[inserted_slice]] on @grid_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
+  %sharded = shard.shard %arg0 to %sharding : tensor<1200x1200xi64>
+  %sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !shard.sharding
+  %sharded_1 = shard.shard %sharded to %sharding_0 : tensor<1200x1200xi64>
+  %sharded_3 = shard.shard %sharded_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+  // CHECK: return %[[UH]] : tensor<304x1200xi64>
+  return %sharded_3 : tensor<1200x1200xi64>
+}
+
+shard.grid @grid4x4(shape = 4x4)
+// CHECK-LABEL: func @test_shard_update_halo2d
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
+func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+  %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] : !shard.sharding
+  // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
+  // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
+  // CHECK: %[[UH:.*]] = shard.update_halo %[[inserted_slice]] on @grid4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
+  %sharded = shard.shard %arg0 to %sharding : tensor<1200x1200xi64>
+  %sharding_0 = shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !shard.sharding
+  %sharded_1 = shard.shard %sharded to %sharding_0 : tensor<1200x1200xi64>
+  %sharded_3 = shard.shard %sharded_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+  // CHECK: return %[[UH]] : tensor<303x307xi64>
+  return %sharded_3 : tensor<1200x1200xi64>
+}
+
+shard.grid @grid(shape = 2)
+// CHECK-LABEL: func.func @test_reduce_0d(
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
+func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor<i32>) {
+  %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+  %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
+  %4 = tensor.empty() : tensor<i32>
+  %sharding_out = shard.sharding @grid split_axes = [[]] : !shard.sharding
+  %sharded_out = shard.shard %4 to %sharding_out : tensor<i32>
+  %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
+  // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
+  %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<i32>) dimensions = [0, 1] 
+    (%in: i32, %init: i32) {
+      %6 = arith.addi %in, %init : i32
+      linalg.yield %6 : i32
+    }
+  // CHECK: %[[all_reduce:.*]] = shard.all_reduce %[[reduced]] on @grid grid_axes = [0] : tensor<i32> -> tensor<i32>
+  %sharded_red = shard.shard %reduced to %sharding_out : tensor<i32>
+  %sharded_ret = shard.shard %sharded_red to %sharding_out annotate_for_users : tensor<i32>
+  // CHECK: return %[[all_reduce]] : tensor<i32>
+  return %sharded_ret : tensor<i32>
+}
+
+// CHECK-LABEL: func.func @test_reduce_1d(
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
+func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
+  %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+  %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
+  %4 = tensor.empty() : tensor<6xi32>
+  %sharded_out = shard.shard %4 to %sharding : tensor<6xi32>
+  %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
+  // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
+  %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1] 
+    (%in: i32, %init: i32) {
+      %6 = arith.addi %in, %init : i32
+      linalg.yield %6 : i32
+    }
+  // CHECK-NOT: shard.all_reduce
+  %sharded_red = shard.shard %reduced to %sharding : tensor<6xi32>
+  %sharded_ret = shard.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32>
+  // CHECK: return %[[reduced]] : tensor<3xi32>
+  return %sharded_ret : tensor<6xi32>
+}
diff --git a/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir
new file mode 100644
index 0000000000000..33c7a8f96464d
--- /dev/null
+++ b/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -test-grid-process-multi-index-op-lowering %s | FileCheck %s
+
+shard.grid @grid2d(shape = ?x?)
+
+// CHECK-LABEL: func.func @multi_index_2d_grid
+func.func @multi_index_2d_grid() -> (index, index) {
+  // CHECK: %[[LINEAR_IDX:.*]] = shard.process_linear_index on @grid2d : index
+  // CHECK: %[[SHARD_SHAPE:.*]]:2 = shard.grid_shape @grid2d : index, index
+  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[SHARD_SHAPE]]#0, %[[SHARD_SHAPE]]#1) : index, index
+  %0:2 = shard.process_multi_index on @grid2d : index, index
+  // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @multi_index_2d_grid_single_inner_axis
+func.func @multi_index_2d_grid_single_inner_axis() -> index {
+  // CHECK: %[[LINEAR_IDX:.*]] = shard.process_linear_index on @grid2d : index
+  // CHECK: %[[SHARD_SHAPE:.*]]:2 = shard.grid_shape @grid2d : index, index
+  // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[SHARD_SHAPE]]#0, %[[SHARD_SHAPE]]#1) : index, index
+  %0 = shard.process_multi_index on @grid2d axes = [0] : index
+  // CHECK: return %[[MULTI_IDX]]#0 : index
+  return %0 : index
+}
diff --git a/mlir/test/Dialect/Shard/resharding-partition.mlir b/mlir/test/Dialect/Shard/resharding-partition.mlir
new file mode 100644
index 0000000000000..ff9e8408aa7fd
--- /dev/null
+++ b/mlir/test/Dialect/Shard/resharding-partition.mlir
@@ -0,0 +1,168 @@
+// RUN: mlir-opt -test-grid-resharding-partition %s | FileCheck %s
+
+shard.grid @grid_1d(shape = 2)
+shard.grid @grid_1d_dynamic(shape = ?)
+
+// CHECK-LABEL: func @same_source_and_target_sharding
+func.func @same_source_and_target_sharding(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+  %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+  %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<2xf32>
+  %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xf32>
+  // CHECK: return %[[ARG]]
+  return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @identical_source_and_target_sharding
+func.func @identical_source_and_target_sharding(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+  %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+  %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<2xf32>
+  %1 = shard.shard %0 to %s0 annotate_for_users : tensor<2xf32>
+  // CHECK: return %[[ARG]]
+  return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis
+func.func @split_replicated_tensor_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
+  %arg0: tensor<3x14xf32>
+) -> tensor<3x14xf32> {
+  // CHECK: %[[ALL_SLICE:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 1
+  // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
+  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+  %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<3x14xf32>
+  %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<3x14xf32>
+  // CHECK: return %[[RESULT]] : tensor<3x14xf32>
+  return %1 : tensor<3x14xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
+func.func @split_replicated_tensor_axis_dynamic(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
+  %arg0: tensor<?x3x?xf32>
+) -> tensor<?x3x?xf32> {
+  // CHECK: %[[RESULT:.*]] = shard.all_slice %[[ARG]] on @grid_1d_dynamic grid_axes = [0] slice_axis = 0
+  // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
+  %s0 = shard.sharding @grid_1d_dynamic split_axes = [[], [], []] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<?x3x?xf32>
+  %s1 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32>
+  // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
+  return %1 : tensor<?x3x?xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+  // CHECK: %[[TARGET_SHARD:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
+  %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[RES]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis_dynamic_grid
+func.func @move_split_axis_dynamic_grid(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+  // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_1d_dynamic grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
+  // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
+  %s0 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = shard.sharding @grid_1d_dynamic split_axes = [[], [0]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[RES]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_dynamic_axis
+func.func @move_split_dynamic_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+  %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+  // CHECK: %[[TARGET_SHARD:.*]] = shard.all_to_all %[[ARG]] on @grid_1d grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
+  %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<?x14xf32>
+  %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
+  // CHECK: return %[[RES]] : tensor<?x14xf32>
+  return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis
+func.func @unshard_static_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+  // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
+  %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_last_axis
+func.func @unshard_static_last_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
+  // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
+  %s0 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = shard.sharding @grid_1d split_axes = [[], []] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_dynamic_axis
+func.func @unshard_dynamic_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+  %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+  // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[ARG]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+  %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<?x14xf32>
+  %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
+  return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis_on_dynamic_grid_axis
+func.func @unshard_static_axis_on_dynamic_grid_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>  
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+  // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d_dynamic grid_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+  // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
+  %s0 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+  %s1 = shard.sharding @grid_1d_dynamic split_axes = [[]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[RES]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
similarity index 100%
rename from mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir
rename to mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
diff --git a/mlir/test/Dialect/Shard/sharding-propagation.mlir b/mlir/test/Dialect/Shard/sharding-propagation.mlir
new file mode 100644
index 0000000000000..34aaf0598b3f0
--- /dev/null
+++ b/mlir/test/Dialect/Shard/sharding-propagation.mlir
@@ -0,0 +1,301 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s
+
+shard.grid @grid_2(shape = 2)
+shard.grid @grid_1d(shape = ?)
+shard.grid @grid_2d(shape = 2x4)
+shard.grid @grid_3d(shape = ?x?x?)
+
+// CHECK-LABEL: func.func @element_wise_empty_sharding_info
+func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT: tosa.sigmoid
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT: return
+  return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_def
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]]  : tensor<8x16xf32>
+  %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+  %1 = shard.shard %0 to %s1  : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V2]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_use
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
+  %s0 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 annotate_for_users  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]]  : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V2]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_output
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]]  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
+  %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users  : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_input
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]]  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = shard.shard %[[V0]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
+  %s0 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.sigmoid %[[V1]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]]  : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @arrow_structure
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
+  // CHECK-NEXT:  %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+  // CHECK-NEXT:  %[[V1:.*]] = shard.shard %[[ARG]] to %[[S1]] annotate_for_users  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.tanh %[[V1]]
+  // CHECK-NEXT:  %[[V3:.*]] = shard.shard %[[V2]] to %[[S1]]  : tensor<8x16xf32>
+  %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V4:.*]] = shard.shard %[[V3]] to %[[S1]] annotate_for_users  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V5:.*]] = tosa.abs %[[V4]]
+  // CHECK-NEXT:  %[[V6:.*]] = shard.shard %[[V5]] to %[[S1]]  : tensor<8x16xf32>
+  %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S3:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+  // CHECK-NEXT:  %[[ZP1:.*]] = shard.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32>
+  // CHECK-NEXT:  %[[ZP2:.*]] = shard.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32>
+  // CHECK-NEXT:  %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]]
+  // CHECK-NEXT:  %[[V8:.*]] = shard.shard %[[V7]] to %[[S1]]  : tensor<8x16xf32>
+  %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
+  %s3 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+  %3 = shard.shard %2 to %s3  : tensor<8x16xf32>
+  // CHECK-NEXT: return %[[V6]], %[[V8]]
+  return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = shard.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0]] : !shard.sharding
+  // CHECK-NEXT:  %[[V1:.*]] = shard.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[S2:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+  // CHECK-NEXT:  %[[ZP:.*]] = shard.shard %[[ARG2]] to %[[S2]] annotate_for_users  : tensor<1xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+  %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]]  : tensor<2x16x32xf32>
+  %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+  %1 = shard.shard %0 to %s1  : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n
+// CHECK-SAME:     [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+  // CHECK: [[vsharding:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding
+  // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
+  // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [], [1]] : !shard.sharding
+  // CHECK: [[vsharded_1:%.*]] = shard.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32>
+  // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+  // CHECK: [[vsharded_3:%.*]] = shard.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
+  // CHECK: [[v0:%.*]] = tosa.matmul
+  %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
+  // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0], [1]] : !shard.sharding
+  // CHECK: [[vsharded_5:%.*]] = shard.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32>
+  %s1 = shard.sharding @grid_2d split_axes = [[], [0], [1]] : !shard.sharding
+  %1 = shard.shard %0 to %s1  : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return [[vsharded_5]]
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
+// CHECK-SAME:     [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
+func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+  // CHECK: [[vsharding:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0], [1]] : !shard.sharding
+  %s0 = shard.sharding @grid_2d split_axes = [[], [0], [1]] : !shard.sharding
+  // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32>
+  %arg0_s = shard.shard %arg0 to %s0 : tensor<2x16x8xf32>
+  // CHECK: [[vsharded_0:%.*]] = shard.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
+  // CHECK: [[vsharding_1:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1]] : !shard.sharding
+  // CHECK: [[vsharded_2:%.*]] = shard.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32>
+  // CHECK: [[vsharding_3:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+  // CHECK: [[vsharded_4:%.*]] = shard.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32>
+  // CHECK: [[v0:%.*]] = tosa.matmul
+  // CHECK: [[vsharding_5:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding
+  // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32>
+  %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
+  // CHECK: return [[vsharded_6]]
+  return %0 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
+// CHECK-SAME:     %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+  // CHECK-NEXT:  %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1], [0]] : !shard.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = shard.shard %[[ARG0]] to %[[S0]] annotate_for_users  : tensor<2x16x8xf32>
+  %s0 = shard.sharding @grid_2d split_axes = [[], [1], [0]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 annotate_for_users  : tensor<2x16x8xf32>
+  // CHECK-NEXT:  %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding
+  // CHECK-NEXT:  %[[V1:.*]] = shard.shard %[[ARG1]] to %[[S1]] annotate_for_users  : tensor<2x8x32xf32>
+  %s1 = shard.sharding @grid_2d split_axes = [[], [0]] : !shard.sharding
+  %1 = shard.shard %arg1 to %s1 annotate_for_users  : tensor<2x8x32xf32>
+  // CHECK-NEXT:  %[[S2:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+  // CHECK-NEXT:  %[[ZP:.*]] = shard.shard %[[ARG2]] to %[[S2]] annotate_for_users  : tensor<1xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+  %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x16x32xf32>
+  // CHECK-NEXT:  %[[S3:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1]] : !shard.sharding
+  // CHECK-NEXT:  %[[V3:.*]] = shard.shard %[[V2]] to %[[S3]]  : tensor<2x16x32xf32>
+  // CHECK-NEXT:  return %[[V3]]
+  return %2 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @resolve_conflicting_annotations
+func.func @resolve_conflicting_annotations(
+  // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
+  %arg0: tensor<2x3xf32>,
+  // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
+  %arg1: tensor<3x2xf32>,
+  // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
+  %out_dps: tensor<2x2xf32>
+// CHECK-SAME: ) -> tensor<2x2xf32> {
+) -> tensor<2x2xf32> {
+  // CHECK: %[[SIN1_SHARDED1:.*]] = shard.sharding @grid_2 split_axes = {{\[\[}}0]] : !shard.sharding
+  // CHECK-NEXT:  %[[IN1_SHARDED1:.*]] = shard.shard %[[IN1]] to %[[SIN1_SHARDED1]]  : tensor<2x3xf32>
+  // CHECK: %[[SIN2_SHARDED:.*]] = shard.sharding @grid_2 split_axes = {{\[\[}}]] : !shard.sharding
+  // CHECK-NEXT:  %[[IN1_SHARDED2:.*]] = shard.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<2x3xf32>
+  // CHECK-NEXT:  %[[IN2_SHARDED:.*]] = shard.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<3x2xf32>
+  // CHECK-NEXT:  %[[OUT_DPS_SHARDED:.*]] = shard.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<2x2xf32>
+  %sarg0_sharded = shard.sharding @grid_2 split_axes = [[0]] : !shard.sharding
+  %arg0_sharded = shard.shard %arg0 to %sarg0_sharded  : tensor<2x3xf32>
+  // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+  // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+  %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+    outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+  // CHECK-NEXT: %[[RES:.*]] = shard.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32>
+  %sres_sharded = shard.sharding @grid_2 split_axes = [[]] : !shard.sharding
+  %res_sharded = shard.shard %res to %sres_sharded  : tensor<2x2xf32>
+  // CHECK: return %[[RES]] : tensor<2x2xf32>
+  return %res_sharded : tensor<2x2xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(a)
+// The sharding propagation results in unnecessary reshards,
+//   an optimization pass should be able to remove them.
+// CHECK-LABEL: func.func @mlp_1d_weight_stationary
+// CHECK-SAME:     [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
+func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
+  %s0 = shard.sharding @grid_1d split_axes = [[], [], [0, 1, 2]] : !shard.sharding
+  %sharded0 = shard.shard %arg0 to %s0 : tensor<2x4x8xf32>
+  %sharded1 = shard.shard %arg1 to %s0 : tensor<2x8x32xf32>
+  // CHECK: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !shard.sharding
+  // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
+  // CHECK: [[vsharded_0:%.*]] = shard.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32>
+  // CHECK: [[vsharded_1:%.*]] = shard.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
+  // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}], [0, 1, 2]] : !shard.sharding
+  // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32>
+  // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}]] : !shard.sharding
+  // CHECK: [[vsharded_5:%.*]] = shard.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32>
+  // CHECK: [[v0:%.*]] = tosa.matmul
+  %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x32xf32>
+  // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32>
+  // CHECK: [[vsharded_7:%.*]] = shard.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
+  // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+  // CHECK: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32>
+  %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+  %sharding = shard.sharding @grid_1d split_axes = [[], [0, 1, 2]] : !shard.sharding
+  // CHECK: [[vsharded_9:%.*]] = shard.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32>
+  %sharded2 = shard.shard %arg2 to %sharding  : tensor<2x32x8xf32>
+  // CHECK: [[vsharded_10:%.*]] = shard.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
+  // CHECK: [[vsharded_11:%.*]] = shard.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
+  // CHECK: [[v2:%.*]] = tosa.matmul
+  %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x8xf32>
+  // CHECK: [[vsharded_12:%.*]] = shard.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
+  %s4 = shard.sharding @grid_1d split_axes = [[], [], [0, 1, 2]] : !shard.sharding
+  %4 = shard.shard %3 to %s4  : tensor<2x4x8xf32>
+  // CHECK: return [[vsharded_12]]
+  return %4 : tensor<2x4x8xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(b)
+// The sharding propagation results in unnecessary reshards,
+//   an optimization pass should be able to remove them.
+// CHECK-LABEL: func.func @mlp_2d_weight_stationary
+// CHECK-SAME:     [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
+func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
+    // CHECK: [[vsharding:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !shard.sharding
+  %s0 = shard.sharding @grid_3d split_axes = [[], [], [0, 1, 2]] : !shard.sharding
+    // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
+  %arg0_s = shard.shard %arg0 to %s0  : tensor<2x4x8xf32>
+    // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [0], [1, 2]] : !shard.sharding
+  %s1 = shard.sharding @grid_3d split_axes = [[], [0], [1, 2]] : !shard.sharding
+    // CHECK: [[vsharded_1:%.*]] = shard.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32>
+  %arg1_s = shard.shard %arg1 to %s1  : tensor<2x8x32xf32>
+    // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}]] : !shard.sharding
+    // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32>
+    // CHECK: [[vsharded_4:%.*]] = shard.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32>
+    // CHECK: [[vsharded_5:%.*]] = shard.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
+    // CHECK: [[v0:%.*]] = tosa.matmul
+  %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x32xf32>
+    // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32>
+  %2 = shard.shard %1 to %s0  : tensor<2x4x32xf32>
+    // CHECK: [[vsharded_7:%.*]] = shard.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32>
+    // CHECK: [[v1:%.*]] = tosa.sigmoid 
+    // CHECK: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32>
+  %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+    // CHECK: [[vsharding_9:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [1, 2], [0]] : !shard.sharding
+  %s2 = shard.sharding @grid_3d split_axes = [[], [1, 2], [0]] : !shard.sharding
+    // CHECK: [[vsharded_10:%.*]] = shard.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32>
+  %arg2_s = shard.shard %arg2 to %s2  : tensor<2x32x8xf32>
+    // CHECK: [[vsharded_11:%.*]] = shard.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32>
+    // CHECK: [[vsharded_12:%.*]] = shard.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
+    // CHECK: [[v2:%.*]] = tosa.matmul
+  %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>)  -> tensor<2x4x8xf32>
+    // CHECK: [[vsharded_13:%.*]] = shard.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
+  %5 = shard.shard %4 to %s0  : tensor<2x4x8xf32>
+    // CHECK: [[vsharded_14:%.*]] = shard.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
+  %6 = shard.shard %5 to %s0 annotate_for_users  : tensor<2x4x8xf32>
+    // CHECK: return [[vsharded_14]]
+  return %6 : tensor<2x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @elementwise_duplicated_chain
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V4:.*]] = tosa.sigmoid %[[V3]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V5:.*]] = shard.shard %[[V4]] to %[[S0]]  : tensor<8x16xf32>
+  %s0 = shard.sharding @grid_2d split_axes = [[]] : !shard.sharding
+  %2 = shard.shard %1 to %s0 : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V5]]
+  return %2 : tensor<8x16xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Shard/simplifications.mlir
similarity index 69%
rename from mlir/test/Dialect/Mesh/simplifications.mlir
rename to mlir/test/Dialect/Shard/simplifications.mlir
index e955f4c134259..33cd490be744a 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Shard/simplifications.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
 
-mesh.mesh @mesh0(shape = 4x2)
-mesh.mesh @mesh1(shape = 4)
+shard.grid @grid0(shape = 4x2)
+shard.grid @grid1(shape = 4)
 
 // Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
 // `all_reduce(x + y)`.
@@ -11,13 +11,13 @@ func.func @all_reduce_arith_addf_endomorphism(
     %arg0: tensor<5xf32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg1: tensor<5xf32>) -> tensor<5xf32> {
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf32>
-  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf32>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
   %2 = arith.addf %0, %1 : tensor<5xf32>
-  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]]
   // CHECK: return %[[ALL_REDUCE_RES]]
   return %2 : tensor<5xf32>
 }
@@ -28,13 +28,13 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result(
     %arg0: tensor<5xf32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf32>
-  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf32>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
   %2 = arith.addf %0, %1 : tensor<5xf32>
-  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]]
   // CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]]
   return %2, %2 : tensor<5xf32>, tensor<5xf32>
 }
@@ -46,11 +46,11 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
     %arg0: tensor<5xf32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
-  // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]]
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+  // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]]
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf32>
-  // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]]
-  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+  // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]]
+  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf32>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]]
   %2 = arith.addf %0, %1 : tensor<5xf32>
@@ -58,17 +58,17 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
   return %0, %2 : tensor<5xf32>, tensor<5xf32>
 }
 
-// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
-func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid
+func.func @all_reduce_arith_addf_no_endomorphism_different_grid(
     // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg0: tensor<5xf32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg1: tensor<5xf32>) -> tensor<5xf32> {
-  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf32>
-  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
-  %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid1
+  %1 = shard.all_reduce %arg1 on @grid1 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf32>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
   %2 = arith.addf %0, %1 : tensor<5xf32>
@@ -76,17 +76,17 @@ func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
   return %2 : tensor<5xf32>
 }
 
-// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
-func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes
+func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes(
     // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg0: tensor<5xf32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg1: tensor<5xf32>) -> tensor<5xf32> {
-  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf32>
-  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
-  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [1]
+  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [1]
     : tensor<5xf32> -> tensor<5xf32>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
   %2 = arith.addf %0, %1 : tensor<5xf32>
@@ -100,11 +100,11 @@ func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
     %arg0: tensor<5xf32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg1: tensor<5xf32>) -> tensor<5xf32> {
-  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = max
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = max
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] reduction = max
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max
     : tensor<5xf32> -> tensor<5xf32>
-  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
-  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0]
+  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf32>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
   %2 = arith.addf %0, %1 : tensor<5xf32>
@@ -118,11 +118,11 @@ func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_elemen
     %arg0: tensor<5xf32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg1: tensor<5xf32>) -> tensor<5xf64> {
-  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf64>
-  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
-  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0]
+  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
     : tensor<5xf32> -> tensor<5xf64>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
   %2 = arith.addf %0, %1 : tensor<5xf64>
@@ -138,13 +138,13 @@ func.func @all_reduce_arith_minimumf_endomorphism(
     %arg0: tensor<5xf32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
     %arg1: tensor<5xf32>) -> tensor<5xf32> {
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min
     : tensor<5xf32> -> tensor<5xf32>
-  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
+  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min
     : tensor<5xf32> -> tensor<5xf32>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
   %2 = arith.minimumf %0, %1 : tensor<5xf32>
-  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min
   // CHECK: return %[[ALL_REDUCE_RES]]
   return %2 : tensor<5xf32>
 }
@@ -155,13 +155,13 @@ func.func @all_reduce_arith_minsi_endomorphism(
     %arg0: tensor<5xi32>,
     // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
     %arg1: tensor<5xi32>) -> tensor<5xi32> {
-  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min
     : tensor<5xi32> -> tensor<5xi32>
-  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
+  %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min
     : tensor<5xi32> -> tensor<5xi32>
   // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
   %2 = arith.minsi %0, %1 : tensor<5xi32>
-  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min
   // CHECK: return %[[ALL_REDUCE_RES]]
   return %2 : tensor<5xi32>
 }
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
deleted file mode 100644
index 8598d81ff6cfa..0000000000000
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ /dev/null
@@ -1,52 +0,0 @@
-// RUN: mlir-opt \
-// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
-// RUN:   %s | FileCheck %s
-
-mesh.mesh @mesh_1d_4(shape = 4)
-
-// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
-func.func @tensor_empty_static_sharded_dims_offsets() -> () {
-  %b = tensor.empty() : tensor<8x16xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
-  %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
-  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
-  // CHECK:  %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
-  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
-  // CHECK-SAME: ] : index, index
-  // CHECK:  tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
-
-  return
-}
-
-// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
-// CHECK-SAME: %[[A0:.*]]: index
-func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
-  %b = tensor.empty(%arg0) : tensor<8x?xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
-  %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
-  // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
-  // CHECK:  %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
-  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape dims = [8, %[[A0]]
-  // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
-  // CHECK-SAME: ] : index, index
-  // CHECK:  tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
-
-  return
-}
-
-// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
-func.func @tensor_empty_same_static_dims_sizes() -> () {
-  %b = tensor.empty() : tensor<16x16xf32>
-  %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !mesh.sharding
-  %sharded= mesh.shard %b to %sharding : tensor<16x16xf32>
-  // CHECK-NEXT:  tensor.empty() : tensor<4x16xf32>
-
-  return
-}
-
-// CHECK-LABEL: func @tensor_empty_0d
-func.func @tensor_empty_0d() -> () {
-  tensor.empty() : tensor<f32>
-  // CHECK-NEXT:  tensor.empty() : tensor<f32>
-  return
-}
diff --git a/mlir/test/Dialect/Tensor/shard-partition.mlir b/mlir/test/Dialect/Tensor/shard-partition.mlir
new file mode 100644
index 0000000000000..5918ee1eddf57
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/shard-partition.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
+// RUN:   %s | FileCheck %s
+
+shard.grid @grid_1d_4(shape = 4)
+
+// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
+func.func @tensor_empty_static_sharded_dims_offsets() -> () {
+  %b = tensor.empty() : tensor<8x16xf32>
+  %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+  %sharded= shard.shard %b to %sharding : tensor<8x16xf32>
+  // CHECK:  %[[sharding:.*]] = shard.sharding @grid_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+  // CHECK:  %[[proc_multi_idx:.*]] = shard.process_multi_index on @grid_1d_4 : index
+  // CHECK:  %[[V0:.*]]:2 = shard.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+  // CHECK-SAME: ] : index, index
+  // CHECK:  tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
+
+  return
+}
+
+// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
+// CHECK-SAME: %[[A0:.*]]: index
+func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
+  %b = tensor.empty(%arg0) : tensor<8x?xf32>
+  %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+  %sharded= shard.shard %b to %sharding : tensor<8x?xf32>
+  // CHECK:  %[[sharding:.*]] = shard.sharding @grid_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+  // CHECK:  %[[proc_multi_idx:.*]] = shard.process_multi_index on @grid_1d_4 : index
+  // CHECK:  %[[V0:.*]]:2 = shard.shard_shape dims = [8, %[[A0]]
+  // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+  // CHECK-SAME: ] : index, index
+  // CHECK:  tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
+
+  return
+}
+
+// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
+func.func @tensor_empty_same_static_dims_sizes() -> () {
+  %b = tensor.empty() : tensor<16x16xf32>
+  %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !shard.sharding
+  %sharded= shard.shard %b to %sharding : tensor<16x16xf32>
+  // CHECK-NEXT:  tensor.empty() : tensor<4x16xf32>
+
+  return
+}
+
+// CHECK-LABEL: func @tensor_empty_0d
+func.func @tensor_empty_0d() -> () {
+  tensor.empty() : tensor<f32>
+  // CHECK-NEXT:  tensor.empty() : tensor<f32>
+  return
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index eb2f74e8aeca1..3b7bd9b9637a8 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -10,7 +10,7 @@ add_subdirectory(Linalg)
 add_subdirectory(LLVM)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
-add_subdirectory(Mesh)
+add_subdirectory(Shard)
 add_subdirectory(NVGPU)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Shard/CMakeLists.txt
similarity index 51%
rename from mlir/test/lib/Dialect/Mesh/CMakeLists.txt
rename to mlir/test/lib/Dialect/Shard/CMakeLists.txt
index 7bd0493d11a7e..f91c54721e030 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Shard/CMakeLists.txt
@@ -1,14 +1,14 @@
 # Exclude tests from libMLIR.so
-add_mlir_library(MLIRMeshTest
+add_mlir_library(MLIRShardTest
   TestOpLowering.cpp
-  TestReshardingSpmdization.cpp
+  TestReshardingPartition.cpp
   TestSimplifications.cpp
 
   EXCLUDE_FROM_LIBMLIR
   )
-mlir_target_link_libraries(MLIRMeshTest PUBLIC
-  MLIRMeshDialect
-  MLIRMeshTransforms
+mlir_target_link_libraries(MLIRShardTest PUBLIC
+  MLIRShardDialect
+  MLIRShardTransforms
   MLIRPass
   MLIRRewrite
   MLIRTransformUtils
diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Shard/TestOpLowering.cpp
similarity index 80%
rename from mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
rename to mlir/test/lib/Dialect/Shard/TestOpLowering.cpp
index dbae93b380f2b..43f3b3f239181 100644
--- a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestOpLowering.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
@@ -24,17 +24,17 @@ struct TestAllSliceOpLoweringPass
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     SymbolTableCollection symbolTableCollection;
-    mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
+    shard::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
     LogicalResult status =
         applyPatternsGreedily(getOperation(), std::move(patterns));
     (void)status;
     assert(succeeded(status) && "applyPatternsGreedily failed.");
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    mesh::registerAllSliceOpLoweringDialects(registry);
+    shard::registerAllSliceOpLoweringDialects(registry);
   }
   StringRef getArgument() const final {
-    return "test-mesh-all-slice-op-lowering";
+    return "test-grid-all-slice-op-lowering";
   }
   StringRef getDescription() const final {
     return "Test lowering of all-slice.";
@@ -48,21 +48,21 @@ struct TestMultiIndexOpLoweringPass
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     SymbolTableCollection symbolTableCollection;
-    mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
-                                                      symbolTableCollection);
+    shard::populateProcessMultiIndexOpLoweringPatterns(patterns,
+                                                       symbolTableCollection);
     LogicalResult status =
         applyPatternsGreedily(getOperation(), std::move(patterns));
     (void)status;
     assert(succeeded(status) && "applyPatternsGreedily failed.");
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    mesh::registerProcessMultiIndexOpLoweringDialects(registry);
+    shard::registerProcessMultiIndexOpLoweringDialects(registry);
   }
   StringRef getArgument() const final {
-    return "test-mesh-process-multi-index-op-lowering";
+    return "test-grid-process-multi-index-op-lowering";
   }
   StringRef getDescription() const final {
-    return "Test lowering of mesh.process_multi_index op.";
+    return "Test lowering of shard.process_multi_index op.";
   }
 };
 
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
similarity index 75%
rename from mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
rename to mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
index 102e64de4bd1f..ac71ff60fc509 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
@@ -7,9 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Transforms/Partition.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -22,11 +22,11 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
 
 namespace {
 
-struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
+struct TestReshardingRewritePattern : OpRewritePattern<ShardOp> {
   using OpRewritePattern<ShardOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ShardOp op,
@@ -36,18 +36,18 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
     }
 
     SymbolTableCollection symbolTable;
-    mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
-        op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
+    shard::GridOp grid = symbolTable.lookupNearestSymbolFrom<shard::GridOp>(
+        op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr());
 
     bool foundUser = false;
     for (auto user : op->getUsers()) {
       if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
         if (targetShardOp.getAnnotateForUsers() &&
-            mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+            grid == symbolTable.lookupNearestSymbolFrom<shard::GridOp>(
                         targetShardOp,
                         cast<ShardingOp>(
                             targetShardOp.getSharding().getDefiningOp())
-                            .getMeshAttr())) {
+                            .getGridAttr())) {
           foundUser = true;
           break;
         }
@@ -61,22 +61,22 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
     for (auto user : op->getUsers()) {
       auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
       if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
-          symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+          symbolTable.lookupNearestSymbolFrom<shard::GridOp>(
               targetShardOp,
               cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
-                  .getMeshAttr()) != mesh) {
+                  .getGridAttr()) != grid) {
         continue;
       }
 
       ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
       ShapedType sourceShardShape =
-          shardShapedType(op.getResult().getType(), mesh, op.getSharding());
+          shardShapedType(op.getResult().getType(), grid, op.getSharding());
       TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
           builder
               .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
               ->getResult(0));
       TypedValue<ShapedType> targetShard =
-          reshard(builder, mesh, op, targetShardOp, sourceShard);
+          reshard(builder, grid, op, targetShardOp, sourceShard);
       Value newTargetUnsharded =
           builder
               .create<UnrealizedConversionCastOp>(
@@ -90,13 +90,13 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
   }
 };
 
-struct TestMeshReshardingPass
-    : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
+struct TestReshardingPass
+    : public PassWrapper<TestReshardingPass, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReshardingPass)
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
+    patterns.insert<TestReshardingRewritePattern>(&getContext());
     if (failed(applyPatternsGreedily(getOperation().getOperation(),
                                      std::move(patterns)))) {
       return signalPassFailure();
@@ -107,18 +107,18 @@ struct TestMeshReshardingPass
     registry.insert<BuiltinDialect>();
   }
   StringRef getArgument() const final {
-    return "test-mesh-resharding-spmdization";
+    return "test-grid-resharding-partition";
   }
   StringRef getDescription() const final {
-    return "Test Mesh dialect resharding spmdization.";
+    return "Test Shard dialect resharding partition.";
   }
 };
 } // namespace
 
 namespace mlir {
 namespace test {
-void registerTestMeshReshardingSpmdizationPass() {
-  PassRegistration<TestMeshReshardingPass>();
+void registerTestReshardingPartitionPass() {
+  PassRegistration<TestReshardingPass>();
 }
 } // namespace test
 } // namespace mlir
diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
similarity index 60%
rename from mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
rename to mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
index 01e196d29f7a5..28852153f37f6 100644
--- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
@@ -7,8 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -16,23 +16,23 @@
 using namespace mlir;
 
 namespace {
-struct TestMeshSimplificationsPass
-    : public PassWrapper<TestMeshSimplificationsPass, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshSimplificationsPass)
+struct TestShardSimplificationsPass
+    : public PassWrapper<TestShardSimplificationsPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShardSimplificationsPass)
 
   void runOnOperation() override;
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithDialect, mesh::MeshDialect>();
+    registry.insert<arith::ArithDialect, shard::ShardDialect>();
   }
-  StringRef getArgument() const final { return "test-mesh-simplifications"; }
-  StringRef getDescription() const final { return "Test mesh simplifications"; }
+  StringRef getArgument() const final { return "test-grid-simplifications"; }
+  StringRef getDescription() const final { return "Test grid simplifications"; }
 };
 } // namespace
 
-void TestMeshSimplificationsPass::runOnOperation() {
+void TestShardSimplificationsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   SymbolTableCollection symbolTableCollection;
-  mesh::populateSimplificationPatterns(patterns, symbolTableCollection);
+  shard::populateSimplificationPatterns(patterns, symbolTableCollection);
   [[maybe_unused]] LogicalResult status =
       applyPatternsGreedily(getOperation(), std::move(patterns));
   assert(succeeded(status) && "Rewrite patters application did not converge.");
@@ -40,8 +40,8 @@ void TestMeshSimplificationsPass::runOnOperation() {
 
 namespace mlir {
 namespace test {
-void registerTestMeshSimplificationsPass() {
-  PassRegistration<TestMeshSimplificationsPass>();
+void registerTestShardSimplificationsPass() {
+  PassRegistration<TestShardSimplificationsPass>();
 }
 } // namespace test
 } // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 26d7597347a8a..6958fe3001b89 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -29,7 +29,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestMathToVCIX
     MLIRMemRefTestPasses
     MLIRTestMemRefToLLVMWithTransforms
-    MLIRMeshTest
+    MLIRShardTest
     MLIRNVGPUTestPasses
     MLIRSCFTestPasses
     MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 143a5e8e8f8dd..2c0975302e6a5 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -130,8 +130,8 @@ void registerTestIrdlTestDialectConversionPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
 void registerTestMemRefToLLVMWithTransforms();
-void registerTestMeshReshardingSpmdizationPass();
-void registerTestMeshSimplificationsPass();
+void registerTestReshardingPartitionPass();
+void registerTestShardSimplificationsPass();
 void registerTestMultiBuffering();
 void registerTestNextAccessPass();
 void registerTestNVGPULowerings();
@@ -276,8 +276,8 @@ void registerTestPasses() {
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
   mlir::test::registerTestMemRefToLLVMWithTransforms();
-  mlir::test::registerTestMeshReshardingSpmdizationPass();
-  mlir::test::registerTestMeshSimplificationsPass();
+  mlir::test::registerTestReshardingPartitionPass();
+  mlir::test::registerTestShardSimplificationsPass();
   mlir::test::registerTestMultiBuffering();
   mlir::test::registerTestNextAccessPass();
   mlir::test::registerTestNVGPULowerings();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9ec7c51da4065..6c528f8f7b6bd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3056,14 +3056,14 @@ cc_library(
 )
 
 ##---------------------------------------------------------------------------##
-# Mesh Dialect
+# Shard Dialect
 ##---------------------------------------------------------------------------##
 
 td_library(
-    name = "MeshTdFiles",
+    name = "ShardTdFiles",
     srcs = [
-        "include/mlir/Dialect/Mesh/IR/MeshBase.td",
-        "include/mlir/Dialect/Mesh/IR/MeshOps.td",
+        "include/mlir/Dialect/Shard/IR/ShardBase.td",
+        "include/mlir/Dialect/Shard/IR/ShardOps.td",
     ],
     includes = ["include"],
     deps = [
@@ -3075,92 +3075,92 @@ td_library(
 )
 
 gentbl_cc_library(
-    name = "MeshIncGen",
+    name = "ShardIncGen",
     tbl_outs = {
-        "include/mlir/Dialect/Mesh/IR/MeshOps.h.inc": [
+        "include/mlir/Dialect/Shard/IR/ShardOps.h.inc": [
             "-gen-op-decls",
-            "-dialect=mesh",
+            "-dialect=shard",
         ],
-        "include/mlir/Dialect/Mesh/IR/MeshOps.cpp.inc": [
+        "include/mlir/Dialect/Shard/IR/ShardOps.cpp.inc": [
             "-gen-op-defs",
-            "-dialect=mesh",
+            "-dialect=shard",
         ],
-        "include/mlir/Dialect/Mesh/IR/MeshDialect.h.inc": [
+        "include/mlir/Dialect/Shard/IR/ShardDialect.h.inc": [
             "-gen-dialect-decls",
-            "-dialect=mesh",
+            "-dialect=shard",
         ],
-        "include/mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc": [
+        "include/mlir/Dialect/Shard/IR/ShardDialect.cpp.inc": [
             "-gen-dialect-defs",
-            "-dialect=mesh",
+            "-dialect=shard",
         ],
-        "include/mlir/Dialect/Mesh/IR/MeshEnums.h.inc": [
+        "include/mlir/Dialect/Shard/IR/ShardEnums.h.inc": [
             "-gen-enum-decls",
-            "-dialect=mesh",
+            "-dialect=shard",
         ],
-        "include/mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc": [
+        "include/mlir/Dialect/Shard/IR/ShardEnums.cpp.inc": [
             "-gen-enum-defs",
-            "-dialect=mesh",
+            "-dialect=shard",
         ],
-        "include/mlir/Dialect/Mesh/IR/MeshAttributes.h.inc": [
+        "include/mlir/Dialect/Shard/IR/ShardAttributes.h.inc": [
             "-gen-attrdef-decls",
-            "-dialect=mesh",
+            "-dialect=shard",
         ],
-        "include/mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc": [
+        "include/mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc": [
             "-gen-attrdef-defs",
-            "-dialect=mesh",
+            "-dialect=shard",
         ],
-        "include/mlir/Dialect/Mesh/IR/MeshTypes.h.inc": [
+        "include/mlir/Dialect/Shard/IR/ShardTypes.h.inc": [
             "-gen-typedef-decls",
-            "-typedefs-dialect=mesh",
+            "-typedefs-dialect=shard",
         ],
-        "include/mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc": [
+        "include/mlir/Dialect/Shard/IR/ShardTypes.cpp.inc": [
             "-gen-typedef-defs",
-            "-typedefs-dialect=mesh",
+            "-typedefs-dialect=shard",
         ],
     },
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/Mesh/IR/MeshOps.td",
+    td_file = "include/mlir/Dialect/Shard/IR/ShardOps.td",
     deps = [
-        ":MeshTdFiles",
+        ":ShardTdFiles",
         ":ShapeOpsTdFiles",
     ],
 )
 
 gentbl_cc_library(
-    name = "MeshShardingInterfaceIncGen",
+    name = "ShardingInterfaceIncGen",
     tbl_outs = {
-        "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc": ["-gen-op-interface-decls"],
-        "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc": ["-gen-op-interface-defs"],
+        "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h.inc": ["-gen-op-interface-decls"],
+        "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc": ["-gen-op-interface-defs"],
     },
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td",
+    td_file = "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td",
     deps = [":OpBaseTdFiles"],
 )
 
 cc_library(
-    name = "MeshShardingInterface",
-    srcs = ["lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp"],
+    name = "ShardingInterface",
+    srcs = ["lib/Dialect/Shard/Interfaces/ShardingInterface.cpp"],
     hdrs = [
-        "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h",
-        "include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h",
+        "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h",
+        "include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h",
     ],
     includes = ["include"],
     deps = [
         ":DialectUtils",
         ":IR",
-        ":MeshDialect",
-        ":MeshShardingInterfaceIncGen",
+        ":ShardDialect",
+        ":ShardingInterfaceIncGen",
         ":Support",
         "//llvm:Support",
     ],
 )
 
 cc_library(
-    name = "MeshDialect",
-    srcs = ["lib/Dialect/Mesh/IR/MeshOps.cpp"],
+    name = "ShardDialect",
+    srcs = ["lib/Dialect/Shard/IR/ShardOps.cpp"],
     hdrs = [
-        "include/mlir/Dialect/Mesh/IR/MeshDialect.h",
-        "include/mlir/Dialect/Mesh/IR/MeshOps.h",
+        "include/mlir/Dialect/Shard/IR/ShardDialect.h",
+        "include/mlir/Dialect/Shard/IR/ShardOps.h",
     ],
     includes = ["include"],
     deps = [
@@ -3171,7 +3171,7 @@ cc_library(
         ":IR",
         ":InferTypeOpInterface",
         ":InliningUtils",
-        ":MeshIncGen",
+        ":ShardIncGen",
         ":SideEffectInterfaces",
         ":Support",
         ":ViewLikeInterface",
@@ -3180,23 +3180,23 @@ cc_library(
 )
 
 gentbl_cc_library(
-    name = "MeshTransformsPassIncGen",
-    tbl_outs = {"include/mlir/Dialect/Mesh/Transforms/Passes.h.inc": [
+    name = "ShardTransformsPassIncGen",
+    tbl_outs = {"include/mlir/Dialect/Shard/Transforms/Passes.h.inc": [
         "-gen-pass-decls",
-        "-name=Mesh",
+        "-name=Shard",
     ]},
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/Mesh/Transforms/Passes.td",
+    td_file = "include/mlir/Dialect/Shard/Transforms/Passes.td",
     deps = [":PassBaseTdFiles"],
 )
 
 cc_library(
-    name = "MeshTransforms",
+    name = "ShardTransforms",
     srcs = glob([
-        "lib/Dialect/Mesh/Transforms/*.cpp",
-        "lib/Dialect/Mesh/Transforms/*.h",
+        "lib/Dialect/Shard/Transforms/*.cpp",
+        "lib/Dialect/Shard/Transforms/*.h",
     ]),
-    hdrs = glob(["include/mlir/Dialect/Mesh/Transforms/*.h"]),
+    hdrs = glob(["include/mlir/Dialect/Shard/Transforms/*.h"]),
     includes = ["include"],
     deps = [
         ":AffineDialect",
@@ -3209,9 +3209,9 @@ cc_library(
         ":FuncDialect",
         ":FunctionInterfaces",
         ":IR",
-        ":MeshDialect",
-        ":MeshShardingInterface",
-        ":MeshTransformsPassIncGen",
+        ":ShardDialect",
+        ":ShardingInterface",
+        ":ShardTransformsPassIncGen",
         ":Pass",
         ":Support",
         ":TensorDialect",
@@ -3221,11 +3221,11 @@ cc_library(
 )
 
 cc_library(
-    name = "MeshToMPIConversion",
+    name = "ShardToMPIConversion",
     srcs = glob([
-        "lib/Conversion/MeshToMPI/*.cpp",
+        "lib/Conversion/ShardToMPI/*.cpp",
     ]),
-    hdrs = glob(["include/mlir/Conversion/MeshToMPI/*.h"]),
+    hdrs = glob(["include/mlir/Conversion/ShardToMPI/*.h"]),
     includes = ["include"],
     deps = [
         ":AffineDialect",
@@ -3240,8 +3240,8 @@ cc_library(
         ":LinalgDialect",
         ":MPIDialect",
         ":MemRefDialect",
-        ":MeshDialect",
-        ":MeshTransforms",
+        ":ShardDialect",
+        ":ShardTransforms",
         ":Pass",
         ":SCFDialect",
         ":Support",
@@ -3988,7 +3988,7 @@ cc_library(
         ":MemRefToEmitC",
         ":MemRefToLLVM",
         ":MemRefToSPIRV",
-        ":MeshToMPIConversion",
+        ":ShardToMPIConversion",
         ":NVGPUToNVVM",
         ":NVVMToLLVM",
         ":OpenACCToSCF",
@@ -4522,7 +4522,7 @@ cc_library(
         ":FuncDialect",
         ":IR",
         ":InliningUtils",
-        ":MeshShardingInterface",
+        ":ShardingInterface",
     ],
 )
 
@@ -4621,7 +4621,7 @@ cc_library(
         ":MemRefToEmitC",
         ":MemRefToLLVM",
         ":MemRefTransformOps",
-        ":MeshDialect",
+        ":ShardDialect",
         ":NVGPUTransformOps",
         ":NVVMTarget",
         ":NVVMToLLVM",
@@ -7194,7 +7194,7 @@ cc_library(
     includes = ["include"],
     deps = [
         ":IR",
-        ":MeshShardingInterface",
+        ":ShardingInterface",
         ":TensorDialect",
         "//llvm:Support",
     ],
@@ -9019,8 +9019,8 @@ cc_library(
         ":MemRefToSPIRV",
         ":MemRefTransformOps",
         ":MemRefTransforms",
-        ":MeshDialect",
-        ":MeshTransforms",
+        ":ShardDialect",
+        ":ShardTransforms",
         ":NVGPUDialect",
         ":NVGPUPassIncGen",
         ":NVGPUToNVVM",
@@ -9120,7 +9120,7 @@ cc_binary(
         "//mlir/test:TestMath",
         "//mlir/test:TestMathToVCIX",
         "//mlir/test:TestMemRef",
-        "//mlir/test:TestMesh",
+        "//mlir/test:TestShard",
         "//mlir/test:TestNVGPU",
         "//mlir/test:TestPDLL",
         "//mlir/test:TestPass",
@@ -9182,7 +9182,7 @@ cc_binary(
         "//mlir/test:TestMathToVCIX",
         "//mlir/test:TestMemRef",
         "//mlir/test:TestMemRefToLLVMWithTransforms",
-        "//mlir/test:TestMesh",
+        "//mlir/test:TestShard",
         "//mlir/test:TestNVGPU",
         "//mlir/test:TestPDLL",
         "//mlir/test:TestPass",
@@ -10548,7 +10548,7 @@ cc_library(
         ":LinalgStructuredOpsIncGen",
         ":MathDialect",
         ":MemRefDialect",
-        ":MeshShardingInterface",
+        ":ShardingInterface",
         ":Parser",
         ":SCFDialect",
         ":SideEffectInterfaces",
@@ -10699,9 +10699,9 @@ cc_library(
         ":MathDialect",
         ":MemRefDialect",
         ":MemRefTransforms",
-        ":MeshDialect",
-        ":MeshShardingInterface",
-        ":MeshTransforms",
+        ":ShardDialect",
+        ":ShardingInterface",
+        ":ShardTransforms",
         ":Pass",
         ":RuntimeVerifiableOpInterface",
         ":SCFDialect",
@@ -11198,8 +11198,8 @@ cc_library(
         ":InferTypeOpInterface",
         ":InliningUtils",
         ":LoopLikeInterface",
-        ":MeshDialect",
-        ":MeshShardingInterface",
+        ":ShardDialect",
+        ":ShardingInterface",
         ":Pass",
         ":QuantOps",
         ":SideEffectInterfaces",
@@ -12141,7 +12141,7 @@ cc_library(
         ":FuncTransforms",
         ":IR",
         ":MemRefDialect",
-        ":MeshShardingInterface",
+        ":ShardingInterface",
         ":Pass",
         ":SideEffectInterfaces",
         ":TensorDialect",
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 95e3ee4df7bc5..e7770fcc9eabd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -903,8 +903,8 @@ cc_library(
 )
 
 cc_library(
-    name = "TestMesh",
-    srcs = glob(["lib/Dialect/Mesh/**/*.cpp"]),
+    name = "TestShard",
+    srcs = glob(["lib/Dialect/Shard/**/*.cpp"]),
     includes = ["lib/Dialect/Test"],
     deps = [
         ":TestDialect",
@@ -912,8 +912,8 @@ cc_library(
         "//mlir:DialectUtils",
         "//mlir:FuncDialect",
         "//mlir:IR",
-        "//mlir:MeshDialect",
-        "//mlir:MeshTransforms",
+        "//mlir:ShardDialect",
+        "//mlir:ShardTransforms",
         "//mlir:Pass",
         "//mlir:SPIRVDialect",
         "//mlir:Support",

>From 81327bd1e64e143a760d4109addfa6ceca72e277 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 24 Jul 2025 11:29:19 +0200
Subject: [PATCH 2/2] mention rename mesh->shard in shard doc

---
 mlir/docs/Dialects/Shard.md | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/docs/Dialects/Shard.md b/mlir/docs/Dialects/Shard.md
index 714b340db4cde..153231d156edc 100644
--- a/mlir/docs/Dialects/Shard.md
+++ b/mlir/docs/Dialects/Shard.md
@@ -1,8 +1,10 @@
 # 'shard' Dialect
 
-The `shard` dialect contains a set of attributes, operations and interfaces that
-are useful for representing sharding and communication on a device grid
-cluster.
+This dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device grid.
+
+It was originally introduced under the name 'mesh' but was later renamed
+to better reflect its purpose.
 
 [TOC]
 



More information about the Mlir-commits mailing list