[llvm] [mlir] [mlir][mesh] Add all-scatter operation (PR #81218)

Boian Petkantchin via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 8 19:10:23 PST 2024


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/81218

>From 5d1e4cdaf9b5722d8322adaf36b4dce1962381e3 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 7 Feb 2024 08:05:53 -0800
Subject: [PATCH 1/2] [mlir][mesh] Add all-scatter operation

This op is the inverse of all-gather. It is useful to have an explicit
concise representation instead of having a blob of slicing logic.

Add lowering for the op that slices from the tensor based on the in-group
process index.

Make resharding generate an all-scatter instead of inserting the slicing logic
directly.
---
 mlir/include/mlir/Dialect/Affine/Utils.h      |   7 +
 mlir/include/mlir/Dialect/Arith/Utils/Utils.h |  17 +++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |   6 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  63 ++++++++
 .../mlir/Dialect/Mesh/Transforms/Transforms.h |  19 ++-
 mlir/include/mlir/IR/Builders.h               |   3 +
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |  27 ++++
 mlir/lib/Dialect/Arith/Utils/Utils.cpp        |  19 +++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  64 ++++++++-
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |   2 +
 .../Mesh/Transforms/Simplifications.cpp       |  16 +--
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  84 ++---------
 .../Dialect/Mesh/Transforms/Transforms.cpp    | 136 ++++++++++++++++--
 .../Mesh/Transforms/TransformsDetail.h        |  35 +++++
 mlir/lib/IR/Builders.cpp                      |  18 +++
 .../Dialect/Mesh/all-scatter-op-lowering.mlir |  74 ++++++++++
 mlir/test/Dialect/Mesh/canonicalization.mlir  |  13 ++
 mlir/test/Dialect/Mesh/invalid.mlir           |  52 +++++++
 mlir/test/Dialect/Mesh/ops.mlir               |  24 ++++
 .../Dialect/Mesh/resharding-spmdization.mlir  |  31 +---
 mlir/test/Dialect/Mesh/spmdization.mlir       |  14 +-
 mlir/test/lib/Dialect/Mesh/CMakeLists.txt     |   2 +-
 mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp |  79 ++++++++++
 .../Mesh/TestProcessMultiIndexOpLowering.cpp  |  54 -------
 mlir/tools/mlir-opt/mlir-opt.cpp              |   4 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |   2 +
 26 files changed, 673 insertions(+), 192 deletions(-)
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
 create mode 100644 mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
 create mode 100644 mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
 delete mode 100644 mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp

diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index b3ccbff3002fb1..441ce1952210bb 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -15,12 +15,14 @@
 
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/OpDefinition.h"
 #include <optional>
 
 namespace mlir {
 class DominanceInfo;
 class Operation;
 class PostDominanceInfo;
+class ImplicitLocOpBuilder;
 
 namespace func {
 class FuncOp;
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
                                                ArrayRef<Value> basis);
+// Generate IR that extracts the linear index form a multi-index according to
+// a shape.
+OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+                                  ArrayRef<OpFoldResult> shape,
+                                  ImplicitLocOpBuilder &builder);
 
 /// Ensure that all operations that could be executed after `start`
 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 402bd196f0736a..2111a7c5810294 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Value.h"
+#include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
 
@@ -81,6 +82,22 @@ struct ArithBuilder {
   OpBuilder &b;
   Location loc;
 };
+
+namespace arith {
+
+// Build the product of a sequence.
+// If values = (v0, v1, ..., vn) than the returned
+// value is v0 * v1 * ... * vn.
+// All values must have the same type.
+//
+// The version without `resultType` must contain at least one element in values.
+// Then the result will have the same type as the elements in `values`.
+// If `values` is empty in the version with `resultType` returns 1 with type
+// `resultType`.
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+                    Type resultType);
+} // namespace arith
 } // namespace mlir
 
 #endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9154e6fd803102..fb9425b96e68e2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
   return res;
 }
 
+template <typename MeshAxesRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
+  return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
+                                    mesh.getShape());
+}
+
 // Get the size of a sharded dimension.
 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
   if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index da372706ec724c..7792aac784d4be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
 
   let builders = [
     OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
   let hasCanonicalizer = 1;
 }
 
