[Mlir-commits] [llvm] [mlir] [mlir][mesh] Add all-scatter operation (PR #81218)

Boian Petkantchin llvmlistbot at llvm.org
Tue Feb 13 14:35:20 PST 2024


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/81218

>From 5d1e4cdaf9b5722d8322adaf36b4dce1962381e3 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 7 Feb 2024 08:05:53 -0800
Subject: [PATCH 1/4] [mlir][mesh] Add all-scatter operation

This op is the inverse of all-gather. It is useful to have an explicit
concise representation instead of having a blob of slicing logic.

Add lowering for the op that slices from the tensor based on the in-group
process index.

Make resharding generate an all-scatter instead of inserting the slicing logic
directly.
---
 mlir/include/mlir/Dialect/Affine/Utils.h      |   7 +
 mlir/include/mlir/Dialect/Arith/Utils/Utils.h |  17 +++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |   6 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  63 ++++++++
 .../mlir/Dialect/Mesh/Transforms/Transforms.h |  19 ++-
 mlir/include/mlir/IR/Builders.h               |   3 +
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |  27 ++++
 mlir/lib/Dialect/Arith/Utils/Utils.cpp        |  19 +++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  64 ++++++++-
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |   2 +
 .../Mesh/Transforms/Simplifications.cpp       |  16 +--
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  84 ++---------
 .../Dialect/Mesh/Transforms/Transforms.cpp    | 136 ++++++++++++++++--
 .../Mesh/Transforms/TransformsDetail.h        |  35 +++++
 mlir/lib/IR/Builders.cpp                      |  18 +++
 .../Dialect/Mesh/all-scatter-op-lowering.mlir |  74 ++++++++++
 mlir/test/Dialect/Mesh/canonicalization.mlir  |  13 ++
 mlir/test/Dialect/Mesh/invalid.mlir           |  52 +++++++
 mlir/test/Dialect/Mesh/ops.mlir               |  24 ++++
 .../Dialect/Mesh/resharding-spmdization.mlir  |  31 +---
 mlir/test/Dialect/Mesh/spmdization.mlir       |  14 +-
 mlir/test/lib/Dialect/Mesh/CMakeLists.txt     |   2 +-
 mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp |  79 ++++++++++
 .../Mesh/TestProcessMultiIndexOpLowering.cpp  |  54 -------
 mlir/tools/mlir-opt/mlir-opt.cpp              |   4 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |   2 +
 26 files changed, 673 insertions(+), 192 deletions(-)
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
 create mode 100644 mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
 create mode 100644 mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
 delete mode 100644 mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp

diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index b3ccbff3002fb1..441ce1952210bb 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -15,12 +15,14 @@
 
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/OpDefinition.h"
 #include <optional>
 
 namespace mlir {
 class DominanceInfo;
 class Operation;
 class PostDominanceInfo;
+class ImplicitLocOpBuilder;
 
 namespace func {
 class FuncOp;
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
                                                ArrayRef<Value> basis);
+// Generate IR that extracts the linear index form a multi-index according to
+// a shape.
+OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+                                  ArrayRef<OpFoldResult> shape,
+                                  ImplicitLocOpBuilder &builder);
 
 /// Ensure that all operations that could be executed after `start`
 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 402bd196f0736a..2111a7c5810294 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Value.h"
+#include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
 
@@ -81,6 +82,22 @@ struct ArithBuilder {
   OpBuilder &b;
   Location loc;
 };
+
+namespace arith {
+
+// Build the product of a sequence.
+// If values = (v0, v1, ..., vn) than the returned
+// value is v0 * v1 * ... * vn.
+// All values must have the same type.
+//
+// The version without `resultType` must contain at least one element in values.
+// Then the result will have the same type as the elements in `values`.
+// If `values` is empty in the version with `resultType` returns 1 with type
+// `resultType`.
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+                    Type resultType);
+} // namespace arith
 } // namespace mlir
 
 #endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9154e6fd803102..fb9425b96e68e2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
   return res;
 }
 
+template <typename MeshAxesRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
+  return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
+                                    mesh.getShape());
+}
+
 // Get the size of a sharded dimension.
 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
   if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index da372706ec724c..7792aac784d4be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
 
   let builders = [
     OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
   let hasCanonicalizer = 1;
 }
 
