[Mlir-commits] [mlir] [mlir][shard, mpi] Adding Shard/MPI reduce_scatter (PR #184189)

Frank Schlimbach llvmlistbot at llvm.org
Tue Mar 3 02:49:48 PST 2026


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/184189

>From 0698503f87e0be19b78d468ef9acbe5a0652459c Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 26 Feb 2026 04:42:34 -0800
Subject: [PATCH 1/7] adding mpi::reduce_scatter_block, lowering to llvm and
 tests

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td    | 34 +++++++++
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   | 76 ++++++++++++++++++-
 mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 10 +++
 mlir/test/Dialect/MPI/mpiops.mlir             |  6 ++
 4 files changed, 125 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 7e68b152fdf75..ee79026bf8214 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -333,6 +333,40 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
                        "(`->` type($retval)^)?";
 }
 
+//===----------------------------------------------------------------------===//
+// ReduceScatterBlockOp
+//===----------------------------------------------------------------------===//
+
+def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
+  let summary = "Equivalent to `MPI_Reduce_scatter_block(sendbuf, recvbuf, "
+                "recvcount, dtype, op, comm)`";
+  let description = [{
+    MPI_Reduce_scatter_block first performs an element-wise reduction on the
+    sendbuf across all processes in the communicator, then scatters the result
+    by distributing equal-sized blocks to each process into recvbuf.
+
+    The `op` attribute specifies the reduction operation to be performed.
+    Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
+    supported.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (
+    ins AnyMemRef : $sendbuf,
+    AnyMemRef : $recvbuf,
+    MPI_ReductionOpEnum : $op,
+    MPI_Comm : $comm
+  );
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
+                       "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
+                       "(`->` type($retval)^)?";
+}
+
 //===----------------------------------------------------------------------===//
 // BarrierOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 0dbc0a126a5c6..8571b39d41780 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -911,6 +912,79 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ReduceScatterBlockOpLowering
+//===----------------------------------------------------------------------===//
+
+struct ReduceScatterBlockOpLowering
+    : public ConvertOpToLLVMPattern<mpi::ReduceScatterBlockOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::ReduceScatterBlockOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type i32 = rewriter.getI32Type();
+    Type i64 = rewriter.getI64Type();
+    Type elemType = op.getSendbuf().getType().getElementType();
+    int64_t sRank = op.getSendbuf().getType().getRank();
+    int64_t rRank = op.getRecvbuf().getType().getRank();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    auto [sendPtr, sendSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
+    auto [recvPtr, recvSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
+
+    // If input and output are the same, request in-place operation.
+    if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
+      sendPtr = LLVM::ConstantOp::create(
+          rewriter, loc, i64,
+          reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
+      sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
+    }
+
+    Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+    Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+    // recvCnt should be the size of the block -> we need to divide by the
+    // number of ranks to get the count argument for the function call.
+    Value nRanks = createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
+    Value recvCnt = LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
+
+    // Check that recvSize is a multiple of the number of ranks
+    Value checkSize = LLVM::MulOp::create(rewriter, loc, i32, recvCnt, nRanks);
+    Value sizeIsValid = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, checkSize, recvSize);
+    cf::AssertOp::create(rewriter, loc, sizeIsValid, "Output buffer's size must be a multiple of the number of ranks");
+
+    // 'int MPI_Reduce_scatter_block(const void *sendbuf, void *recvbuf,
+    //     int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
+    auto funcType = LLVM::LLVMFunctionType::get(
+        i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
+              comm.getType()});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "MPI_Reduce_scatter_block", funcType);
+
+    // replace op with function call
+    auto funcCall = LLVM::CallOp::create(
+        rewriter, loc, funcDecl,
+        ValueRange{sendPtr, recvPtr, recvCnt, dataType, mpiOp, comm});
+
+    if (op.getRetval())
+      rewriter.replaceOp(op, funcCall.getResult());
+    else
+      rewriter.eraseOp(op);
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertToLLVMPatternInterface implementation
 //===----------------------------------------------------------------------===//
@@ -943,7 +1017,7 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
   patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
                CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
                SendOpLowering, RecvOpLowering, AllGatherOpLowering,
-               AllReduceOpLowering>(converter);
+               AllReduceOpLowering, ReduceScatterBlockOpLowering>(converter);
 }
 
 void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 9ec81c53b41f8..7e35f331b51a2 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -3,6 +3,7 @@
 // COM: Test MPICH ABI
 // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
 // CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
 // CHECK: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
@@ -119,6 +120,10 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
     mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
+    // CHECK: [[v73:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
+    // CHECK: llvm.call @MPI_Reduce_scatter_block([[v73]], {{.*}} : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+    mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+
     // CHECK: llvm.call @MPI_Finalize() : () -> i32
     %3 = mpi.finalize : !mpi.retval
 
@@ -132,6 +137,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
 // CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Reduce_scatter_block(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
 // CHECK: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32
@@ -240,6 +246,10 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
     mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
+    // CHECK: [[v63:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
+    // CHECK: llvm.call @MPI_Reduce_scatter_block([[v63]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+    mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+
     // CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32
     %color = arith.constant 10 : i32
     // CHECK: [[v72:%.*]] = llvm.mlir.constant(22 : i32) : i32
diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir
index 87a5647ee91d2..ee979d33c699d 100644
--- a/mlir/test/Dialect/MPI/mpiops.mlir
+++ b/mlir/test/Dialect/MPI/mpiops.mlir
@@ -77,6 +77,12 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () {
     // CHECK-NEXT: mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
     mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
+    // CHECK-NEXT: [[v10:%.*]] = mpi.reduce_scatter_block([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+    %err9 = mpi.reduce_scatter_block(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+
+    // CHECK-NEXT: mpi.reduce_scatter_block([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32>
+    mpi.reduce_scatter_block(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
+
     // CHECK-NEXT: [[v7:%.*]] = mpi.finalize : !mpi.retval
     %rval = mpi.finalize : !mpi.retval
 

>From 9bb1b80ad4b4cb95179803d6df11e439294e7d9c Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 27 Feb 2026 06:43:08 -0800
Subject: [PATCH 2/7] simplify pass (was test-simplification) and lowering
 shard.reduce_scatter to MPI

---
 .../include/mlir/Dialect/Shard/IR/ShardOps.td |  20 +--
 .../mlir/Dialect/Shard/Transforms/Partition.h |   2 +-
 .../mlir/Dialect/Shard/Transforms/Passes.td   |  17 ++
 .../{Simplifications.h => Simplify.h}         |  16 +-
 mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp | 153 +++++++++++++++++-
 mlir/lib/Dialect/Shard/IR/ShardOps.cpp        |   6 +-
 .../Dialect/Shard/Transforms/CMakeLists.txt   |   3 +-
 .../{Simplifications.cpp => Simplify.cpp}     | 108 +++++++++----
 .../ShardToMPI/convert-shard-to-mpi.mlir      |  34 ++++
 .../Shard/all-scatter-op-lowering.mlir        |  12 +-
 mlir/test/Dialect/Shard/canonicalization.mlir |   8 +-
 mlir/test/Dialect/Shard/folding.mlir          |   2 +-
 mlir/test/Dialect/Shard/invalid.mlir          |  20 +--
 mlir/test/Dialect/Shard/ops.mlir              |  20 +--
 .../{simplifications.mlir => simplify.mlir}   |  85 +++++++++-
 mlir/test/lib/Dialect/Shard/CMakeLists.txt    |   1 -
 .../Dialect/Shard/TestReshardingPartition.cpp |   2 +-
 .../lib/Dialect/Shard/TestSimplifications.cpp |  47 ------
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 -
 19 files changed, 423 insertions(+), 135 deletions(-)
 rename mlir/include/mlir/Dialect/Shard/Transforms/{Simplifications.h => Simplify.h} (89%)
 rename mlir/lib/Dialect/Shard/Transforms/{Simplifications.cpp => Simplify.cpp} (55%)
 rename mlir/test/Dialect/Shard/{simplifications.mlir => simplify.mlir} (66%)
 delete mode 100644 mlir/test/lib/Dialect/Shard/TestSimplifications.cpp

diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 6ef7c72d305ee..60f6d6fc1ffe4 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -901,13 +901,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
     `grid_axes` using the specified reduction method. The reduction is performed
     element-wise across the tensor pieces from all devices in the group.
     After reduction, the reduction result is scattered (split and distributed)
-    across the device group along `scatter_axis`.
+    across the device group along `scatter_dim`.
     Example:
     ```
     shard.grid @grid0(shape = 2x2)
     ...
     %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
-      reduction = <max> scatter_axis = 0
+      reduction = <max> scatter_dim = 0
       : tensor<2x2xf32> -> tensor<1x2xf64>
     ```
     Input:
@@ -940,17 +940,17 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
     ```
   }];
   let arguments = !con(commonArgs, (ins
-    AnyNon0RankedTensor:$input,
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
     DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction,
-    IndexAttr:$scatter_axis
+    IndexAttr:$scatter_dim
   ));
   let results = (outs
-    AnyRankedTensor:$result
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
   );
   let assemblyFormat = [{
     $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     (`reduction` `=` $reduction^)?
-    `scatter_axis` `=` $scatter_axis
+    `scatter_dim` `=` $scatter_dim
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasCanonicalizer = 1;
@@ -964,7 +964,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
   let summary = "Scatter over a device grid.";
   let description = [{
     For each device group defined by `grid_axes`, the input tensor on the `root`
-    device is split along axis `scatter_axis` and distributed across the group.
+    device is split along axis `scatter_dim` and distributed across the group.
     The content of the input on all other (non-root) devices is ignored.
     The `root` device is defined by its in-group multi-index.
 
@@ -972,7 +972,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
     ```
     shard.grid @grid0(shape = 2x2)
     %1 = shard.scatter %0 on @grid0 grid_axes = [0]
-      scatter_axis = 0
+      scatter_dim = 0
       root = [1]
       : (tensor<2x2xi8>) -> tensor<1x2xi8>
     ```
@@ -1011,7 +1011,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    IndexAttr:$scatter_axis,
+    IndexAttr:$scatter_dim,
     DenseI64ArrayAttr:$root,
     Variadic<Index>:$root_dynamic
   ));
@@ -1020,7 +1020,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
   );
   let assemblyFormat = [{
     $input `on` $grid (`grid_axes` `=` $grid_axes^)?
-    `scatter_axis` `=` $scatter_axis
+    `scatter_dim` `=` $scatter_dim
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
index 37903765903db..ba12002b9f1eb 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
+//===- Partition.h - Shard Partition ----------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
index bbc6a1977b13e..575c176217e61 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
@@ -44,6 +44,23 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
   ];
 }
 
+def ShardSimplify : Pass<"shard-simplify"> {
+  let summary = "Shard simplify patterns.";
+  let description = [{
+    Applies simplification patterns on the Shard dialect operations.
+    This includes:
+    - All-reduce endomorphism simplification, e.g. transforming
+      `all_reduce_sum(x) + all_reduce_sum(y)` into `all_reduce_sum(x + y)`.
+    - Folding `AllSliceOp(AllReduceOp)` into `ReduceScatterOp` when both ops
+      share the same grid and grid_axes.
+    - Folding static grid shapes into constants.
+  }];
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "shard::ShardDialect"
+  ];
+}
+
 def Partition : InterfacePass<"shard-partition", "mlir::FunctionOpInterface"> {
   let summary = "Partition a function into SPMD form.";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
similarity index 89%
rename from mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
index 45ae758ec14c2..f3f4feffd8a71 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
+//===- Simplify.h - Shard Simplify ------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
-#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
 
 #include "mlir/Dialect/Shard/IR/ShardOps.h"
 #include "mlir/IR/PatternMatch.h"
@@ -37,8 +37,8 @@ namespace shard {
 // Will not work with some op `f(x, y, z)` where only `x` and `y` form
 // the algebraic structure.
 template <typename AlgebraicOp>
-void populateAllReduceEndomorphismSimplificationPatterns(
-    RewritePatternSet &patterns, ReductionKind reduction) {
+void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns,
+                                                   ReductionKind reduction) {
   auto getEndomorphismOpOperand = [](Operation *op) {
     auto allReduceOp = llvm::cast<AllReduceOp>(op);
     return &allReduceOp.getInputMutable();
@@ -105,12 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
 
 // It is invalid to change ops that declare symbols during the application of
 // these patterns, because symbolTableCollection is used to cache them.
-void populateSimplificationPatterns(
-    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void populateSimplifyPatterns(RewritePatternSet &patterns,
+                              SymbolTableCollection &symbolTableCollection);
 void populateFoldingPatterns(RewritePatternSet &patterns,
                              SymbolTableCollection &symbolTableCollection);
 
 } // namespace shard
 } // namespace mlir
 
-#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 1db14e60c5a7f..01196dd1b6fcd 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -26,7 +26,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Shard/IR/ShardDialect.h"
 #include "mlir/Dialect/Shard/IR/ShardOps.h"
-#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Simplify.h"
 #include "mlir/Dialect/Shard/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -618,6 +618,155 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
   }
 };
 
+struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
+  using CommOpPattern::CommOpPattern;
+
+  // shard.reduce_scatter reduces and then scatters along a specified
+  // scatter-dim. mpi.reduce_scatter_block always scatters along the first
+  // dimension. Hence, if scatter-dim != 0, we need to rearrange the input
+  // data by expanding the scatter-dim into {nRanks, output_scatter_dim} and
+  // transposing nRanks to the first dimension.
+
+  LogicalResult
+  matchAndRewrite(ReduceScatterOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto gridAxes = adaptor.getGridAxes();
+    int64_t scatterDim = adaptor.getScatterDimAttr().getInt();
+
+    SymbolTableCollection symbolTableCollection;
+    FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+    if (failed(gridOp))
+      return failure();
+
+    ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
+    Value rawInput = adaptor.getInput();
+    auto inShapedType = cast<ShapedType>(rawInput.getType());
+    auto elemType = inShapedType.getElementType();
+    MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+    auto inputShape = inShapedType.getShape();
+    auto outputShape = outType.getShape();
+    int64_t inputDimOnAxis = inputShape[scatterDim];
+    int64_t outputDimOnAxis = outputShape[scatterDim];
+
+    for (size_t i = 0; i < outputShape.size(); ++i)
+      if (outputShape[i] != inputShape[i] && i != (size_t)scatterDim)
+        return op.emitError(
+            "Result and input shapes must match along non-scatter axes.");
+    if (outputDimOnAxis == 0)
+      return op.emitError(
+          "Output size along the scatter axis must be non-zero.");
+    if (inputDimOnAxis % outputDimOnAxis != 0)
+      return op.emitError(
+          "Input size along the scatter axis must be an exact "
+          "multiple of the output size along the scatter axis.");
+
+    if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+      return op.emitError("Result must be a statically shaped memref in "
+                          "contiguous row-major layout.");
+
+    int64_t nRanks = inputDimOnAxis / outputDimOnAxis;
+
+    // Verify that nRanks matches the number of devices along the grid axes.
+    int64_t gridGroupSize =
+        collectiveProcessGroupSize(gridAxes, gridOp->getShape());
+    if (nRanks != gridGroupSize)
+      return op.emitError()
+             << "Expected the scatter factor (" << nRanks
+             << ") to match the number of devices along grid_axes ("
+             << gridGroupSize << ").";
+
+    // Get the right communicator.
+    Value comm = getComm(*gridOp, gridAxes, ib);
+
+    Value mpiInput;
+    if (scatterDim == 0) {
+      // scatter_dim == 0 maps directly to MPI_Reduce_scatter_block.
+      // Input must be contiguous for MPI.
+      Value input = getAsMemref(rawInput, ib);
+      MemRefType inType = cast<MemRefType>(input.getType());
+      if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+        return op.emitError("Input must be a statically shaped memref in "
+                            "contiguous row-major layout.");
+      mpiInput = input;
+    } else {
+      // For scatter_dim != 0 we rearrange the input so the scatter factor
+      // becomes the first dimension.
+      //
+      // 1. Get a tensor representation of the input (avoid memref->tensor
+      //    round-trip if the input is already a tensor).
+      Value tensorInput = rawInput;
+      if (!isa<RankedTensorType>(rawInput.getType())) {
+        auto inTensorType = RankedTensorType::get(inputShape, elemType);
+        tensorInput =
+            bufferization::ToTensorOp::create(ib, inTensorType, rawInput, true);
+      }
+
+      // 2. Expand the scatter dim from {d0, ..., d_sd, ..., dN} to
+      //    {d0, ..., nRanks, o_sd, ..., dN}.
+      SmallVector<int64_t> expandedShape;
+      SmallVector<ReassociationIndices> expandReassociation;
+      int64_t expandedIdx = 0;
+      for (int64_t i = 0; i < (int64_t)inputShape.size(); ++i) {
+        if (i == scatterDim) {
+          expandedShape.push_back(nRanks);
+          expandedShape.push_back(outputDimOnAxis);
+          expandReassociation.push_back({expandedIdx, expandedIdx + 1});
+          expandedIdx += 2;
+        } else {
+          expandedShape.push_back(inputShape[i]);
+          expandReassociation.push_back({expandedIdx});
+          expandedIdx += 1;
+        }
+      }
+      auto expandedType = RankedTensorType::get(expandedShape, elemType);
+      tensorInput = tensor::ExpandShapeOp::create(ib, expandedType, tensorInput,
+                                                  expandReassociation);
+
+      // 3. Transpose to move nRanks (at position scatterDim) to position 0:
+      //    {d0, ..., nRanks, o_sd, ..., dN} -> {nRanks, d0, ..., o_sd, ..., dN}
+      SmallVector<int64_t> permutation, transposedShape;
+      permutation.emplace_back(scatterDim);
+      for (int64_t i = 0; i < scatterDim; ++i)
+        permutation.emplace_back(i);
+      for (int64_t i = scatterDim + 1; i < (int64_t)expandedShape.size(); ++i)
+        permutation.emplace_back(i);
+      for (auto p : permutation)
+        transposedShape.emplace_back(expandedShape[p]);
+
+      Value permOutput = tensor::EmptyOp::create(ib, transposedShape, elemType);
+      tensorInput =
+          linalg::TransposeOp::create(ib, tensorInput, permOutput, permutation)
+              ->getResult(0);
+
+      // 4. Materialize as contiguous memref for MPI by copying into a
+      //    freshly allocated buffer.
+      auto mpiInType = MemRefType::get(transposedShape, elemType);
+      Value transposedBuf =
+          bufferization::ToBufferOp::create(ib, mpiInType, tensorInput);
+      mpiInput = memref::AllocOp::create(ib, mpiInType);
+      linalg::CopyOp::create(ib, transposedBuf, mpiInput);
+    }
+
+    // Allocate output buffer.
+    Value output = memref::AllocOp::create(ib, outType);
+    // Create the MPI ReduceScatter operation.
+    mpi::ReduceScatterBlockOp::create(
+        ib, TypeRange(), mpiInput, output,
+        getMPIReductionOp(adaptor.getReductionAttr()), comm);
+
+    // Deallocate the temporary input buffer if we allocated one.
+    if (scatterDim != 0)
+      memref::DeallocOp::create(ib, mpiInput);
+
+    // If the destination is a tensor, cast it to a tensor.
+    if (isa<RankedTensorType>(op.getType()))
+      output =
+          bufferization::ToTensorOp::create(ib, op.getType(), output, true);
+    rewriter.replaceOp(op, output);
+    return success();
+  }
+};
+
 struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
   using CommOpPattern::CommOpPattern;
 
@@ -1048,7 +1197,7 @@ struct ConvertShardToMPIPass
 
     patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
                  ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
-                 ConvertAllGatherOp, ConvertAllReduceOp,
+                 ConvertAllGatherOp, ConvertAllReduceOp, ConvertReduceScatterOp,
                  ConvertProcessLinearIndexOp>(typeConverter, ctxt);
     SymbolTableCollection stc;
     populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index 98234bada09e4..a173da3db1d18 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -1416,7 +1416,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   }
 
   return verifyScatterOrSliceOperandAndResultShape(
-      getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
+      getOperand(), getResult(), getScatterDim().getSExtValue(), getGridAxes(),
       grid.value().getShape());
 }
 
@@ -1445,9 +1445,9 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
 
-  auto scatterAxis = getScatterAxis().getSExtValue();
+  auto scatterDim = getScatterDim().getSExtValue();
   return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
-                                                   scatterAxis, getGridAxes(),
+                                                   scatterDim, getGridAxes(),
                                                    grid.value().getShape());
 }
 
diff --git a/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
index a884764e70e92..4e3fb6db9966c 100644
--- a/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_mlir_dialect_library(MLIRShardTransforms
-  Simplifications.cpp
+  Simplify.cpp
   ShardingPropagation.cpp
   Partition.cpp
   Transforms.cpp
@@ -26,4 +26,5 @@ add_mlir_dialect_library(MLIRShardTransforms
   MLIRSupport
   MLIRTensorDialect
   MLIRTosaShardingInterfaceImpl
+  MLIRTransformUtils
 )
diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
similarity index 55%
rename from mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
rename to mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index a17671e5408c4..bb12b430b1be2 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -1,4 +1,4 @@
-//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===//
+//===- Simplify.cpp - Shard Simplify ----------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Simplify.h"
 #include "TransformsDetail.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
 #include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include <numeric>
@@ -20,31 +23,8 @@
 namespace mlir {
 namespace shard {
 
-void populateSimplificationPatterns(
-    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
-  populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
-      patterns, ReductionKind::Sum);
-  populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
-      patterns, ReductionKind::Sum);
-
-  populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
-      patterns, ReductionKind::Min);
-  populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
-      patterns, ReductionKind::Min);
-  populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
-      patterns, ReductionKind::Min);
-
-  populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
-      patterns, ReductionKind::Max);
-  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
-      patterns, ReductionKind::Max);
-  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
-      patterns, ReductionKind::Max);
-
-  // TODO: add simplifications for all-gather and other collectives.
-
-  populateFoldingPatterns(patterns, symbolTableCollection);
-}
+#define GEN_PASS_DEF_SHARDSIMPLIFY
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
 
 namespace {
 
@@ -109,12 +89,86 @@ struct GridShapeFolder
   }
 };
 
+// Simplify AllSliceOp(AllReduceOp) -> ReduceScatterOp when both ops share the
+// same grid and grid_axes.
+//
+// AllReduceOp performs an element-wise reduction across all devices in the
+// group, and AllSliceOp then slices (scatters) the result along a tensor
+// dimension. This is exactly what ReduceScatterOp does in a single collective.
+struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AllSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    // Check if the input to AllSliceOp is produced by an AllReduceOp.
+    auto reduceOp = sliceOp.getInput().getDefiningOp<AllReduceOp>();
+    if (!reduceOp)
+      return failure();
+
+    // Both ops must operate on the same grid and grid axes.
+    if (reduceOp.getGrid() != sliceOp.getGrid() ||
+        reduceOp.getGridAxes() != sliceOp.getGridAxes())
+      return failure();
+
+    // Replace with a single ReduceScatterOp.
+    rewriter.replaceOpWithNewOp<ReduceScatterOp>(
+        sliceOp, sliceOp.getResult().getType(), sliceOp.getGridAttr(),
+        sliceOp.getGridAxesAttr(), reduceOp.getInput(),
+        reduceOp.getReductionAttr(), sliceOp.getSliceAxisAttr());
+
+    return success();
+  }
+};
+
 } // namespace
 
+void populateSimplifyPatterns(RewritePatternSet &patterns,
+                              SymbolTableCollection &symbolTableCollection) {
+  populateAllReduceEndomorphismSimplifyPatterns<arith::AddFOp>(
+      patterns, ReductionKind::Sum);
+  populateAllReduceEndomorphismSimplifyPatterns<arith::AddIOp>(
+      patterns, ReductionKind::Sum);
+
+  populateAllReduceEndomorphismSimplifyPatterns<arith::MinimumFOp>(
+      patterns, ReductionKind::Min);
+  populateAllReduceEndomorphismSimplifyPatterns<arith::MinSIOp>(
+      patterns, ReductionKind::Min);
+  populateAllReduceEndomorphismSimplifyPatterns<arith::MinUIOp>(
+      patterns, ReductionKind::Min);
+
+  populateAllReduceEndomorphismSimplifyPatterns<arith::MaximumFOp>(
+      patterns, ReductionKind::Max);
+  populateAllReduceEndomorphismSimplifyPatterns<arith::MaxSIOp>(
+      patterns, ReductionKind::Max);
+  populateAllReduceEndomorphismSimplifyPatterns<arith::MaxUIOp>(
+      patterns, ReductionKind::Max);
+
+  patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
+
+  // TODO: add simplify patterns for all-gather and other collectives.
+
+  populateFoldingPatterns(patterns, symbolTableCollection);
+}
+
 void populateFoldingPatterns(RewritePatternSet &patterns,
                              SymbolTableCollection &symbolTableCollection) {
   patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
 }
 
+namespace {
+
+struct ShardSimplifyPass : public impl::ShardSimplifyBase<ShardSimplifyPass> {
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    SymbolTableCollection symbolTableCollection;
+    populateSimplifyPatterns(patterns, symbolTableCollection);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
 } // namespace shard
 } // namespace mlir
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index f3da09d05e3b8..08c3897e4e650 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -159,6 +159,40 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
     return %0 : memref<3x4xf64>
   }
 
+  // CHECK-LABEL: func.func @reduce_scatter_memref(
+  func.func @reduce_scatter_memref(
+    // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
+    %arg0 : memref<3x4xf32>) -> memref<1x4xf32> {
+    // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 0 : i32
+    // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 7 : i32
+    // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<1x4xf32>
+    // CHECK: mpi.reduce_scatter_block([[varg0]], [[valloc]], MPI_SUM, [[vnewcomm]]) : memref<3x4xf32>, memref<1x4xf32>
+    %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0 : memref<3x4xf32> -> memref<1x4xf32>
+    // CHECK: return [[valloc]] : memref<1x4xf32>
+    return %0 : memref<1x4xf32>
+  }
+
+  // CHECK-LABEL: func.func @reduce_scatter_tensor_dim1(
+  func.func @reduce_scatter_tensor_dim1(
+    // CHECK-SAME: [[varg0:%.*]]: tensor<2x12xf32>
+    %arg0 : tensor<2x12xf32>) -> tensor<2x4xf32> {
+    // CHECK: [[vexpanded:%.*]] = tensor.expand_shape [[varg0]] {{\[\[}}0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xf32> into tensor<2x3x4xf32>
+    // CHECK: [[vempty:%.*]] = tensor.empty() : tensor<3x2x4xf32>
+    // CHECK: [[vtransposed:%.*]] = linalg.transpose ins([[vexpanded]] : tensor<2x3x4xf32>) outs([[vempty]] : tensor<3x2x4xf32>) permutation = [1, 0, 2]
+    // CHECK: [[vtobuf:%.*]] = bufferization.to_buffer [[vtransposed]] : tensor<3x2x4xf32> to memref<3x2x4xf32>
+    // CHECK: [[valloctmp:%.*]] = memref.alloc() : memref<3x2x4xf32>
+    // CHECK: linalg.copy ins([[vtobuf]] : memref<3x2x4xf32>) outs([[valloctmp]] : memref<3x2x4xf32>)
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x4xf32>
+    // CHECK: mpi.reduce_scatter_block([[valloctmp]], [[valloc]], MPI_SUM,
+    // CHECK: memref.dealloc [[valloctmp]] : memref<3x2x4xf32>
+    // CHECK: [[vout:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<2x4xf32> to tensor<2x4xf32>
+    %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 1 : tensor<2x12xf32> -> tensor<2x4xf32>
+    // CHECK: return [[vout]] : tensor<2x4xf32>
+    return %0 : tensor<2x4xf32>
+  }
+
   // CHECK-LABEL: func @allgather_tensor_0
   // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
   func.func @allgather_tensor_0(%arg0 : tensor<3x4xf32>) -> tensor<12x4xf32> {
diff --git a/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
index bc911215851aa..5b11078c32bbf 100644
--- a/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
+++ b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --test-grid-simplifications --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --shard-simplify --cse %s | FileCheck %s
 
 shard.grid @grid_1d(shape = ?)
 
@@ -59,15 +59,15 @@ func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid(
   // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = shard.process_multi_index on @grid_4d axes = [3, 1] : index, index
   // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = shard.grid_shape @grid_4d axes = [3, 1] : index, index
   // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
-  // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
-  // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+  // CHECK: %[[scatter_dim_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
+  // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[scatter_dim_SIZE]], %[[PROC_GROUP_SIZE]] : index
   // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
   // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
-  // CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+  // CHECK: %[[RESULT_scatter_dim_SIZE:.*]] = arith.divui %[[scatter_dim_SIZE]], %[[PROC_GROUP_SIZE]] : index
   // CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1]
   // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
-  // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
-  // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
+  // CHECK: %[[scatter_dim_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_scatter_dim_SIZE]] : index
+  // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[scatter_dim_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_scatter_dim_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
   %0 = shard.all_slice %arg0 on @grid_4d grid_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
   // CHECK: return %[[RESULT]] : tensor<?x?xf16>
   return %0 : tensor<?x?xf16>
diff --git a/mlir/test/Dialect/Shard/canonicalization.mlir b/mlir/test/Dialect/Shard/canonicalization.mlir
index ed40dfb7237da..a3a9c592ff0ae 100644
--- a/mlir/test/Dialect/Shard/canonicalization.mlir
+++ b/mlir/test/Dialect/Shard/canonicalization.mlir
@@ -135,7 +135,7 @@ func.func @reduce_scatter_empty_grid_axes(
 // CHECK-NOT: shard.reduce_scatter
   %0 = shard.reduce_scatter %arg0 on @grid0
     grid_axes = []
-    scatter_axis = 0
+    scatter_dim = 0
     : tensor<4xf32> -> tensor<4xf32>
 // CHECK: return %[[ARG]]
   return %0 : tensor<4xf32>
@@ -148,7 +148,7 @@ func.func @reduce_scatter_empty_grid_axes_different_return_type(
   %0 = shard.reduce_scatter %arg0 on @grid0
 // CHECK-NOT: grid_axes
     grid_axes = []
-    scatter_axis = 0
+    scatter_dim = 0
     : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
@@ -160,7 +160,7 @@ func.func @reduce_scatter_default_reduction(
     grid_axes = [0]
 // CHECK-NOT: reduction
     reduction = sum
-    scatter_axis = 0
+    scatter_dim = 0
     : tensor<4xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
 }
@@ -172,7 +172,7 @@ func.func @scatter_empty_grid_axes(
 // CHECK-NOT: shard.scatter
   %0 = shard.scatter %arg0 on @grid0
     grid_axes = []
-    scatter_axis = 0
+    scatter_dim = 0
     root = []
     : (tensor<4xf32>) -> tensor<4xf32>
 // CHECK: return %[[ARG]]
diff --git a/mlir/test/Dialect/Shard/folding.mlir b/mlir/test/Dialect/Shard/folding.mlir
index 5a0f35b53a129..7b6ba33f84fd9 100644
--- a/mlir/test/Dialect/Shard/folding.mlir
+++ b/mlir/test/Dialect/Shard/folding.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
+// RUN: mlir-opt -shard-simplify %s | FileCheck %s
 
 shard.grid @grid0(shape = 4x?x2)
 shard.grid @grid1(shape = 2x3)
diff --git a/mlir/test/Dialect/Shard/invalid.mlir b/mlir/test/Dialect/Shard/invalid.mlir
index 6acac971164ed..c92932a725d1c 100644
--- a/mlir/test/Dialect/Shard/invalid.mlir
+++ b/mlir/test/Dialect/Shard/invalid.mlir
@@ -707,7 +707,7 @@ shard.grid @grid0(shape = 3)
 func.func @reduce_scatter_duplicate_grid_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf64> {
   // expected-error at +1 {{Grid axes contains duplicate elements.}}
-  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_dim = 0
     : tensor<?xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }
@@ -719,7 +719,7 @@ shard.grid @grid0(shape = 3)
 func.func @reduce_scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf64> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
-  %0 = shard.reduce_scatter %arg0 on @grid0 scatter_axis = 0
+  %0 = shard.reduce_scatter %arg0 on @grid0 scatter_dim = 0
     : tensor<?xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
 }
@@ -731,7 +731,7 @@ shard.grid @grid0(shape = 3)
 func.func @reduce_scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf64> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
-  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
     : tensor<3xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
 }
@@ -743,7 +743,7 @@ shard.grid @grid0(shape = 3)
 func.func @reduce_scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf64> {
   // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
-  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
     : tensor<4xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }
@@ -756,7 +756,7 @@ func.func @scatter_duplicate_grid_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf32> {
   // expected-error at +1 {{Grid axes contains duplicate elements.}}
   %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 0]
-    scatter_axis = 0 root = [0, 0]
+    scatter_dim = 0 root = [0, 0]
     : (tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
@@ -769,7 +769,7 @@ func.func @scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
   %0 = shard.scatter %arg0 on @grid0
-    scatter_axis = 0 root = []
+    scatter_dim = 0 root = []
     : (tensor<?xf32>) -> tensor<2xf32>
   return %0 : tensor<2xf32>
 }
@@ -782,7 +782,7 @@ func.func @scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
   %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
-    scatter_axis = 0 root = [1]
+    scatter_dim = 0 root = [1]
     : (tensor<3xf32>) -> tensor<2xf32>
   return %0 : tensor<2xf32>
 }
@@ -795,7 +795,7 @@ func.func @scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf32> {
   // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
   %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
-    scatter_axis = 0 root = [1]
+    scatter_dim = 0 root = [1]
     : (tensor<4xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
@@ -808,7 +808,7 @@ func.func @scatter_root_dimension_out_of_bounds(
     %arg0 : tensor<3xi8>) -> tensor<1xi8> {
   // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
   %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
-    scatter_axis = 0 root = [3]
+    scatter_dim = 0 root = [3]
     : (tensor<3xi8>) -> tensor<1xi8>
   return %0 : tensor<1xi8>
 }
@@ -821,7 +821,7 @@ func.func @scatter_root_wrong_number_dimensions(
     %arg0 : tensor<3xi8>) -> tensor<1xi8> {
   // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
   %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
-    scatter_axis = 0 root = [2, 2]
+    scatter_dim = 0 root = [2, 2]
     : (tensor<3xi8>) -> tensor<1xi8>
   return %0 : tensor<1xi8>
 }
diff --git a/mlir/test/Dialect/Shard/ops.mlir b/mlir/test/Dialect/Shard/ops.mlir
index 5265dadd2a845..5d9ea064bed44 100644
--- a/mlir/test/Dialect/Shard/ops.mlir
+++ b/mlir/test/Dialect/Shard/ops.mlir
@@ -453,10 +453,10 @@ func.func @reduce_scatter_static_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
   // CHECK-NEXT: shard.reduce_scatter %[[ARG]]
-  // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_axis = 1
+  // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_dim = 1
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
   %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [2]
-    reduction = max scatter_axis = 1
+    reduction = max scatter_dim = 1
     : tensor<3x4xf32> -> tensor<3x1xf64>
   return %0 : tensor<3x1xf64>
 }
@@ -466,9 +466,9 @@ func.func @reduce_scatter_dynamic_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
     %arg0 : tensor<?xf32>) -> tensor<?xf64> {
   // CHECK-NEXT: shard.reduce_scatter %[[ARG]]
-  // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_axis = 0
+  // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_dim = 0
   // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
-  %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0
+  %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_dim = 0
     : tensor<?xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }
@@ -479,10 +479,10 @@ func.func @scatter_static_dimensions(
     %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
   // CHECK-NEXT: shard.scatter %[[ARG]]
   // CHECK-SAME: on @grid0 grid_axes = [2]
-  // CHECK-SAME: scatter_axis = 1 root = [1]
+  // CHECK-SAME: scatter_dim = 1 root = [1]
   // CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
   %0 = shard.scatter %arg0 on @grid0 grid_axes = [2]
-    scatter_axis = 1 root = [1]
+    scatter_dim = 1 root = [1]
     : (tensor<3x4xf32>) -> tensor<3x1xf32>
   return %0 : tensor<3x1xf32>
 }
@@ -493,10 +493,10 @@ func.func @scatter_dynamic_dimensions(
     %arg0 : tensor<?xf32>) -> tensor<?xf32> {
   // CHECK-NEXT: shard.scatter %[[ARG]]
   // CHECK-SAME: on @grid3 grid_axes = [0, 1]
-  // CHECK-SAME: scatter_axis = 0 root = [1, 2]
+  // CHECK-SAME: scatter_dim = 0 root = [1, 2]
   // CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
   %0 = shard.scatter %arg0 on @grid3 grid_axes = [0, 1]
-    scatter_axis = 0 root = [1, 2]
+    scatter_dim = 0 root = [1, 2]
     : (tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
@@ -510,11 +510,11 @@ func.func @scatter_dynamic_root(
     ) -> tensor<1xi8> {
   // CHECK-NEXT: shard.scatter %[[ARG0]]
   // CHECK-SAME: on @grid0 grid_axes = [0, 2]
-  // CHECK-SAME: scatter_axis = 0
+  // CHECK-SAME: scatter_dim = 0
   // CHECK-SAME: root = [1, %[[ARG1]]]
   // CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
   %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 2]
-    scatter_axis = 0
+    scatter_dim = 0
     root = [1, %arg1]
     : (tensor<8xi8>, index) -> tensor<1xi8>
   return %0 : tensor<1xi8>
diff --git a/mlir/test/Dialect/Shard/simplifications.mlir b/mlir/test/Dialect/Shard/simplify.mlir
similarity index 66%
rename from mlir/test/Dialect/Shard/simplifications.mlir
rename to mlir/test/Dialect/Shard/simplify.mlir
index 33cd490be744a..e5693a288fda6 100644
--- a/mlir/test/Dialect/Shard/simplifications.mlir
+++ b/mlir/test/Dialect/Shard/simplify.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
+// RUN: mlir-opt -shard-simplify %s | FileCheck %s
 
 shard.grid @grid0(shape = 4x2)
 shard.grid @grid1(shape = 4)
@@ -177,3 +177,86 @@ func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 {
   %0 = arith.maxsi %extracted, %c1_i64 : i64
   return %0 : i64
 }
+
+// -----
+// AllReduceOp + AllSliceOp -> ReduceScatterOp tests
+// -----
+
+// Basic case: all_slice(all_reduce(x)) with matching grid and axes folds
+// into reduce_scatter.
+// CHECK-LABEL: func.func @all_reduce_all_slice_to_reduce_scatter
+func.func @all_reduce_all_slice_to_reduce_scatter(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
+  // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0
+  // CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf32>
+  // CHECK: return %[[RS]]
+  return %1 : tensor<1x8xf32>
+}
+
+// Verify non-default reduction kind is preserved.
+// CHECK-LABEL: func.func @all_reduce_all_slice_max_reduction
+func.func @all_reduce_all_slice_max_reduction(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max : tensor<4x8xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
+  // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] reduction = max scatter_dim = 0
+  // CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf32>
+  // CHECK: return %[[RS]]
+  return %1 : tensor<1x8xf32>
+}
+
+// Slice on a different tensor axis than the reduce axes.
+// CHECK-LABEL: func.func @all_reduce_all_slice_different_slice_axis
+func.func @all_reduce_all_slice_different_slice_axis(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [1] : tensor<4x8xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<4x8xf32> -> tensor<4x4xf32>
+  // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [1] scatter_dim = 1
+  // CHECK-SAME: : tensor<4x8xf32> -> tensor<4x4xf32>
+  // CHECK: return %[[RS]]
+  return %1 : tensor<4x4xf32>
+}
+
+// Do not fold when grids differ.
+// CHECK-LABEL: func.func @all_reduce_all_slice_different_grid
+func.func @all_reduce_all_slice_different_grid(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<1x8xf32> {
+  // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
+  // CHECK: %[[AS:.*]] = shard.all_slice %[[AR]] on @grid1
+  %1 = shard.all_slice %0 on @grid1 grid_axes = [0] slice_axis = 0 : tensor<4x8xf32> -> tensor<1x8xf32>
+  // CHECK: return %[[AS]]
+  return %1 : tensor<1x8xf32>
+}
+
+// Do not fold when grid_axes differ.
+// CHECK-LABEL: func.func @all_reduce_all_slice_different_grid_axes
+func.func @all_reduce_all_slice_different_grid_axes(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<4x4xf32> {
+  // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf32>
+  // CHECK: %[[AS:.*]] = shard.all_slice %[[AR]] on @grid0 grid_axes = [1]
+  %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<4x8xf32> -> tensor<4x4xf32>
+  // CHECK: return %[[AS]]
+  return %1 : tensor<4x4xf32>
+}
+
+// Verify element type conversion is preserved (all_reduce input/output types may differ).
+// CHECK-LABEL: func.func @all_reduce_all_slice_type_promotion
+func.func @all_reduce_all_slice_type_promotion(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<1x8xf64> {
+  %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] : tensor<4x8xf32> -> tensor<4x8xf64>
+  %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0 : tensor<4x8xf64> -> tensor<1x8xf64>
+  // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0
+  // CHECK-SAME: : tensor<4x8xf32> -> tensor<1x8xf64>
+  // CHECK: return %[[RS]]
+  return %1 : tensor<1x8xf64>
+}
diff --git a/mlir/test/lib/Dialect/Shard/CMakeLists.txt b/mlir/test/lib/Dialect/Shard/CMakeLists.txt
index f91c54721e030..a97839b6b1ffd 100644
--- a/mlir/test/lib/Dialect/Shard/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Shard/CMakeLists.txt
@@ -2,7 +2,6 @@
 add_mlir_library(MLIRShardTest
   TestOpLowering.cpp
   TestReshardingPartition.cpp
-  TestSimplifications.cpp
 
   EXCLUDE_FROM_LIBMLIR
   )
diff --git a/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
index 23fdad1bd624d..1d1812e4aea39 100644
--- a/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
@@ -1,4 +1,4 @@
-//===- TestSimplification.cpp - Test simplification -----------------------===//
+//===- TestReshardingPartition.cpp - Test resharding partition ------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
deleted file mode 100644
index 28852153f37f6..0000000000000
--- a/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
+++ /dev/null
@@ -1,47 +0,0 @@
-//===- TestSimplification.cpp - Test simplification -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Shard/IR/ShardDialect.h"
-#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-
-namespace {
-struct TestShardSimplificationsPass
-    : public PassWrapper<TestShardSimplificationsPass, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShardSimplificationsPass)
-
-  void runOnOperation() override;
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithDialect, shard::ShardDialect>();
-  }
-  StringRef getArgument() const final { return "test-grid-simplifications"; }
-  StringRef getDescription() const final { return "Test grid simplifications"; }
-};
-} // namespace
-
-void TestShardSimplificationsPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  SymbolTableCollection symbolTableCollection;
-  shard::populateSimplificationPatterns(patterns, symbolTableCollection);
-  [[maybe_unused]] LogicalResult status =
-      applyPatternsGreedily(getOperation(), std::move(patterns));
-  assert(succeeded(status) && "Rewrite patters application did not converge.");
-}
-
-namespace mlir {
-namespace test {
-void registerTestShardSimplificationsPass() {
-  PassRegistration<TestShardSimplificationsPass>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index a427132247e6d..564bb700b53e3 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -131,7 +131,6 @@ void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
 void registerTestMemRefToLLVMWithTransforms();
 void registerTestReshardingPartitionPass();
-void registerTestShardSimplificationsPass();
 void registerTestMultiBuffering();
 void registerTestNextAccessPass();
 void registerTestNVGPULowerings();
@@ -281,7 +280,6 @@ static void registerTestPasses() {
   mlir::test::registerTestMemRefStrideCalculation();
   mlir::test::registerTestMemRefToLLVMWithTransforms();
   mlir::test::registerTestReshardingPartitionPass();
-  mlir::test::registerTestShardSimplificationsPass();
   mlir::test::registerTestMultiBuffering();
   mlir::test::registerTestNextAccessPass();
   mlir::test::registerTestNVGPULowerings();

>From 6767033a1b1a3dd00fa81a4a8116cd420f035532 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 2 Mar 2026 08:56:09 -0800
Subject: [PATCH 3/7] using correct recvcount for lowering mpi.reduce_scatter

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td    |  4 ++--
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   | 19 ++++++++--------
 mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp |  5 +++--
 .../lib/Dialect/Shard/Transforms/Simplify.cpp |  2 +-
 mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 22 +++++++++++++++----
 5 files changed, 33 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index ee79026bf8214..0ed2a98545870 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -354,8 +354,8 @@ def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
   }];
 
   let arguments = (
-    ins AnyMemRef : $sendbuf,
-    AnyMemRef : $recvbuf,
+    ins AnyNon0RankedMemRef : $sendbuf,
+    AnyNon0RankedMemRef : $recvbuf,
     MPI_ReductionOpEnum : $op,
     MPI_Comm : $comm
   );
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 8571b39d41780..5087b96adf8ec 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -952,15 +952,14 @@ struct ReduceScatterBlockOpLowering
     Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
     Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
 
-    // recvCnt should be the size of the block -> we need to divide by the
-    // number of ranks to get the count argument for the function call.
-    Value nRanks = createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
-    Value recvCnt = LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
-
-    // Check that recvSize is a multiple of the number of ranks
-    Value checkSize = LLVM::MulOp::create(rewriter, loc, i32, recvCnt, nRanks);
-    Value sizeIsValid = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, checkSize, recvSize);
-    cf::AssertOp::create(rewriter, loc, sizeIsValid, "Output buffer's size must be a multiple of the number of ranks");
+    Value nRanks =
+        createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
+    Value expected = LLVM::UDivOp::create(rewriter, loc, i32, sendSize, nRanks);
+    Value sizeIsValid = LLVM::ICmpOp::create(
+        rewriter, loc, LLVM::ICmpPredicate::eq, expected, recvSize);
+    cf::AssertOp::create(rewriter, loc, sizeIsValid,
+                         "Send buffer's size must be the receive buffer's size "
+                         "times the number of ranks");
 
     // 'int MPI_Reduce_scatter_block(const void *sendbuf, void *recvbuf,
     //     int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
@@ -974,7 +973,7 @@ struct ReduceScatterBlockOpLowering
     // replace op with function call
     auto funcCall = LLVM::CallOp::create(
         rewriter, loc, funcDecl,
-        ValueRange{sendPtr, recvPtr, recvCnt, dataType, mpiOp, comm});
+        ValueRange{sendPtr, recvPtr, recvSize, dataType, mpiOp, comm});
 
     if (op.getRetval())
       rewriter.replaceOp(op, funcCall.getResult());
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 01196dd1b6fcd..68448fcbd4427 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -649,7 +649,8 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
     int64_t outputDimOnAxis = outputShape[scatterDim];
 
     for (size_t i = 0; i < outputShape.size(); ++i)
-      if (outputShape[i] != inputShape[i] && i != (size_t)scatterDim)
+      if (outputShape[i] != inputShape[i] &&
+          i != static_cast<size_t>(scatterDim))
         return op.emitError(
             "Result and input shapes must match along non-scatter axes.");
     if (outputDimOnAxis == 0)
@@ -706,7 +707,7 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
       SmallVector<int64_t> expandedShape;
       SmallVector<ReassociationIndices> expandReassociation;
       int64_t expandedIdx = 0;
-      for (int64_t i = 0; i < (int64_t)inputShape.size(); ++i) {
+      for (int64_t i = 0; i < static_cast<int64_t>(inputShape.size()); ++i) {
         if (i == scatterDim) {
           expandedShape.push_back(nRanks);
           expandedShape.push_back(outputDimOnAxis);
diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index bb12b430b1be2..c4256b9d3d80a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -102,7 +102,7 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
                                 PatternRewriter &rewriter) const override {
     // Check if the input to AllSliceOp is produced by an AllReduceOp.
     auto reduceOp = sliceOp.getInput().getDefiningOp<AllReduceOp>();
-    if (!reduceOp)
+    if (!reduceOp || !reduceOp->hasOneUse())
       return failure();
 
     // Both ops must operate on the same grid and grid axes.
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 7e35f331b51a2..7069db44a3f58 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -120,8 +120,15 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
     mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
-    // CHECK: [[v73:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
-    // CHECK: llvm.call @MPI_Reduce_scatter_block([[v73]], {{.*}} : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
+    // CHECK: llvm.mul
+    // CHECK: [[rsb_cst:%.*]] = llvm.mlir.constant(1 : index) : i32
+    // CHECK: [[rsb_dim:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[rsb_trunc:%.*]] = llvm.trunc [[rsb_dim]] : i64 to i32
+    // CHECK: [[rsb_recvcount:%.*]] = llvm.mul [[rsb_trunc]], [[rsb_cst]] : i32
+    // CHECK: [[rsb_ipp:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
+    // CHECK: llvm.cond_br {{.*}}, ^[[rsb_bb1:.*]], ^{{.*}}
+    // CHECK: ^[[rsb_bb1]]:
+    // CHECK: llvm.call @MPI_Reduce_scatter_block([[rsb_ipp]], {{.*}}, [[rsb_recvcount]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
     mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
     // CHECK: llvm.call @MPI_Finalize() : () -> i32
@@ -246,8 +253,15 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[ipp]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
     mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
-    // CHECK: [[v63:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
-    // CHECK: llvm.call @MPI_Reduce_scatter_block([[v63]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
+    // CHECK: llvm.mul
+    // CHECK: [[rsb_cst:%.*]] = llvm.mlir.constant(1 : index) : i32
+    // CHECK: [[rsb_dim:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[rsb_trunc:%.*]] = llvm.trunc [[rsb_dim]] : i64 to i32
+    // CHECK: [[rsb_recvcount:%.*]] = llvm.mul [[rsb_trunc]], [[rsb_cst]] : i32
+    // CHECK: [[rsb_ipp:%.*]] = llvm.inttoptr {{.*}} : i64 to !llvm.ptr
+    // CHECK: llvm.cond_br {{.*}}, ^[[rsb_bb1:.*]], ^{{.*}}
+    // CHECK: ^[[rsb_bb1]]:
+    // CHECK: llvm.call @MPI_Reduce_scatter_block([[rsb_ipp]], {{.*}}, [[rsb_recvcount]], {{.*}}) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
     mpi.reduce_scatter_block(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>
 
     // CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32

>From 06803b0b4eab0ff6f9391269961f758e039e7f06 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 2 Mar 2026 10:19:00 -0800
Subject: [PATCH 4/7] adding performance hints

---
 mlir/lib/Dialect/Shard/Transforms/Simplify.cpp | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index c4256b9d3d80a..525ff007bc2f6 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -95,6 +95,17 @@ struct GridShapeFolder
 // AllReduceOp performs an element-wise reduction across all devices in the
 // group, and AllSliceOp then slices (scatters) the result along a tensor
 // dimension. This is exactly what ReduceScatterOp does in a single collective.
+//
+// With a ring algorithm over N ranks and M elements:
+//   AllReduce:      2*(N-1) steps of M/N each  =>  ~2M total data transferred
+//   AllSlice:       local slice, no communication
+//   ReduceScatter:  (N-1) steps of M/N each    =>  ~M total data transferred
+// So this fusion roughly halves the communication volume.
+//
+// Memory-wise, AllReduce produces a full-sized M-element result that the
+// subsequent AllSlice must keep alive until the slice is taken. ReduceScatter
+// only materializes the M/N-element local slice, reducing peak memory by
+// a factor of N.
 struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
   using OpRewritePattern::OpRewritePattern;
 

>From 130cf4178b82e0f2257611920f8466289187baa4 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 2 Mar 2026 19:22:32 +0100
Subject: [PATCH 5/7] Update mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 5087b96adf8ec..50817cf5e00dd 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -954,9 +954,10 @@ struct ReduceScatterBlockOpLowering
 
     Value nRanks =
         createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
-    Value expected = LLVM::UDivOp::create(rewriter, loc, i32, sendSize, nRanks);
+    Value totalExpected =
+        LLVM::MulOp::create(rewriter, loc, i32, recvSize, nRanks);
     Value sizeIsValid = LLVM::ICmpOp::create(
-        rewriter, loc, LLVM::ICmpPredicate::eq, expected, recvSize);
+        rewriter, loc, LLVM::ICmpPredicate::eq, sendSize, totalExpected);
     cf::AssertOp::create(rewriter, loc, sizeIsValid,
                          "Send buffer's size must be the receive buffer's size "
                          "times the number of ranks");

>From 77b42fbfa7b67e9987cb2a3f3aad650e7aac3e46 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Mon, 2 Mar 2026 19:24:58 +0100
Subject: [PATCH 6/7] Update mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp

Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
 mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 68448fcbd4427..830a9333ade4a 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -641,8 +641,8 @@ struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
     ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
     Value rawInput = adaptor.getInput();
     auto inShapedType = cast<ShapedType>(rawInput.getType());
-    auto elemType = inShapedType.getElementType();
     MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+    auto elemType = outType.getElementType();
     auto inputShape = inShapedType.getShape();
     auto outputShape = outType.getShape();
     int64_t inputDimOnAxis = inputShape[scatterDim];

>From b7d9e6a4848d6944d38982e751306e7f12daae19 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 3 Mar 2026 02:49:28 -0800
Subject: [PATCH 7/7] verify same elementtype

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td |  1 +
 mlir/lib/Dialect/MPI/IR/MPIOps.cpp         | 15 +++++++++++++++
 2 files changed, 16 insertions(+)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 0ed2a98545870..fb0192ba748ad 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -365,6 +365,7 @@ def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
   let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
                        "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
                        "(`->` type($retval)^)?";
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 6cca853071dc2..e5e09e28998ba 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -15,8 +15,23 @@
 using namespace mlir;
 using namespace mlir::mpi;
 
+//===----------------------------------------------------------------------===//
+// Verifiers
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::mpi::ReduceScatterBlockOp::verify() {
+  if (getSendbuf().getType().getElementType() !=
+      getRecvbuf().getType().getElementType())
+    return emitOpError("sendbuf and recvbuf must have the same element type");
+  return success();
+}
+
 namespace {
 
+//===----------------------------------------------------------------------===//
+// Canonicalization patterns
+//===----------------------------------------------------------------------===//
+
 // If input memref has dynamic shape and is a cast and if the cast's input has
 // static shape, fold the cast's static input into the given operation.
 template <typename OpT>



More information about the Mlir-commits mailing list