+def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
+    Pure,
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
+  let description = [{
+    Scatter along the `scatter_axis` tensor axis.
+    This operation can be taught of as the inverse of all-gather.
+    Technically, it is not required that all processes have the same input tensor.
+    Each process will slice a piece of its local tensor based on its in-group device index.
+    The operation does not communicate data between devices. 
+
+    Example:
+    ```mlir
+    mesh.mesh @mesh0(shape = 2x2)
+    ...
+    %1 = mesh.all_scatter %0 on @mesh0 mesh_axes = [1] scatter_axis = 1
+      : tensor<2x4xi8> -> tensor<2x2xi8>
+    ```
+    Input:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+    Result:
+    ```
+    gather tensor
+    axis 1
+    ------------>
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    IndexAttr:$scatter_axis
+  ));
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis
+    attr-dict `:` type($input) `->` type($result)
+  }];
+  let hasCanonicalizer = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>,
+    OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>
+  ];
+}
+
 def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
     Pure,
     SameOperandsAndResultElementType,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index 10a965daac71b9..d398bdd65330b1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -9,17 +9,34 @@
 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+
 namespace mlir {
 class RewritePatternSet;
 class SymbolTableCollection;
 class DialectRegistry;
+class ImplicitLocOpBuilder;
 namespace mesh {
 
 void processMultiIndexOpLoweringPopulatePatterns(
     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-
 void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry);
 
+void allScatterOpLoweringPopulatePatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void allScatterOpLoweringRegisterDialects(DialectRegistry &registry);
+
+void populateAllOpLoweringPatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void registerAllOpLoweringDialects(DialectRegistry &registry);
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+                                 ImplicitLocOpBuilder &builder);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2fe1495b2b593b..43b6d2b3841690 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -118,6 +118,9 @@ class Builder {
   // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
   // ranked tensor of them. Returns null attribute otherwise.
   TypedAttr getZeroAttr(Type type);
+  // Returns a 1-valued attribute of the given `type`.
+  // Type constraints are the same as `getZeroAttr`.
+  TypedAttr getOneAttr(Type type);
 
   // Convenience methods for fixed types.
   FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d4adb94a9fc8d..41bcd2b9f33e6b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -20,9 +20,11 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <optional>
@@ -1869,3 +1871,28 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
   results.push_back(residual);
   return results;
 }
+
+OpFoldResult
+mlir::affine::linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+                                   ArrayRef<OpFoldResult> shape,
+                                   ImplicitLocOpBuilder &builder) {
+  assert(multiIndex.size() == shape.size());
+  SmallVector<AffineExpr> shapeAffine;
+  for (size_t i = 0; i < shape.size(); ++i) {
+    shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+  }
+
+  SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+  SmallVector<OpFoldResult> strides;
+  strides.reserve(stridesAffine.size());
+  llvm::transform(stridesAffine, std::back_inserter(strides),
+                  [&builder, &shape](AffineExpr strideExpr) {
+                    return affine::makeComposedFoldedAffineApply(
+                        builder, builder.getLoc(), strideExpr, shape);
+                  });
+
+  auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
+      OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
+  return affine::makeComposedFoldedAffineApply(
+      builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bf274d4ae27ed8..999cdbc5c10008 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include <numeric>
 
 using namespace mlir;
 
@@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
 }
+
+namespace mlir::arith {
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
+  return createProduct(builder, loc, values, values.front().getType());
+}
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+                    Type resultType) {
+  Value one = builder.create<ConstantOp>(loc, resultType,
+                                         builder.getOneAttr(resultType));
+  ArithBuilder arithBuilder(builder, loc);
+  return std::accumulate(
+      values.begin(), values.end(), one,
+      [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
+}
+
+} // namespace mlir::arith
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index a65b8f2e5a2376..762725d2c56e66 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -252,14 +252,20 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                         MeshOp mesh) {
+  build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+}
+
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        MeshOp mesh, ArrayRef<MeshAxis> axes) {
   build(odsBuilder, odsState,
-        SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
-        mesh.getSymName(),
-        MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
+        SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+                          odsBuilder.getIndexType()),
+        mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
 }
 
 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
                         StringRef mesh, ArrayRef<MeshAxis> axes) {
+  assert(!axes.empty());
   build(odsBuilder, odsState,
         SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -590,6 +596,22 @@ static LogicalResult verifyScatterOperandAndResultShape(
   return success();
 }
 
+static RankedTensorType scatterResultType(Type operandType, MeshOp mesh,
+                                          ArrayRef<MeshAxis> meshAxes,
+                                          int64_t scatterAxis) {
+  RankedTensorType operandRankedTensorType =
+      operandType.cast<RankedTensorType>();
+  DimensionSize operandScatterAxisSize =
+      operandRankedTensorType.getShape()[scatterAxis];
+  SmallVector<int64_t> resultShape =
+      llvm::to_vector(operandRankedTensorType.getShape());
+
+  resultShape[scatterAxis] =
+      operandScatterAxisSize /
+      DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+  return operandRankedTensorType.clone(resultShape);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_gather op
 //===----------------------------------------------------------------------===//
@@ -625,6 +647,42 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.all_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyScatterOperandAndResultShape(
+      getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().getShape());
+}
+
+void AllScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                               MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllScatterOp>>(context);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                         Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+                         int64_t scatterAxis) {
+  Type resultType =
+      scatterResultType(input.getType(), mesh, meshAxes, scatterAxis);
+  build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+        scatterAxis);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                         Type resultType, Value input, StringRef mesh,
+                         ArrayRef<MeshAxis> meshAxes, int64_t scatterAxis) {
+  build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+        APInt(sizeof(scatterAxis) * CHAR_BIT, scatterAxis));
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_to_all op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index dccb75848c94f0..28af820440076c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -13,7 +13,9 @@ add_mlir_dialect_library(MLIRMeshTransforms
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect
+  MLIRAffineUtils
   MLIRArithDialect
+  MLIRArithUtils
   MLIRControlFlowDialect
   MLIRFuncDialect
   MLIRIR
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c0273cdaef7144..7fcac2312444f3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "TransformsDetail.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -16,7 +17,6 @@
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include <iterator>
 #include <numeric>
 #include <utility>
 
@@ -56,13 +56,10 @@ namespace {
 // symbol tables.
 // We can't use DialectFoldInterface since the cache may be invalidated by some
 // pass changing the referenced MeshOp ops.
-struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
-  template <typename... OpRewritePatternArgs>
-  MeshShapeFolder(SymbolTableCollection &symbolTableCollection,
-                  OpRewritePatternArgs &&...opRewritePatternArgs)
-      : OpRewritePattern(
-            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
-        symbolTableCollection(symbolTableCollection) {}
+struct MeshShapeFolder
+    : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+  using OpRewritePatternWithSymbolTableCollection::
+      OpRewritePatternWithSymbolTableCollection;
   LogicalResult matchAndRewrite(MeshShapeOp op,
                                 PatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
@@ -113,9 +110,6 @@ struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
 
     return success();
   }
-
-private:
-  SymbolTableCollection &symbolTableCollection;
 };
 
 } // namespace
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b649157a9e46de..21e03d9572c590 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -128,92 +128,24 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
 }
 
-static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
-                                             int64_t splitTensorAxis,
-                                             int64_t splitCount) {
-  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
-  targetShape[splitTensorAxis] =
-      shardDimension(targetShape[splitTensorAxis], splitCount);
-  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
-}
-
 // Split a replicated tensor along a mesh axis.
 // e.g. [[0, 1]] -> [[0, 1, 2]].
 // Returns the spmdized target value with its sharding.
-//
-// The implementation is the extract the tensor slice corresponding
-// to the current device.
 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
 splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           MeshShardingAttr sourceSharding,
                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
-  MLIRContext *ctx = builder.getContext();
-  builder.setInsertionPointAfterValue(sourceShard);
-
-  Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
-
-  Value processIndexAlongAxis =
+  TypedValue<ShapedType> targetShard =
       builder
-          .create<ProcessMultiIndexOp>(mesh.getSymName(),
-                                       SmallVector<MeshAxis>({splitMeshAxis}))
-          .getResult()[0];
-
+          .create<AllScatterOp>(sourceShard, mesh,
+                                ArrayRef<MeshAxis>(splitMeshAxis),
+                                splitTensorAxis)
+          .getResult()
+          .cast<TypedValue<ShapedType>>();
   MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
-      ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
-  ShapedType targetShape = targetShapeInSplitLastAxis(
-      sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
-
-  Value meshAxisSize =
-      builder
-          .create<MeshShapeOp>(mesh.getSymName(),
-                               SmallVector<MeshAxis>({splitMeshAxis}))
-          .getResult()[0];
-
-  Value sourceAxisSize =
-      builder.create<tensor::DimOp>(sourceShard, splitTensorAxis);
-  Value sourceAxisSizeModMeshAxisSize =
-      builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize);
-  Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
-      arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero);
-  builder.create<cf::AssertOp>(
-      isTargetShapeExactlyDivisible,
-      "Sharding a tensor with axis size that is not exactly divisible by the "
-      "mesh axis size is not supported.");
-  Value targetAxisSize =
-      builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize);
-  Value axisOffset =
-      builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis);
-  SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0);
-  staticOffsets[splitTensorAxis] = ShapedType::kDynamic;
-  DenseI64ArrayAttr staticOffsetsAttr =
-      DenseI64ArrayAttr::get(ctx, staticOffsets);
-  SmallVector<Value> dynamicOffsets(1, axisOffset);
-
-  DenseI64ArrayAttr staticSizesAttr =
-      DenseI64ArrayAttr::get(ctx, targetShape.getShape());
-  SmallVector<Value> dynamicSizes;
-  for (int64_t i = 0; i < targetShape.getRank(); ++i) {
-    if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) {
-      if (i == splitTensorAxis) {
-        dynamicSizes.push_back(targetAxisSize);
-      } else {
-        Value dimSize = builder.create<tensor::DimOp>(sourceShard, i);
-        dynamicSizes.push_back(dimSize);
-      }
-    }
-  }
-
-  DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get(
-      ctx, SmallVector<int64_t>(targetShape.getRank(), 1));
-  TypedValue<RankedTensorType> targetShard =
-      builder
-          .create<tensor::ExtractSliceOp>(
-              targetShape, sourceShard, dynamicOffsets, dynamicSizes,
-              SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr,
-              staticStridesAttr)
-          .getResult();
-  return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding};
+      builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
+  return {targetShard, targetSharding};
 }
 
 // Detect if the resharding is of type e.g.
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 03b1d9b3498028..95739c18ce0f36 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -7,12 +7,21 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "TransformsDetail.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectRegistry.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/STLExtras.h"
@@ -26,18 +35,14 @@ namespace {
 
 /// Lower `mesh.process_multi_index` into expression using
 /// `mesh.process_linear_index` and `mesh.mesh_shape`.
-struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
-  template <typename... OpRewritePatternArgs>
-  ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
-                              OpRewritePatternArgs &&...opRewritePatternArgs)
-      : OpRewritePattern(
-            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
-        symbolTableCollection(symbolTableCollection) {}
+struct ProcessMultiIndexOpLowering
+    : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
+  using OpRewritePatternWithSymbolTableCollection::
+      OpRewritePatternWithSymbolTableCollection;
 
   LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
-        op.getOperation(), op.getMeshAttr());
+    MeshOp mesh = getMesh(op, symbolTableCollection);
     if (!mesh) {
       return failure();
     }
@@ -46,6 +51,7 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
     builder.setInsertionPointAfter(op.getOperation());
     Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
     ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
+    llvm::errs() << "meshShape.size() = " << meshShape.size() << "\n";
     SmallVector<Value> completeMultiIndex =
         builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
             .getMultiIndex();
@@ -64,9 +70,83 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
     rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
     return success();
   }
+};
+
+struct AllScatterOpLowering
+    : OpRewritePatternWithSymbolTableCollection<AllScatterOp> {
+  using OpRewritePatternWithSymbolTableCollection::
+      OpRewritePatternWithSymbolTableCollection;
+
+  LogicalResult matchAndRewrite(AllScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    MeshOp mesh = getMesh(op, symbolTableCollection);
+    if (!mesh) {
+      return failure();
+    }
+
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    builder.setInsertionPointAfter(op.getOperation());
+
+    Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+    Operation::result_range processInGroupMultiIndex =
+        builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
+            .getResults();
+
+    Operation::result_range processGroupShape =
+        builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
+            .getResult();
+    Value processGroupSize =
+        createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
+
+    int64_t scatterAxis = op.getScatterAxis().getSExtValue();
+    Value operandScatterAxisSize =
+        builder.create<tensor::DimOp>(op.getOperand(), scatterAxis);
+    Value operandScatterAxisSizeModProcessGroupSize =
+        builder.create<arith::RemUIOp>(operandScatterAxisSize,
+                                       processGroupSize);
+    Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+        arith::CmpIPredicate::eq, operandScatterAxisSizeModProcessGroupSize,
+        zero);
+    builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
+                                 "Scattering a tensor with axis size that is "
+                                 "not exactly divisible by the "
+                                 "mesh process group size is not supported.");
+    Value resultScatterAxisSize = builder.create<arith::DivUIOp>(
+        operandScatterAxisSize, processGroupSize);
+    OpFoldResult processInGroupLinearIndex = affine::linearIndexFromShape(
+        llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
+        llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
+
+    // extract slice
+    RankedTensorType operandType =
+        op.getOperand().getType().cast<RankedTensorType>();
+    SmallVector<OpFoldResult> sizes;
+    for (int64_t i = 0; i < operandType.getRank(); ++i) {
+      if (i == scatterAxis) {
+        sizes.emplace_back(resultScatterAxisSize);
+      } else {
+        Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
+        sizes.emplace_back(dimSize);
+      }
+    }
+    SmallVector<OpFoldResult> offsets(
+        operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
+    offsets[scatterAxis] =
+        ArithBuilder(builder, builder.getLoc())
+            .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
+                                                 processInGroupLinearIndex),
+                 resultScatterAxisSize);
+    SmallVector<OpFoldResult> strides(
+        operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
+    Value slice = builder.create<tensor::ExtractSliceOp>(
+        op.getOperand(), offsets, sizes, strides);
+    Value newResult =
+        builder.create<tensor::CastOp>(op.getResult().getType(), slice);
+    rewriter.replaceAllUsesWith(op.getResult(), newResult);
 
-private:
-  SymbolTableCollection &symbolTableCollection;
+    return success();
+  }
 };
 
 } // namespace
@@ -81,4 +161,38 @@ void processMultiIndexOpLoweringRegisterDialects(DialectRegistry &registry) {
   registry.insert<affine::AffineDialect, mesh::MeshDialect>();
 }
 
+void allScatterOpLoweringPopulatePatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
+  patterns.add<AllScatterOpLowering>(symbolTableCollection,
+                                     patterns.getContext());
+}
+
+void allScatterOpLoweringRegisterDialects(DialectRegistry &registry) {
+  registry.insert<affine::AffineDialect, arith::ArithDialect,
+                  cf::ControlFlowDialect, mesh::MeshDialect,
+                  tensor::TensorDialect>();
+}
+
+void populateAllOpLoweringPatterns(
+    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
+  processMultiIndexOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+  allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+}
+
+void registerAllOpLoweringDialects(DialectRegistry &registry) {
+  processMultiIndexOpLoweringRegisterDialects(registry);
+  allScatterOpLoweringRegisterDialects(registry);
+}
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+                                 ImplicitLocOpBuilder &builder) {
+  Operation::result_range meshShape =
+      builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
+  return arith::createProduct(builder, builder.getLoc(),
+                              llvm::to_vector_of<Value>(meshShape),
+                              builder.getIndexType())
+      .cast<TypedValue<IndexType>>();
+}
+
 } // namespace mlir::mesh
diff --git a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h b/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
new file mode 100644
index 00000000000000..3e3f584caca247
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
@@ -0,0 +1,35 @@
+//===- TransformsDetail.h - -------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+namespace mesh {
+
+template <typename Op>
+struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> {
+  template <typename... OpRewritePatternArgs>
+  OpRewritePatternWithSymbolTableCollection(
+      SymbolTableCollection &symbolTableCollection,
+      OpRewritePatternArgs &&...opRewritePatternArgs)
+      : OpRewritePattern<Op>(
+            std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+        symbolTableCollection(symbolTableCollection) {}
+
+protected:
+  SymbolTableCollection &symbolTableCollection;
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2e42c4e870716b..18ca3c332e0204 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -346,6 +346,24 @@ TypedAttr Builder::getZeroAttr(Type type) {
   return {};
 }
 
+TypedAttr Builder::getOneAttr(Type type) {
+  if (llvm::isa<FloatType>(type))
+    return getFloatAttr(type, 1.0);
+  if (llvm::isa<IndexType>(type))
+    return getIndexAttr(1);
+  if (llvm::dyn_cast<IntegerType>(type))
+    return getIntegerAttr(type,
+                          APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
+  if (llvm::isa<RankedTensorType, VectorType>(type)) {
+    auto vtType = llvm::cast<ShapedType>(type);
+    auto element = getOneAttr(vtType.getElementType());
+    if (!element)
+      return {};
+    return DenseElementsAttr::get(vtType, element);
+  }
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Affine Expressions, Affine Maps, and Integer Sets.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
new file mode 100644
index 00000000000000..8aa3fc5160a001
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt --split-input-file --test-mesh-all-scatter-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
+
+mesh.mesh @mesh_1d(shape = ?)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
+func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
+  // CHECK: %[[ARG:.*]]: tensor<?xf16>
+  %arg0: tensor<?xf16>
+// CHECK-SAME: -> tensor<?xf16> {
+) -> tensor<?xf16> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+  // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+  // CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor<?xf16>
+  // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+  // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
+  // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
+  // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+  // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16>
+  %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<?xf16> -> tensor<?xf16>
+  // CHECK: return %[[RESULT]] : tensor<?xf16>
+  return %0 : tensor<?xf16>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh
+func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
+  // CHECK: %[[ARG:.*]]: tensor<2xf16>
+  %arg0: tensor<2xf16>
+// CHECK-SAME: -> tensor<1xf16> {
+) -> tensor<1xf16> {
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+  // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16>
+  // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16>
+  %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<2xf16> -> tensor<1xf16>
+  // CHECK: return %[[RESULT]] : tensor<1xf16>
+  return %0 : tensor<1xf16>
+}
+
+// -----
+
+// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+
+mesh.mesh @mesh_4d(shape = ?x?x?x?)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
+func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
+  // CHECK: %[[ARG:.*]]: tensor<?x?xf16>
+  %arg0 : tensor<?x?xf16>
+// CHECK-SAME: -> tensor<?x?xf16> {
+) -> tensor<?x?xf16> {
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index
+  // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index
+  // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
+  // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
+  // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+  // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
+  // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
+  // CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+  // CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1]
+  // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
+  // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
+  %0 = mesh.all_scatter %arg0 on @mesh_4d mesh_axes = [3, 1] scatter_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
+  // CHECK: return %[[RESULT]] : tensor<?x?xf16>
+  return %0 : tensor<?x?xf16>
+}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 23c5b253b4c073..4fa9c81f4cd004 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,6 +63,19 @@ func.func @all_gather_empty_mesh_axes(
   return %0 : tensor<4xf32>
 }
 
+// CHECK-LABEL: func @all_scatter_empty_mesh_axes
+func.func @all_scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.scatter
+  %0 = mesh.all_scatter %arg0 on @mesh0
+    mesh_axes = []
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
 // CHECK-LABEL: func @broadcast_empty_mesh_axes
 func.func @broadcast_empty_mesh_axes(
 // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3fa3ebd67b15e7..590d548d9c769e 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -316,6 +316,58 @@ func.func @all_gather_invalid_negative_gather_axis(
 
 // -----
 
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_duplicate_mesh_axis(
+    %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0, 0]
+    scatter_axis = 0
+    : tensor<?xf32> -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<2xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+  %0 = mesh.all_scatter %arg0 on @mesh0
+    scatter_axis = 0
+    : tensor<?xf32> -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_static_dimension_size(
+    %arg0 : tensor<3xf32>) -> tensor<2xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
+    scatter_axis = 0
+    : tensor<3xf32> -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_operand_static_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<?xf32> {
+  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
 func.func @all_to_all_invalid_mesh_symbol(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 40a8469b264643..69e8545f3cb86f 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -208,6 +208,30 @@ func.func @all_gather_dynamic_dims_in_mesh(
   return %0 : tensor<5x?xf32>
 }
 
+// CHECK-LABEL: func @all_scatter_static_dimensions
+func.func @all_scatter_static_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
+  // CHECK-NEXT: mesh.all_scatter %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [2] scatter_axis = 1
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
+  %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [2] scatter_axis = 1
+    : tensor<3x4xf32> -> tensor<3x1xf32>
+  return %0 : tensor<3x1xf32>
+}
+
+// CHECK-LABEL: func @all_scatter_dynamic_dimensions
+func.func @all_scatter_dynamic_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+    %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK-NEXT: mesh.all_scatter %[[ARG]]
+  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+  // CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
+  %0 = mesh.all_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+    : tensor<?xf32> -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
 // CHECK-LABEL: func @all_to_all
 func.func @all_to_all(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index f1d686135c28eb..fdf9905a62e99c 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -19,17 +19,9 @@ func.func @split_replicated_tensor_axis(
   // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
   %arg0: tensor<3x14xf32>
 ) -> tensor<3x14xf32> {
-  // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
-  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
-  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
-  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
-  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
-  // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
-  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
-  // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
-  // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32>
-  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+  // CHECK: %[[ALL_SCATTER:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 1
+  // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
+  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SCATTER]] : tensor<3x7xf32> to tensor<3x14xf32>
   %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
   %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
   // CHECK: return %[[RESULT]] : tensor<3x14xf32>
@@ -41,22 +33,11 @@ func.func @split_replicated_tensor_axis_dynamic(
   // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
   %arg0: tensor<?x3x?xf32>
 ) -> tensor<?x3x?xf32> {
-  // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index
-  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index
-  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d_dynamic axes = [0] : index
-  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
-  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
-  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
-  // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
-  // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
-  // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
-  // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor<?x3x?xf32>
-  // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0]
-  // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor<?x3x?xf32> to tensor<?x3x?xf32>
+  // CHECK: %[[RESULT:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] scatter_axis = 0
+  // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
   %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
   %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
-  // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor<?x3x?xf32>
+  // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
   return %1 : tensor<?x3x?xf32>
 }
 
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 9993c1518e9eae..3aeea169107d66 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -63,9 +63,8 @@ func.func @unary_elementwise_with_resharding(
   %arg0: tensor<2xi8>
 // CHECK-SAME: -> tensor<2xi8> {
 ) -> tensor<2xi8> {
-  // We don't care about the whole resharding IR, just that it happens.
-  // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
-  // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+  // CHECK: %[[SLICE:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
   %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
   %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
   // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
@@ -109,9 +108,8 @@ func.func @multiple_chained_ops(
   %arg0: tensor<2xi8>
 // CHECK-SAME: -> tensor<1xi8> {
 ) -> tensor<2xi8> {
-  // We don't care about the whole resharding IR, just that it happens.
-  // CHECK: %[[RESHARD1:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
-  // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+  // CHECK: %[[RESHARD1:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
   %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
   %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
   // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
@@ -122,8 +120,8 @@ func.func @multiple_chained_ops(
   %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
   // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
   %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
-  // CHECK: %[[RESHARD3:.*]] = tensor.extract_slice %[[ABS2]][%{{.*}}] [1] [1]
-  // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+  // CHECK: %[[RESHARD3:.*]] = mesh.all_scatter %[[ABS2]] on @mesh_1d mesh_axes = [0] scatter_axis = 0 : 
+  // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
   %6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8>
   %7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
   // CHECK: return %[[RESHARD3]] : tensor<1xi8>
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index 00931e6c94fc5f..07e9bb6f9f238a 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,6 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMeshTest
-  TestProcessMultiIndexOpLowering.cpp
+  TestOpLowering.cpp
   TestReshardingSpmdization.cpp
   TestSimplifications.cpp
 
diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
new file mode 100644
index 00000000000000..2c9131e666ed69
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
@@ -0,0 +1,79 @@
+//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestAllScatterOpLoweringPass
+    : public PassWrapper<TestAllScatterOpLoweringPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllScatterOpLoweringPass)
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    SymbolTableCollection symbolTableCollection;
+    mesh::allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+    LogicalResult status =
+        applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    (void)status;
+    assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    mesh::allScatterOpLoweringRegisterDialects(registry);
+  }
+  StringRef getArgument() const final {
+    return "test-mesh-all-scatter-op-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Test lowering of all-scatter.";
+  }
+};
+
+struct TestMultiIndexOpLoweringPass
+    : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    SymbolTableCollection symbolTableCollection;
+    mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
+                                                      symbolTableCollection);
+    LogicalResult status =
+        applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    (void)status;
+    assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    mesh::processMultiIndexOpLoweringRegisterDialects(registry);
+  }
+  StringRef getArgument() const final {
+    return "test-mesh-process-multi-index-op-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Test lowering of mesh.process_multi_index op.";
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestOpLoweringPasses() {
+  PassRegistration<TestAllScatterOpLoweringPass>();
+  PassRegistration<TestMultiIndexOpLoweringPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
deleted file mode 100644
index 0bcc403a2734ed..00000000000000
--- a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
+++ /dev/null
@@ -1,54 +0,0 @@
-//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-
-namespace {
-struct TestMultiIndexOpLoweringPass
-    : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
-
-  void runOnOperation() override;
-  void getDependentDialects(DialectRegistry &registry) const override {
-    mesh::processMultiIndexOpLoweringRegisterDialects(registry);
-  }
-  StringRef getArgument() const final {
-    return "test-mesh-process-multi-index-op-lowering";
-  }
-  StringRef getDescription() const final {
-    return "Test lowering of mesh.process_multi_index op.";
-  }
-};
-} // namespace
-
-void TestMultiIndexOpLoweringPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  SymbolTableCollection symbolTableCollection;
-  mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
-                                                    symbolTableCollection);
-  LogicalResult status =
-      applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-  (void)status;
-  assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
-}
-
-namespace mlir {
-namespace test {
-void registerTestMultiIndexOpLoweringPass() {
-  PassRegistration<TestMultiIndexOpLoweringPass>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428bdd9691e095..a94f42927f3f3b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,7 +120,7 @@ void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
 void registerTestMeshSimplificationsPass();
 void registerTestMeshReshardingSpmdizationPass();
-void registerTestMultiIndexOpLoweringPass();
+void registerTestOpLoweringPasses();
 void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
@@ -241,7 +241,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
-  mlir::test::registerTestMultiIndexOpLoweringPass();
+  mlir::test::registerTestOpLoweringPasses();
   mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestMeshReshardingSpmdizationPass();
   mlir::test::registerTestNextAccessPass();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9f2c1f1120066b..8943e91c1f0ceb 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3314,7 +3314,9 @@ cc_library(
     includes = ["include"],
     deps = [
         ":AffineDialect",
+        ":AffineUtils",
         ":ArithDialect",
+        ":ArithUtils",
         ":ControlFlowDialect",
         ":DialectUtils",
         ":FuncDialect",

>From 8cbbb177dc5ba8803e4b068672dc679ef230fe80 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 8 Feb 2024 19:10:01 -0800
Subject: [PATCH 2/2] Remove unused resharding dependencies

---
 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 6 +-----
 1 file changed, 1 insertion(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 21e03d9572c590..cd0481d7b87aec 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -8,9 +8,6 @@
 
 #include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
@@ -519,8 +516,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
 }
 
 void reshardingRegisterDependentDialects(DialectRegistry &registry) {
-  registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
-                  cf::ControlFlowDialect>();
+  registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
 }
 
 #define GEN_PASS_DEF_SPMDIZATION



More information about the llvm-commits mailing list