+def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
+    Pure,
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
+  let description = [{
+    Scatter along the `scatter_axis` tensor axis.
+    This operation can be taught of as the inverse of all-gather.
+    Technically, it is not required that all processes have the same input tensor.
+    Each process will slice a piece of its local tensor based on its in-group device index.
+    The operation does not communicate data between devices. 
+
+    Example:
+    ```mlir
+    mesh.mesh @mesh0(shape = 2x2)
+    ...
+    %1 = mesh.all_scatter %0 on @mesh0 mesh_axes = [1] scatter_axis = 1
+      : tensor<2x4xi8> -> tensor<2x2xi8>
+    ```
+    Input:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+    Result:
+    ```
+    gather tensor
+    axis 1
+    ------------>
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    IndexAttr:$scatter_axis
+  ));
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis
+    attr-dict `:` type($input) `->` type($result)
+  }];
+  let hasCanonicalizer = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>,
+    OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>
+  ];
+}
+
 def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
     Pure,
     SameOperandsAndResultElementType,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index 10a965daac71b9..d398bdd65330b1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -9,17 +9,34 @@
 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+
 namespace mlir {
 class RewritePatternSet;
 class SymbolTableCollection;
 class DialectRegistry;
+class ImplicitLocOpBuilder;
 namespace mesh {
 
 void processMultiIndexOpLoweringPopulatePatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-
 void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);
 
+void allScatterOpLoweringPopulatePatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void allScatterOpLoweringRegisterDialects(DialectRegistry &registry);
+
+void populateAllOpLoweringPatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void registerAllOpLoweringDialects(DialectRegistry &registry);
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+                                 ImplicitLocOpBuilder &builder);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2fe1495b2b593b..43b6d2b3841690 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -118,6 +118,9 @@ class Builder {
   // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
   // ranked tensor of them. Returns null attribute otherwise.
   TypedAttr getZeroAttr(Type type);
+  // Returns a 1-valued attribute of the given `type`.
+  // Type constraints are the same as `getZeroAttr`.
+  TypedAttr getOneAttr(Type type);
 
   // Convenience methods for fixed types.
   FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d4adb94a9fc8d..41bcd2b9f33e6b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -20,9 +20,11 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <optional>
@@ -1869,3 +1871,28 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
   results.push_back(residual);
   return results;
 }
+
+OpFoldResult
+mlir::affine::linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+                                   ArrayRef<OpFoldResult> shape,
+                                   ImplicitLocOpBuilder &builder) {
+  assert(multiIndex.size() == shape.size());
+  SmallVector<AffineExpr> shapeAffine;
+  for (size_t i = 0; i < shape.size(); ++i) {
+    shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+  }
+
+  SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+  SmallVector<OpFoldResult> strides;
+  strides.reserve(stridesAffine.size());
+  llvm::transform(stridesAffine, std::back_inserter(strides),
+                  [&builder, &shape](AffineExpr strideExpr) {
+                    return affine::makeComposedFoldedAffineApply(
+                        builder, builder.getLoc(), strideExpr, shape);
+                  });
+
+  auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
+      OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
+  return affine::makeComposedFoldedAffineApply(
+      builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bf274d4ae27ed8..999cdbc5c10008 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include <numeric>
 
 using namespace mlir;
 
@@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
 }
+
+namespace mlir::arith {
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
+  return createProduct(builder, loc, values, values.front().getType());
+}
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+                    Type resultType) {
+  Value one = builder.create<ConstantOp>(loc, resultType,
+                                         builder.getOneAttr(resultType));
+  ArithBuilder arithBuilder(builder, loc);
+  return std::accumulate(
+      values.begin(), values.end(), one,
+      [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
+}
+
+} // namespace mlir::arith
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index a65b8f2e5a2376..762725d2c56e66 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -252,14 +252,20 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                         MeshOp mesh) {
+  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+}
+
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        MeshOp mesh, ArrayRef<MeshAxis> axes) {
   build(odsBuilder, odsState,
-        SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
-        mesh.getSymName(),
-        MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
+        SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+                          odsBuilder.getIndexType()),
+        mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                         StringRef mesh, ArrayRef<MeshAxis> axes) {
+  assert(!axes.empty());
   build(odsBuilder, odsState,
         SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -590,6 +596,22 @@ static LogicalResult verifyScatterOperandAndResultShape(
   return success();
 }
 
+static RankedTensorType scatterResultType(Type operandType, MeshOp mesh,
+                                          ArrayRef<MeshAxis> meshAxes,
+                                          int64_t scatterAxis) {
+  RankedTensorType operandRankedTensorType =
+      operandType.cast<RankedTensorType>();
+  DimensionSize operandScatterAxisSize =
+      operandRankedTensorType.getShape()[scatterAxis];
+  SmallVector<int64_t> resultShape =
+      llvm::to_vector(operandRankedTensorType.getShape());
+
+  resultShape[scatterAxis] =
+      operandScatterAxisSize /
+      DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+  return operandRankedTensorType.clone(resultShape);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_gather op
 //===----------------------------------------------------------------------===//
@@ -625,6 +647,42 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.all_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyScatterOperandAndResultShape(
+      getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().getShape());
+}
+
+void AllScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                               MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllScatterOp>>(context);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                         Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+                         int64_t scatterAxis) {
+  Type resultType =
+      scatterResultType(input.getType(), mesh, meshAxes, scatterAxis);
+  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+        scatterAxis);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                         Type resultType, Value input, StringRef mesh,
+                         ArrayRef<MeshAxis> meshAxes, int64_t scatterAxis) {
+  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+        APInt(sizeof(scatterAxis) * CHAR_BIT, scatterAxis));
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_to_all op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index dccb75848c94f0..28af820440076c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -13,7 +13,9 @@ add_mlir_dialect_library(MLIRMeshTransforms
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect
+  MLIRAffineUtils
   MLIRArithDialect
+  MLIRArithUtils
   MLIRControlFlowDialect
   MLIRFuncDialect
   MLIRIR
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c0273cdaef7144..7fcac2312444f3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "TransformsDetail.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -16,7 +17,6 @@
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include <iterator>
 #include <numeric>
 #include <utility>
 
@@ -56,13 +56,10 @@ namespace {
 // symbol tables.
 // We can't use DialectFoldInterface since the cache may be invalidated by some
 // pass changing the referenced MeshOp ops.
-struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
-  template <typename... OpRewritePatternArgs>
-  MeshShapeFolder(SymbolTableCollection &symbolTableCollection,
-                  OpRewritePatternArgs &&...opRewritePatternArgs)
-      : OpRewritePattern(
-            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
-        symbolTableCollection(symbolTableCollection) {}
+struct MeshShapeFolder
+    : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+  using OpRewritePatternWithSymbolTableCollection::
+      OpRewritePatternWithSymbolTableCollection;
   LogicalResult matchAndRewrite(MeshShapeOp op,
                                 PatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
@@ -113,9 +110,6 @@ struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
 
     return success();
   }
-
-private:
-  SymbolTableCollection &symbolTableCollection;
 };
 
 } // namespace
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b649157a9e46de..21e03d9572c590 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -128,92 +128,24 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
 }
 
-static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
-                                             int64_t splitTensorAxis,
-                                             int64_t splitCount) {
-  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
-  targetShape[splitTensorAxis] =
-      shardDimension(targetShape[splitTensorAxis], splitCount);
-  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
-}
-
 // Split a replicated tensor along a mesh axis.
 // e.g. [[0, 1]] -> [[0, 1, 2]].
 // Returns the spmdized target value with its sharding.
-//
-// The implementation is the extract the tensor slice corresponding
-// to the current device.
 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
 splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           MeshShardingAttr sourceSharding,
                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
-  MLIRContext *ctx = builder.getContext();
-  builder.setInsertionPointAfterValue(sourceShard);
-
-  Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
-
-  Value processIndexAlongAxis =
+  TypedValue<ShapedType> targetShard =
       builder
-          .create<ProcessMultiIndexOp>(mesh.getSymName(),
-                                       SmallVector<MeshAxis>({splitMeshAxis}))
-          .getResult()[0];
-
+          .create<AllScatterOp>(sourceShard, mesh,
+                                ArrayRef<MeshAxis>(splitMeshAxis),
+                                splitTensorAxis)
+          .getResult()
+          .cast<TypedValue<ShapedType>>();
   MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
-      ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
-  ShapedType targetShape = targetShapeInSplitLastAxis(
-      sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
-
-  Value meshAxisSize =
-      builder
-          .create<MeshShapeOp>(mesh.getSymName(),
-                               SmallVector<MeshAxis>({splitMeshAxis}))
-          .getResult()[0];
-
-  Value sourceAxisSize =
-      builder.create<tensor::DimOp>(sourceShard, splitTensorAxis);
-  Value sourceAxisSizeModMeshAxisSize =
-      builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize);
-  Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
-      arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero);
-  builder.create<cf::AssertOp>(
-      isTargetShapeExactlyDivisible,
-      "Sharding a tensor with axis size that is not exactly divisible by the "
-      "mesh axis size is not supported.");
-  Value targetAxisSize =
-      builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize);
-  Value axisOffset =
-      builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis);
-  SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0);
-  staticOffsets[splitTensorAxis] = ShapedType::kDynamic;
-  DenseI64ArrayAttr staticOffsetsAttr =
-      DenseI64ArrayAttr::get(ctx, staticOffsets);
-  SmallVector<Value> dynamicOffsets(1, axisOffset);
-
-  DenseI64ArrayAttr staticSizesAttr =
-      DenseI64ArrayAttr::get(ctx, targetShape.getShape());
-  SmallVector<Value> dynamicSizes;
-  for (int64_t i = 0; i < targetShape.getRank(); ++i) {
-    if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) {
-      if (i == splitTensorAxis) {
-        dynamicSizes.push_back(targetAxisSize);
-      } else {
-        Value dimSize = builder.create<tensor::DimOp>(sourceShard, i);
-        dynamicSizes.push_back(dimSize);
-      }
-    }
-  }
-
-  DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get(
-      ctx, SmallVector<int64_t>(targetShape.getRank(), 1));
-  TypedValue<RankedTensorType> targetShard =
-      builder
-          .create<tensor::ExtractSliceOp>(
-              targetShape, sourceShard, dynamicOffsets, dynamicSizes,
-              SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr,
-              staticStridesAttr)
-          .getResult();
-  return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding};
+      builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
+  return {targetShard, targetSharding};
 }
 
 // Detect if the resharding is of type e.g.
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 03b1d9b3498028..95739c18ce0f36 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -7,12 +7,21 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "TransformsDetail.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectRegistry.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/STLExtras.h"
@@ -26,18 +35,14 @@ namespace {
 
 /// Lower `mesh.process_multi_index` into expression using
 /// `mesh.process_linear_index` and `mesh.mesh_shape`.
-struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
-  template <typename... OpRewritePatternArgs>
-  ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
-                              OpRewritePatternArgs &&...opRewritePatternArgs)
-      : OpRewritePattern(
-            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
-        symbolTableCollection(symbolTableCollection) {}
+struct ProcessMultiIndexOpLowering
+    : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
+  using OpRewritePatternWithSymbolTableCollection::
+      OpRewritePatternWithSymbolTableCollection;
 
   LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
-        op.getOperation(), op.getMeshAttr());
+    MeshOp mesh = getMesh(op, symbolTableCollection);
     if (!mesh) {
       return failure();
     }
@@ -46,6 +51,7 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
     builder.setInsertionPointAfter(op.getOperation());
     Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
     ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
+    llvm::errs() << "meshShape.size() = " << meshShape.size() << "\n";
     SmallVector<Value> completeMultiIndex =
         builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
             .getMultiIndex();
@@ -64,9 +70,83 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
     rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
     return success();
   }
+};
+
+struct AllScatterOpLowering
+    : OpRewritePatternWithSymbolTableCollection<AllScatterOp> {
+  using OpRewritePatternWithSymbolTableCollection::
+      OpRewritePatternWithSymbolTableCollection;
+
+  LogicalResult matchAndRewrite(AllScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    MeshOp mesh = getMesh(op, symbolTableCollection);
+    if (!mesh) {
+      return failure();
+    }
+
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    builder.setInsertionPointAfter(op.getOperation());
+
+    Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+    Operation::result_range processInGroupMultiIndex =
+        builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
+            .getResults();
+
+    Operation::result_range processGroupShape =
+        builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
+            .getResult();
+    Value processGroupSize =
+        createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
+
+    int64_t scatterAxis = op.getScatterAxis().getSExtValue();
+    Value operandScatterAxisSize =
+        builder.create<tensor::DimOp>(op.getOperand(), scatterAxis);
+    Value operandScatterAxisSizeModProcessGroupSize =
+        builder.create<arith::RemUIOp>(operandScatterAxisSize,
+                                       processGroupSize);
+    Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+        arith::CmpIPredicate::eq, operandScatterAxisSizeModProcessGroupSize,
+        zero);
+    builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
+                                 "Scattering a tensor with axis size that is "
+                                 "not exactly divisible by the "
+                                 "mesh process group size is not supported.");
+    Value resultScatterAxisSize = builder.create<arith::DivUIOp>(
+        operandScatterAxisSize, processGroupSize);
+    OpFoldResult processInGroupLinearIndex = affine::linearIndexFromShape(
+        llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
+        llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
+
+    // extract slice
+    RankedTensorType operandType =
+        op.getOperand().getType().cast<RankedTensorType>();
+    SmallVector<OpFoldResult> sizes;
+    for (int64_t i = 0; i < operandType.getRank(); ++i) {
+      if (i == scatterAxis) {
+        sizes.emplace_back(resultScatterAxisSize);
+      } else {
+        Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
+        sizes.emplace_back(dimSize);
+      }
+    }
+    SmallVector<OpFoldResult> offsets(
+        operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
+    offsets[scatterAxis] =
+        ArithBuilder(builder, builder.getLoc())
+            .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
+                                                 processInGroupLinearIndex),
+                 resultScatterAxisSize);
+    SmallVector<OpFoldResult> strides(
+        operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
+    Value slice = builder.create<tensor::ExtractSliceOp>(
+        op.getOperand(), offsets, sizes, strides);
+    Value newResult =
+        builder.create<tensor::CastOp>(op.getResult().getType(), slice);
+    rewriter.replaceAllUsesWith(op.getResult(), newResult);
 
-private:
-  SymbolTableCollection &symbolTableCollection;
+    return success();
+  }
 };
 
 } // namespace
@@ -81,4 +161,38 @@ void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry) {
   registry.insert<affine::AffineDialect, mesh::MeshDialect>();
 }
 
+void allScatterOpLoweringPopulatePatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
+  patterns.add<AllScatterOpLowering>(symbolTableCollection,
+                                     patterns.getContext());
+}
+
+void allScatterOpLoweringRegisterDialects(DialectRegistry &registry) {
+  registry.insert<affine::AffineDialect, arith::ArithDialect,
+                  cf::ControlFlowDialect, mesh::MeshDialect,
+                  tensor::TensorDialect>();
+}
+
+void populateAllOpLoweringPatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
+  processMultiIndexOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+  allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+}
+
+void registerAllOpLoweringDialects(DialectRegistry &registry) {
+  processMultiIndexOpLoweringRegisterDialects(registry);
+  allScatterOpLoweringRegisterDialects(registry);
+}
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+                                 ImplicitLocOpBuilder &builder) {
+  Operation::result_range meshShape =
+      builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
+  return arith::createProduct(builder, builder.getLoc(),
+                              llvm::to_vector_of<Value>(meshShape),
+                              builder.getIndexType())
+      .cast<TypedValue<IndexType>>();
+}
+
 } // namespace mlir::mesh
diff --git a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h b/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
new file mode 100644
index 00000000000000..3e3f584caca247
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
@@ -0,0 +1,35 @@
+//===- TransformsDetail.h - -------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+namespace mesh {
+
+template <typename Op>
+struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> {
+  template <typename... OpRewritePatternArgs>
+  OpRewritePatternWithSymbolTableCollection(
+      SymbolTableCollection &symbolTableCollection,
+      OpRewritePatternArgs &&...opRewritePatternArgs)
+      : OpRewritePattern<Op>(
+            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+        symbolTableCollection(symbolTableCollection) {}
+
+protected:
+  SymbolTableCollection &symbolTableCollection;
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2e42c4e870716b..18ca3c332e0204 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -346,6 +346,24 @@ TypedAttr Builder::getZeroAttr(Type type) {
   return {};
 }
 
+TypedAttr Builder::getOneAttr(Type type) {
+  if (llvm::isa<FloatType>(type))
+    return getFloatAttr(type, 1.0);
+  if (llvm::isa<IndexType>(type))
+    return getIndexAttr(1);
+  if (llvm::dyn_cast<IntegerType>(type))
+    return getIntegerAttr(type,
+                          APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
+  if (llvm::isa<RankedTensorType, VectorType>(type)) {
+    auto vtType = llvm::cast<ShapedType>(type);
+    auto element = getOneAttr(vtType.getElementType());
+    if (!element)
+      return {};
+    return DenseElementsAttr::get(vtType, element);
+  }
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Affine Expressions, Affine Maps, and Integer Sets.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
new file mode 100644
index 00000000000000..8aa3fc5160a001
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt --split-input-file --test-mesh-all-scatter-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
+
+mesh.mesh @mesh_1d(shape = ?)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
+func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
+  // CHECK: %[[ARG:.*]]: tensor<?xf16>
+  %arg0: tensor<?xf16>
+// CHECK-SAME: -> tensor<?xf16> {
+) -> tensor<?xf16> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+  // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+  // CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor<?xf16>
+  // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+  // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
+  // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
+  // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+  // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16>
+  %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<?xf16> -> tensor<?xf16>
+  // CHECK: return %[[RESULT]] : tensor<?xf16>
+  return %0 : tensor<?xf16>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh
+func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
+  // CHECK: %[[ARG:.*]]: tensor<2xf16>
+  %arg0: tensor<2xf16>
+// CHECK-SAME: -> tensor<1xf16> {
+) -> tensor<1xf16> {
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+  // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16>
+  // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16>
+  %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<2xf16> -> tensor<1xf16>
+  // CHECK: return %[[RESULT]] : tensor<1xf16>
+  return %0 : tensor<1xf16>
+}
+
+// -----
+
+// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+
+mesh.mesh @mesh_4d(shape = ?x?x?x?)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
+func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
+  // CHECK: %[[ARG:.*]]: tensor<?x?xf16>
+  %arg0 : tensor<?x?xf16>
+// CHECK-SAME: -> tensor<?x?xf16> {
+) -> tensor<?x?xf16> {
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index
+  // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index
+  // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
+  // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
+  // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+  // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
+  // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
+  // CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+  // CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1]
+  // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
+  // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
+  %0 = mesh.all_scatter %arg0 on @mesh_4d mesh_axes = [3, 1] scatter_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
+  // CHECK: return %[[RESULT]] : tensor<?x?xf16>
+  return %0 : tensor<?x?xf16>
+}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 23c5b253b4c073..4fa9c81f4cd004 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,6 +63,19 @@ func.func @all_gather_empty_mesh_axes(
   return %0 : tensor<4xf32>
 }
 
+// CHECK-LABEL: func @all_scatter_empty_mesh_axes
+func.func @all_scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.scatter
+  %0 = mesh.all_scatter %arg0 on @mesh0
+    mesh_axes = []
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
 // CHECK-LABEL: func @broadcast_empty_mesh_axes
 func.func @broadcast_empty_mesh_axes(
 // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3fa3ebd67b15e7..590d548d9c769e 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -316,6 +316,58 @@ func.func @all_gather_invalid_negative_gather_axis(
 
 // -----
 
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_duplicate_mesh_axis(
+    %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0, 0]
+    scatter_axis = 0
+    : tensor<?xf32> -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<2xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+  %0 = mesh.all_scatter %arg0 on @mesh0
+    scatter_axis = 0
+    : tensor<?xf32> -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_static_dimension_size(
+    %arg0 : tensor<3xf32>) -> tensor<2xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
+    scatter_axis = 0
+    : tensor<3xf32> -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_operand_static_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<?xf32> {
+  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
 func.func @all_to_all_invalid_mesh_symbol(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 40a8469b264643..69e8545f3cb86f 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -208,6 +208,30 @@ func.func @all_gather_dynamic_dims_in_mesh(
   return %0 : tensor<5x?xf32>
 }
 
+// CHECK-LABEL: func @all_scatter_static_dimensions
+func.func @all_scatter_static_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
+  // CHECK-NEXT: mesh.all_scatter %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [2] scatter_axis = 1
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
+  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [2] scatter_axis = 1
+    : tensor<3x4xf32> -> tensor<3x1xf32>
+  return %0 : tensor<3x1xf32>
+}
+
+// CHECK-LABEL: func @all_scatter_dynamic_dimensions
+func.func @all_scatter_dynamic_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+    %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK-NEXT: mesh.all_scatter %[[ARG]]
+  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+  // CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
+  %0 = mesh.all_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+    : tensor<?xf32> -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
 // CHECK-LABEL: func @all_to_all
 func.func @all_to_all(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index f1d686135c28eb..fdf9905a62e99c 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -19,17 +19,9 @@ func.func @split_replicated_tensor_axis(
   // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
   %arg0: tensor<3x14xf32>
 ) -> tensor<3x14xf32> {
-  // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
-  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
-  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
-  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
-  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
-  // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
-  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
-  // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
-  // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32>
-  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+  // CHECK: %[[ALL_SCATTER:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 1
+  // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
+  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SCATTER]] : tensor<3x7xf32> to tensor<3x14xf32>
   %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
   %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
   // CHECK: return %[[RESULT]] : tensor<3x14xf32>
@@ -41,22 +33,11 @@ func.func @split_replicated_tensor_axis_dynamic(
   // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
   %arg0: tensor<?x3x?xf32>
 ) -> tensor<?x3x?xf32> {
-  // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index
-  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index
-  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d_dynamic axes = [0] : index
-  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
-  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
-  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
-  // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
-  // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
-  // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
-  // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor<?x3x?xf32>
-  // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0]
-  // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor<?x3x?xf32> to tensor<?x3x?xf32>
+  // CHECK: %[[RESULT:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] scatter_axis = 0
+  // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
   %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
   %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
-  // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor<?x3x?xf32>
+  // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
   return %1 : tensor<?x3x?xf32>
 }
 
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 9993c1518e9eae..3aeea169107d66 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -63,9 +63,8 @@ func.func @unary_elementwise_with_resharding(
   %arg0: tensor<2xi8>
 // CHECK-SAME: -> tensor<2xi8> {
 ) -> tensor<2xi8> {
-  // We don't care about the whole resharding IR, just that it happens.
-  // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
-  // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+  // CHECK: %[[SLICE:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
   %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
   %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
   // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
@@ -109,9 +108,8 @@ func.func @multiple_chained_ops(
   %arg0: tensor<2xi8>
 // CHECK-SAME: -> tensor<1xi8> {
 ) -> tensor<2xi8> {
-  // We don't care about the whole resharding IR, just that it happens.
-  // CHECK: %[[RESHARD1:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
-  // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+  // CHECK: %[[RESHARD1:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
   %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
   %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
   // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
@@ -122,8 +120,8 @@ func.func @multiple_chained_ops(
   %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
   // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
   %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
-  // CHECK: %[[RESHARD3:.*]] = tensor.extract_slice %[[ABS2]][%{{.*}}] [1] [1]
-  // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+  // CHECK: %[[RESHARD3:.*]] = mesh.all_scatter %[[ABS2]] on @mesh_1d mesh_axes = [0] scatter_axis = 0 : 
+  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
   %6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8>
   %7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
   // CHECK: return %[[RESHARD3]] : tensor<1xi8>
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index 00931e6c94fc5f..07e9bb6f9f238a 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,6 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMeshTest
-  TestProcessMultiIndexOpLowering.cpp
+  TestOpLowering.cpp
   TestReshardingSpmdization.cpp
   TestSimplifications.cpp
 
diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
new file mode 100644
index 00000000000000..2c9131e666ed69
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
@@ -0,0 +1,79 @@
+//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestAllScatterOpLoweringPass
+    : public PassWrapper<TestAllScatterOpLoweringPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllScatterOpLoweringPass)
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    SymbolTableCollection symbolTableCollection;
+    mesh::allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+    LogicalResult status =
+        applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    (void)status;
+    assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    mesh::allScatterOpLoweringRegisterDialects(registry);
+  }
+  StringRef getArgument() const final {
+    return "test-mesh-all-scatter-op-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Test lowering of all-scatter.";
+  }
+};
+
+struct TestMultiIndexOpLoweringPass
+    : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    SymbolTableCollection symbolTableCollection;
+    mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
+                                                      symbolTableCollection);
+    LogicalResult status =
+        applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    (void)status;
+    assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    mesh::processMultiIndexOpLoweringRegisterDialects(registry);
+  }
+  StringRef getArgument() const final {
+    return "test-mesh-process-multi-index-op-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Test lowering of mesh.process_multi_index op.";
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestOpLoweringPasses() {
+  PassRegistration<TestAllScatterOpLoweringPass>();
+  PassRegistration<TestMultiIndexOpLoweringPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
deleted file mode 100644
index 0bcc403a2734ed..00000000000000
--- a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
+++ /dev/null
@@ -1,54 +0,0 @@
-//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-
-namespace {
-struct TestMultiIndexOpLoweringPass
-    : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
-
-  void runOnOperation() override;
-  void getDependentDialects(DialectRegistry &registry) const override {
-    mesh::processMultiIndexOpLoweringRegisterDialects(registry);
-  }
-  StringRef getArgument() const final {
-    return "test-mesh-process-multi-index-op-lowering";
-  }
-  StringRef getDescription() const final {
-    return "Test lowering of mesh.process_multi_index op.";
-  }
-};
-} // namespace
-
-void TestMultiIndexOpLoweringPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  SymbolTableCollection symbolTableCollection;
-  mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
-                                                    symbolTableCollection);
-  LogicalResult status =
-      applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-  (void)status;
-  assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
-}
-
-namespace mlir {
-namespace test {
-void registerTestMultiIndexOpLoweringPass() {
-  PassRegistration<TestMultiIndexOpLoweringPass>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428bdd9691e095..a94f42927f3f3b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,7 +120,7 @@ void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
 void registerTestMeshSimplificationsPass();
 void registerTestMeshReshardingSpmdizationPass();
-void registerTestMultiIndexOpLoweringPass();
+void registerTestOpLoweringPasses();
 void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
@@ -241,7 +241,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
-  mlir::test::registerTestMultiIndexOpLoweringPass();
+  mlir::test::registerTestOpLoweringPasses();
   mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestMeshReshardingSpmdizationPass();
   mlir::test::registerTestNextAccessPass();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9f2c1f1120066b..8943e91c1f0ceb 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3314,7 +3314,9 @@ cc_library(
     includes = ["include"],
     deps = [
         ":AffineDialect",
+        ":AffineUtils",
         ":ArithDialect",
+        ":ArithUtils",
         ":ControlFlowDialect",
         ":DialectUtils",
         ":FuncDialect",

>From 8cbbb177dc5ba8803e4b068672dc679ef230fe80 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 8 Feb 2024 19:10:01 -0800
Subject: [PATCH 2/4] Remove unused resharding dependencies

---
 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 6 +-----
 1 file changed, 1 insertion(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 21e03d9572c590..cd0481d7b87aec 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -8,9 +8,6 @@
 
 #include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
@@ -519,8 +516,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
 }
 
 void reshardingRegisterDependentDialects(DialectRegistry &registry) {
-  registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
-                  cf::ControlFlowDialect>();
+  registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
 }
 
 #define GEN_PASS_DEF_SPMDIZATION

>From 73cffb0c1cb1d9a4e02420ff12c3ee9cbcd251ba Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Tue, 13 Feb 2024 13:33:56 -0800
Subject: [PATCH 3/4] Address PR comments

---
 mlir/include/mlir/Dialect/Affine/Utils.h      | 10 +--
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 18 ++---
 .../mlir/Dialect/Mesh/Transforms/Transforms.h |  8 +-
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       | 21 +++---
 mlir/lib/Dialect/Arith/Utils/Utils.cpp        |  2 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 45 ++++++------
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  6 +-
 .../Dialect/Mesh/Transforms/Transforms.cpp    | 73 +++++++++++--------
 .../Dialect/Mesh/all-scatter-op-lowering.mlir | 20 ++---
 mlir/test/Dialect/Mesh/canonicalization.mlir  |  8 +-
 mlir/test/Dialect/Mesh/invalid.mlir           | 30 ++++----
 mlir/test/Dialect/Mesh/ops.mlir               | 20 ++---
 .../Dialect/Mesh/resharding-spmdization.mlir  |  6 +-
 mlir/test/Dialect/Mesh/spmdization.mlir       |  6 +-
 mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp | 20 ++---
 15 files changed, 152 insertions(+), 141 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 441ce1952210bb..9f37528c27046e 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -311,11 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
                                                ArrayRef<Value> basis);
-// Generate IR that extracts the linear index form a multi-index according to
-// a shape.
-OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
-                                  ArrayRef<OpFoldResult> shape,
-                                  ImplicitLocOpBuilder &builder);
+// Generate IR that extracts the linear index from a multi-index according to
+// a basis/shape.
+OpFoldResult linearIndex(ArrayRef<OpFoldResult> multiIndex,
+                         ArrayRef<OpFoldResult> basis,
+                         ImplicitLocOpBuilder &builder);
 
 /// Ensure that all operations that could be executed after `start`
 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 7792aac784d4be..96636d5347ff6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -342,15 +342,15 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
   let hasCanonicalizer = 1;
 }
 
-def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
+def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
     Pure,
     SameOperandsAndResultElementType,
     SameOperandsAndResultRank
   ]> {
-  let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
+  let summary = "All-slice over a device mesh. This is the inverse of all-gather.";
   let description = [{
-    Scatter along the `scatter_axis` tensor axis.
-    This operation can be taught of as the inverse of all-gather.
+    Slice along the `slice_axis` tensor axis.
+    This operation can be thought of as the inverse of all-gather.
     Technically, it is not required that all processes have the same input tensor.
     Each process will slice a piece of its local tensor based on its in-group device index.
     The operation does not communicate data between devices. 
@@ -359,7 +359,7 @@ def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
     ```mlir
     mesh.mesh @mesh0(shape = 2x2)
     ...
-    %1 = mesh.all_scatter %0 on @mesh0 mesh_axes = [1] scatter_axis = 1
+    %1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
       : tensor<2x4xi8> -> tensor<2x2xi8>
     ```
     Input:
@@ -388,19 +388,19 @@ def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    IndexAttr:$scatter_axis
+    IndexAttr:$slice_axis
   ));
   let results = (outs
     AnyNon0RankedTensor:$result
   );
   let assemblyFormat = [{
-    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasCanonicalizer = 1;
   let builders = [
-    OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>,
-    OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>
+    OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>,
+    OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index d398bdd65330b1..aeab28961a4e1e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -21,13 +21,13 @@ class DialectRegistry;
 class ImplicitLocOpBuilder;
 namespace mesh {
 
-void processMultiIndexOpLoweringPopulatePatterns(
+void populateProcessMultiIndexOpLoweringPatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);
+void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry);
 
-void allScatterOpLoweringPopulatePatterns(
+void populateAllSliceOpLoweringPatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-void allScatterOpLoweringRegisterDialects(DialectRegistry &registry);
+void registerAllSliceOpLoweringDialects(DialectRegistry &registry);
 
 void populateAllOpLoweringPatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 41bcd2b9f33e6b..3afecd747f401c 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1872,23 +1872,22 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
   return results;
 }
 
-OpFoldResult
-mlir::affine::linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
-                                   ArrayRef<OpFoldResult> shape,
-                                   ImplicitLocOpBuilder &builder) {
-  assert(multiIndex.size() == shape.size());
-  SmallVector<AffineExpr> shapeAffine;
-  for (size_t i = 0; i < shape.size(); ++i) {
-    shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+OpFoldResult mlir::affine::linearIndex(ArrayRef<OpFoldResult> multiIndex,
+                                       ArrayRef<OpFoldResult> basis,
+                                       ImplicitLocOpBuilder &builder) {
+  assert(multiIndex.size() == basis.size());
+  SmallVector<AffineExpr> basisAffine;
+  for (size_t i = 0; i < basis.size(); ++i) {
+    basisAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
   }
 
-  SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+  SmallVector<AffineExpr> stridesAffine = computeStrides(basisAffine);
   SmallVector<OpFoldResult> strides;
   strides.reserve(stridesAffine.size());
   llvm::transform(stridesAffine, std::back_inserter(strides),
-                  [&builder, &shape](AffineExpr strideExpr) {
+                  [&builder, &basis](AffineExpr strideExpr) {
                     return affine::makeComposedFoldedAffineApply(
-                        builder, builder.getLoc(), strideExpr, shape);
+                        builder, builder.getLoc(), strideExpr, basis);
                   });
 
   auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 999cdbc5c10008..aa239f5e053969 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -280,4 +280,4 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
       [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
 }
 
-} // namespace mlir::arith
\ No newline at end of file
+} // namespace mlir::arith
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 762725d2c56e66..91165765245c6d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -558,7 +558,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
   return success();
 }
 
-static LogicalResult verifyScatterOperandAndResultShape(
+static LogicalResult verifyScatterOrSliceOperandAndResultShape(
     Value operand, Value result, int64_t scatterAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -582,7 +582,7 @@ static LogicalResult verifyScatterOperandAndResultShape(
     return emitError(result.getLoc())
            << "Operand dimension size " << int64_t(operandScatterDimSize)
            << " is not divisible by collective device group size "
-           << int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis
+           << int64_t(deviceGroupSize) << " for tensor axis " << scatterAxis
            << ".";
   }
   DimensionSize expectedResultScatterDimSize =
@@ -648,39 +648,38 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.all_scatter op
+// mesh.all_slice op
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-AllScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
   if (failed(mesh)) {
     return failure();
   }
-  return verifyScatterOperandAndResultShape(
-      getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+  return verifyScatterOrSliceOperandAndResultShape(
+      getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
       mesh.value().getShape());
 }
 
-void AllScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
-                                               MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllScatterOp>>(context);
+void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                             MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
 }
 
-void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                         Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
-                         int64_t scatterAxis) {
+void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                       Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+                       int64_t sliceAxis) {
   Type resultType =
-      scatterResultType(input.getType(), mesh, meshAxes, scatterAxis);
+      scatterResultType(input.getType(), mesh, meshAxes, sliceAxis);
   build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
-        scatterAxis);
+        sliceAxis);
 }
 
-void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                         Type resultType, Value input, StringRef mesh,
-                         ArrayRef<MeshAxis> meshAxes, int64_t scatterAxis) {
+void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                       Type resultType, Value input, StringRef mesh,
+                       ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
   build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
-        APInt(sizeof(scatterAxis) * CHAR_BIT, scatterAxis));
+        APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
 }
 
 //===----------------------------------------------------------------------===//
@@ -810,7 +809,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
 
-  return verifyScatterOperandAndResultShape(
+  return verifyScatterOrSliceOperandAndResultShape(
       getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
       mesh.value().getShape());
 }
@@ -836,9 +835,9 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   }
 
   auto scatterAxis = getScatterAxis().getSExtValue();
-  return verifyScatterOperandAndResultShape(getInput(), getResult(),
-                                            scatterAxis, getMeshAxes(),
-                                            mesh.value().getShape());
+  return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
+                                                   scatterAxis, getMeshAxes(),
+                                                   mesh.value().getShape());
 }
 
 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index cd0481d7b87aec..7cbe0de048769b 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -135,9 +135,9 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
   TypedValue<ShapedType> targetShard =
       builder
-          .create<AllScatterOp>(sourceShard, mesh,
-                                ArrayRef<MeshAxis>(splitMeshAxis),
-                                splitTensorAxis)
+          .create<AllSliceOp>(sourceShard, mesh,
+                              ArrayRef<MeshAxis>(splitMeshAxis),
+                              splitTensorAxis)
           .getResult()
           .cast<TypedValue<ShapedType>>();
   MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 95739c18ce0f36..c8c75b94a42206 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -51,7 +51,6 @@ struct ProcessMultiIndexOpLowering
     builder.setInsertionPointAfter(op.getOperation());
     Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
     ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
-    llvm::errs() << "meshShape.size() = " << meshShape.size() << "\n";
     SmallVector<Value> completeMultiIndex =
         builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
             .getMultiIndex();
@@ -72,13 +71,28 @@ struct ProcessMultiIndexOpLowering
   }
 };
 
-struct AllScatterOpLowering
-    : OpRewritePatternWithSymbolTableCollection<AllScatterOp> {
+struct AllSliceOpLowering
+    : OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
   using OpRewritePatternWithSymbolTableCollection::
       OpRewritePatternWithSymbolTableCollection;
 
-  LogicalResult matchAndRewrite(AllScatterOp op,
+  LogicalResult matchAndRewrite(AllSliceOp op,
                                 PatternRewriter &rewriter) const override {
+    // 1. Compute the process linear index inside the process group from its
+    // multi-index.
+    //
+    // 2. Extract a slice from the input tensor.
+    // All axes except the slicing axis are not interesting and take the full
+    // axis.
+    // The slice axis is split into equisized parts with count
+    // the number of processes in the collective process group induced by
+    // the mesh axes.
+    // The part for each process is determined by the corresponding
+    // linear-index in the process group.
+    //
+    // There are no collectives that require communication.
+    // Each process operates on its local tensor.
+
     MeshOp mesh = getMesh(op, symbolTableCollection);
     if (!mesh) {
       return failure();
@@ -99,32 +113,31 @@ struct AllScatterOpLowering
     Value processGroupSize =
         createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
 
-    int64_t scatterAxis = op.getScatterAxis().getSExtValue();
-    Value operandScatterAxisSize =
-        builder.create<tensor::DimOp>(op.getOperand(), scatterAxis);
-    Value operandScatterAxisSizeModProcessGroupSize =
-        builder.create<arith::RemUIOp>(operandScatterAxisSize,
-                                       processGroupSize);
+    int64_t sliceAxis = op.getSliceAxis().getSExtValue();
+    Value operandSliceAxisSize =
+        builder.create<tensor::DimOp>(op.getOperand(), sliceAxis);
+    Value operandSliceAxisSizeModProcessGroupSize =
+        builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize);
     Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
-        arith::CmpIPredicate::eq, operandScatterAxisSizeModProcessGroupSize,
+        arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize,
         zero);
     builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
-                                 "Scattering a tensor with axis size that is "
+                                 "Slicing a tensor with axis size that is "
                                  "not exactly divisible by the "
                                  "mesh process group size is not supported.");
-    Value resultScatterAxisSize = builder.create<arith::DivUIOp>(
-        operandScatterAxisSize, processGroupSize);
-    OpFoldResult processInGroupLinearIndex = affine::linearIndexFromShape(
+    Value resultSliceAxisSize =
+        builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize);
+    OpFoldResult processInGroupLinearIndex = affine::linearIndex(
         llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
         llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
 
-    // extract slice
+    // insert tensor.extract_slice
     RankedTensorType operandType =
         op.getOperand().getType().cast<RankedTensorType>();
     SmallVector<OpFoldResult> sizes;
     for (int64_t i = 0; i < operandType.getRank(); ++i) {
-      if (i == scatterAxis) {
-        sizes.emplace_back(resultScatterAxisSize);
+      if (i == sliceAxis) {
+        sizes.emplace_back(resultSliceAxisSize);
       } else {
         Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
         sizes.emplace_back(dimSize);
@@ -132,11 +145,11 @@ struct AllScatterOpLowering
     }
     SmallVector<OpFoldResult> offsets(
         operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
-    offsets[scatterAxis] =
+    offsets[sliceAxis] =
         ArithBuilder(builder, builder.getLoc())
             .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
                                                  processInGroupLinearIndex),
-                 resultScatterAxisSize);
+                 resultSliceAxisSize);
     SmallVector<OpFoldResult> strides(
         operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
     Value slice = builder.create<tensor::ExtractSliceOp>(
@@ -151,23 +164,23 @@ struct AllScatterOpLowering
 
 } // namespace
 
-void processMultiIndexOpLoweringPopulatePatterns(
+void populateProcessMultiIndexOpLoweringPatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
   patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
                                             patterns.getContext());
 }
 
-void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry) {
+void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry) {
   registry.insert<affine::AffineDialect, mesh::MeshDialect>();
 }
 
-void allScatterOpLoweringPopulatePatterns(
+void populateAllSliceOpLoweringPatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
-  patterns.add<AllScatterOpLowering>(symbolTableCollection,
-                                     patterns.getContext());
+  patterns.add<AllSliceOpLowering>(symbolTableCollection,
+                                   patterns.getContext());
 }
 
-void allScatterOpLoweringRegisterDialects(DialectRegistry &registry) {
+void registerAllSliceOpLoweringDialects(DialectRegistry &registry) {
   registry.insert<affine::AffineDialect, arith::ArithDialect,
                   cf::ControlFlowDialect, mesh::MeshDialect,
                   tensor::TensorDialect>();
@@ -175,13 +188,13 @@ void allScatterOpLoweringRegisterDialects(DialectRegistry &registry) {
 
 void populateAllOpLoweringPatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
-  processMultiIndexOpLoweringPopulatePatterns(patterns, symbolTableCollection);
-  allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+  populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection);
+  populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
 }
 
 void registerAllOpLoweringDialects(DialectRegistry &registry) {
-  processMultiIndexOpLoweringRegisterDialects(registry);
-  allScatterOpLoweringRegisterDialects(registry);
+  registerProcessMultiIndexOpLoweringDialects(registry);
+  registerAllSliceOpLoweringDialects(registry);
 }
 
 TypedValue<IndexType>
diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
index 8aa3fc5160a001..4f54607a1c7ff0 100644
--- a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt --split-input-file --test-mesh-all-scatter-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --test-mesh-all-slice-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
 
 mesh.mesh @mesh_1d(shape = ?)
 
-// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
-func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
+func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
   // CHECK: %[[ARG:.*]]: tensor<?xf16>
   %arg0: tensor<?xf16>
 // CHECK-SAME: -> tensor<?xf16> {
@@ -18,7 +18,7 @@ func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
   // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
   // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index
   // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16>
-  %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<?xf16> -> tensor<?xf16>
+  %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16>
   // CHECK: return %[[RESULT]] : tensor<?xf16>
   return %0 : tensor<?xf16>
 }
@@ -27,8 +27,8 @@ func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
 
 mesh.mesh @mesh_1d(shape = 2)
 
-// CHECK-LABEL: func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh
-func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh
+func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
   // CHECK: %[[ARG:.*]]: tensor<2xf16>
   %arg0: tensor<2xf16>
 // CHECK-SAME: -> tensor<1xf16> {
@@ -37,7 +37,7 @@ func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
   // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
   // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16>
   // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16>
-  %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<2xf16> -> tensor<1xf16>
+  %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16>
   // CHECK: return %[[RESULT]] : tensor<1xf16>
   return %0 : tensor<1xf16>
 }
@@ -48,8 +48,8 @@ func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
 
 mesh.mesh @mesh_4d(shape = ?x?x?x?)
 
-// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
-func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
+func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
   // CHECK: %[[ARG:.*]]: tensor<?x?xf16>
   %arg0 : tensor<?x?xf16>
 // CHECK-SAME: -> tensor<?x?xf16> {
@@ -68,7 +68,7 @@ func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
   // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
   // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
   // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
-  %0 = mesh.all_scatter %arg0 on @mesh_4d mesh_axes = [3, 1] scatter_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
+  %0 = mesh.all_slice %arg0 on @mesh_4d mesh_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
   // CHECK: return %[[RESULT]] : tensor<?x?xf16>
   return %0 : tensor<?x?xf16>
 }
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 4fa9c81f4cd004..633324ae680eb1 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,14 +63,14 @@ func.func @all_gather_empty_mesh_axes(
   return %0 : tensor<4xf32>
 }
 
-// CHECK-LABEL: func @all_scatter_empty_mesh_axes
-func.func @all_scatter_empty_mesh_axes(
+// CHECK-LABEL: func @all_slice_empty_mesh_axes
+func.func @all_slice_empty_mesh_axes(
 // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
 // CHECK-NOT: mesh.scatter
-  %0 = mesh.all_scatter %arg0 on @mesh0
+  %0 = mesh.all_slice %arg0 on @mesh0
     mesh_axes = []
-    scatter_axis = 0
+    slice_axis = 0
     : tensor<4xf32> -> tensor<4xf32>
 // CHECK: return %[[ARG]]
   return %0 : tensor<4xf32>
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 590d548d9c769e..6d7df86d78406f 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -318,11 +318,11 @@ func.func @all_gather_invalid_negative_gather_axis(
 
 mesh.mesh @mesh0(shape = 3)
 
-func.func @all_scatter_duplicate_mesh_axis(
+func.func @all_slice_duplicate_mesh_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf32> {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0, 0]
-    scatter_axis = 0
+  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0, 0]
+    slice_axis = 0
     : tensor<?xf32> -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
@@ -331,11 +331,11 @@ func.func @all_scatter_duplicate_mesh_axis(
 
 mesh.mesh @mesh0(shape = 3)
 
-func.func @all_scatter_invalid_dynamic_dimension(
+func.func @all_slice_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
-  %0 = mesh.all_scatter %arg0 on @mesh0
-    scatter_axis = 0
+  %0 = mesh.all_slice %arg0 on @mesh0
+    slice_axis = 0
     : tensor<?xf32> -> tensor<2xf32>
   return %0 : tensor<2xf32>
 }
@@ -344,11 +344,11 @@ func.func @all_scatter_invalid_dynamic_dimension(
 
 mesh.mesh @mesh0(shape = 3)
 
-func.func @all_scatter_invalid_static_dimension_size(
+func.func @all_slice_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
-  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
-    scatter_axis = 0
+  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
+    slice_axis = 0
     : tensor<3xf32> -> tensor<2xf32>
   return %0 : tensor<2xf32>
 }
@@ -357,11 +357,11 @@ func.func @all_scatter_invalid_static_dimension_size(
 
 mesh.mesh @mesh0(shape = 3)
 
-func.func @all_scatter_invalid_operand_static_dimension_size(
+func.func @all_slice_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf32> {
-  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
-  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
-    scatter_axis = 0
+  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
+  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
+    slice_axis = 0
     : tensor<4xf32> -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
@@ -712,7 +712,7 @@ mesh.mesh @mesh0(shape = 3)
 
 func.func @reduce_scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf64> {
-  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
   %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
     : tensor<4xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
@@ -763,7 +763,7 @@ mesh.mesh @mesh0(shape = 3)
 
 func.func @scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf32> {
-  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
   %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
     scatter_axis = 0 root = [1]
     : (tensor<4xf32>) -> tensor<?xf32>
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 69e8545f3cb86f..6e5df86b13106a 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -208,26 +208,26 @@ func.func @all_gather_dynamic_dims_in_mesh(
   return %0 : tensor<5x?xf32>
 }
 
-// CHECK-LABEL: func @all_scatter_static_dimensions
-func.func @all_scatter_static_dimensions(
+// CHECK-LABEL: func @all_slice_static_dimensions
+func.func @all_slice_static_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
-  // CHECK-NEXT: mesh.all_scatter %[[ARG]]
-  // CHECK-SAME: on @mesh0 mesh_axes = [2] scatter_axis = 1
+  // CHECK-NEXT: mesh.all_slice %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
-  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [2] scatter_axis = 1
+  %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1
     : tensor<3x4xf32> -> tensor<3x1xf32>
   return %0 : tensor<3x1xf32>
 }
 
-// CHECK-LABEL: func @all_scatter_dynamic_dimensions
-func.func @all_scatter_dynamic_dimensions(
+// CHECK-LABEL: func @all_slice_dynamic_dimensions
+func.func @all_slice_dynamic_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
     %arg0 : tensor<?xf32>) -> tensor<?xf32> {
-  // CHECK-NEXT: mesh.all_scatter %[[ARG]]
-  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+  // CHECK-NEXT: mesh.all_slice %[[ARG]]
+  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0
   // CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
-  %0 = mesh.all_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+  %0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0
     : tensor<?xf32> -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index fdf9905a62e99c..ba05306598bcc6 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -19,9 +19,9 @@ func.func @split_replicated_tensor_axis(
   // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
   %arg0: tensor<3x14xf32>
 ) -> tensor<3x14xf32> {
-  // CHECK: %[[ALL_SCATTER:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 1
+  // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1
   // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
-  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SCATTER]] : tensor<3x7xf32> to tensor<3x14xf32>
+  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
   %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
   %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
   // CHECK: return %[[RESULT]] : tensor<3x14xf32>
@@ -33,7 +33,7 @@ func.func @split_replicated_tensor_axis_dynamic(
   // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
   %arg0: tensor<?x3x?xf32>
 ) -> tensor<?x3x?xf32> {
-  // CHECK: %[[RESULT:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] scatter_axis = 0
+  // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0
   // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
   %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
   %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 3aeea169107d66..2fb8029dfe64ae 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -63,7 +63,7 @@ func.func @unary_elementwise_with_resharding(
   %arg0: tensor<2xi8>
 // CHECK-SAME: -> tensor<2xi8> {
 ) -> tensor<2xi8> {
-  // CHECK: %[[SLICE:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+  // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
   // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
   %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
   %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
@@ -108,7 +108,7 @@ func.func @multiple_chained_ops(
   %arg0: tensor<2xi8>
 // CHECK-SAME: -> tensor<1xi8> {
 ) -> tensor<2xi8> {
-  // CHECK: %[[RESHARD1:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+  // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
   // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
   %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
   %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
@@ -120,7 +120,7 @@ func.func @multiple_chained_ops(
   %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
   // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
   %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
-  // CHECK: %[[RESHARD3:.*]] = mesh.all_scatter %[[ABS2]] on @mesh_1d mesh_axes = [0] scatter_axis = 0 : 
+  // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 : 
   // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
   %6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8>
   %7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
index 2c9131e666ed69..321b6a42bf9667 100644
--- a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
@@ -18,27 +18,27 @@ using namespace mlir;
 
 namespace {
 
-struct TestAllScatterOpLoweringPass
-    : public PassWrapper<TestAllScatterOpLoweringPass, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllScatterOpLoweringPass)
+struct TestAllSliceOpLoweringPass
+    : public PassWrapper<TestAllSliceOpLoweringPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass)
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     SymbolTableCollection symbolTableCollection;
-    mesh::allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+    mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
     LogicalResult status =
         applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     (void)status;
     assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    mesh::allScatterOpLoweringRegisterDialects(registry);
+    mesh::registerAllSliceOpLoweringDialects(registry);
   }
   StringRef getArgument() const final {
-    return "test-mesh-all-scatter-op-lowering";
+    return "test-mesh-all-slice-op-lowering";
   }
   StringRef getDescription() const final {
-    return "Test lowering of all-scatter.";
+    return "Test lowering of all-slice.";
   }
 };
 
@@ -49,7 +49,7 @@ struct TestMultiIndexOpLoweringPass
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     SymbolTableCollection symbolTableCollection;
-    mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
+    mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
                                                       symbolTableCollection);
     LogicalResult status =
         applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
@@ -57,7 +57,7 @@ struct TestMultiIndexOpLoweringPass
     assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    mesh::processMultiIndexOpLoweringRegisterDialects(registry);
+    mesh::registerProcessMultiIndexOpLoweringDialects(registry);
   }
   StringRef getArgument() const final {
     return "test-mesh-process-multi-index-op-lowering";
@@ -72,7 +72,7 @@ struct TestMultiIndexOpLoweringPass
 namespace mlir {
 namespace test {
 void registerTestOpLoweringPasses() {
-  PassRegistration<TestAllScatterOpLoweringPass>();
+  PassRegistration<TestAllSliceOpLoweringPass>();
   PassRegistration<TestMultiIndexOpLoweringPass>();
 }
 } // namespace test

>From e85bc613caccdb8bf4fcda013d1bf0e2404eda25 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Tue, 13 Feb 2024 14:33:45 -0800
Subject: [PATCH 4/4] A few missed things

---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 33 ++++++++++++++--------------
 1 file changed, 16 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 91165765245c6d..3291010d27428a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -559,12 +559,12 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
 }
 
 static LogicalResult verifyScatterOrSliceOperandAndResultShape(
-    Value operand, Value result, int64_t scatterAxis,
+    Value operand, Value result, int64_t tensorAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   ShapedType operandType = operand.getType().cast<ShapedType>();
   ShapedType resultType = result.getType().cast<ShapedType>();
   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
-    if (axis != scatterAxis) {
+    if (axis != tensorAxis) {
       if (failed(verifyDimensionCompatibility(
               result.getLoc(), operandType.getDimSize(axis),
               resultType.getDimSize(axis), axis))) {
@@ -576,38 +576,38 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape(
   auto deviceGroupSize =
       DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandScatterDimSize =
-      DimensionSize(operandType.getDimSize(scatterAxis));
+      DimensionSize(operandType.getDimSize(tensorAxis));
   if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
       int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
     return emitError(result.getLoc())
            << "Operand dimension size " << int64_t(operandScatterDimSize)
            << " is not divisible by collective device group size "
-           << int64_t(deviceGroupSize) << " for tensor axis " << scatterAxis
+           << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
            << ".";
   }
-  DimensionSize expectedResultScatterDimSize =
+  DimensionSize expectedResultTensorDimSize =
       operandScatterDimSize / deviceGroupSize;
   if (failed(verifyDimensionCompatibility(
-          result.getLoc(), expectedResultScatterDimSize.value(),
-          resultType.getDimSize(scatterAxis), scatterAxis))) {
+          result.getLoc(), expectedResultTensorDimSize.value(),
+          resultType.getDimSize(tensorAxis), tensorAxis))) {
     return failure();
   }
 
   return success();
 }
 
-static RankedTensorType scatterResultType(Type operandType, MeshOp mesh,
-                                          ArrayRef<MeshAxis> meshAxes,
-                                          int64_t scatterAxis) {
+static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
+                                        ArrayRef<MeshAxis> meshAxes,
+                                        int64_t sliceAxis) {
   RankedTensorType operandRankedTensorType =
-      operandType.cast<RankedTensorType>();
-  DimensionSize operandScatterAxisSize =
-      operandRankedTensorType.getShape()[scatterAxis];
+      cast<RankedTensorType>(operandType);
+  DimensionSize operandSliceAxisSize =
+      operandRankedTensorType.getShape()[sliceAxis];
   SmallVector<int64_t> resultShape =
       llvm::to_vector(operandRankedTensorType.getShape());
 
-  resultShape[scatterAxis] =
-      operandScatterAxisSize /
+  resultShape[sliceAxis] =
+      operandSliceAxisSize /
       DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
   return operandRankedTensorType.clone(resultShape);
 }
@@ -669,8 +669,7 @@ void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                        Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
                        int64_t sliceAxis) {
-  Type resultType =
-      scatterResultType(input.getType(), mesh, meshAxes, sliceAxis);
+  Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
   build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
         sliceAxis);
 }



More information about the Mlir-commits mailing list