[Mlir-commits] [mlir] [mlir][shard] Add collective simplify rewrites (PR #193982)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 24 07:14:26 PDT 2026


https://github.com/zackc6 created https://github.com/llvm/llvm-project/pull/193982

Fold all_slice(all_gather(x)) and all_gather(reduce_scatter(x)) into their simpler collective forms to remove redundant communication patterns. Add positive and negative shard simplify tests to lock in legality conditions.

>From 4a84299edff4053dae5ddf5e78be118638704750 Mon Sep 17 00:00:00 2001
From: zack <zackchen666 at gmail.com>
Date: Fri, 24 Apr 2026 22:08:59 +0800
Subject: [PATCH] [mlir][shard] Add collective simplify rewrites

Fold all_slice(all_gather(x)) and all_gather(reduce_scatter(x)) into their simpler collective forms to remove redundant communication patterns. Add positive and negative shard simplify tests to lock in legality conditions.
---
 .../lib/Dialect/Shard/Transforms/Simplify.cpp |  63 +++++++-
 mlir/test/Dialect/Shard/simplify.mlir         | 147 ++++++++++++++++++
 2 files changed, 209 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index 525ff007bc2f6..bc11d5d3fcf93 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -131,6 +131,66 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
   }
 };
 
+// Simplify AllSliceOp(AllGatherOp) -> input when both ops share the same grid,
+// grid_axes and gather/slice axis.
+//
+// AllGather concatenates in-group slices along gather_axis and replicates the
+// concatenated result. AllSlice on the same axis then takes each device-local
+// in-group slice from that replicated tensor, i.e. exactly the original input.
+struct AllGatherAllSliceSimplification : OpRewritePattern<AllSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AllSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto gatherOp = sliceOp.getInput().getDefiningOp<AllGatherOp>();
+    if (!gatherOp)
+      return failure();
+
+    if (gatherOp.getGrid() != sliceOp.getGrid() ||
+        gatherOp.getGridAxes() != sliceOp.getGridAxes())
+      return failure();
+
+    if (gatherOp.getGatherAxis() != sliceOp.getSliceAxis())
+      return failure();
+
+    if (gatherOp.getInput().getType() != sliceOp.getResult().getType())
+      return failure();
+
+    rewriter.replaceOp(sliceOp, gatherOp.getInput());
+    return success();
+  }
+};
+
+// Simplify AllGatherOp(ReduceScatterOp) -> AllReduceOp when both ops share the
+// same grid, grid_axes and gather/scatter axis.
+//
+// ReduceScatter computes an element-wise reduction and scatters along a tensor
+// axis. AllGather along the same axis reassembles that full reduced tensor and
+// replicates it to all participants, which is exactly AllReduce.
+struct ReduceScatterAllGatherSimplification : OpRewritePattern<AllGatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AllGatherOp gatherOp,
+                                PatternRewriter &rewriter) const override {
+    auto reduceScatterOp = gatherOp.getInput().getDefiningOp<ReduceScatterOp>();
+    if (!reduceScatterOp)
+      return failure();
+
+    if (reduceScatterOp.getGrid() != gatherOp.getGrid() ||
+        reduceScatterOp.getGridAxes() != gatherOp.getGridAxes())
+      return failure();
+
+    if (reduceScatterOp.getScatterDim() != gatherOp.getGatherAxis())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<AllReduceOp>(
+        gatherOp, gatherOp.getResult().getType(), gatherOp.getGridAttr(),
+        gatherOp.getGridAxesAttr(), reduceScatterOp.getInput(),
+        reduceScatterOp.getReductionAttr());
+    return success();
+  }
+};
+
 } // namespace
 
 void populateSimplifyPatterns(RewritePatternSet &patterns,
@@ -154,7 +214,8 @@ void populateSimplifyPatterns(RewritePatternSet &patterns,
   populateAllReduceEndomorphismSimplifyPatterns<arith::MaxUIOp>(
       patterns, ReductionKind::Max);
 
-  patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
+  patterns.add<AllReduceAllSliceSimplification, AllGatherAllSliceSimplification,
+               ReduceScatterAllGatherSimplification>(patterns.getContext());
 
   // TODO: add simplify patterns for all-gather and other collectives.
 
diff --git a/mlir/test/Dialect/Shard/simplify.mlir b/mlir/test/Dialect/Shard/simplify.mlir
index e5693a288fda6..5f8b1b0fac83a 100644
--- a/mlir/test/Dialect/Shard/simplify.mlir
+++ b/mlir/test/Dialect/Shard/simplify.mlir
@@ -260,3 +260,150 @@ func.func @all_reduce_all_slice_type_promotion(
   // CHECK: return %[[RS]]
   return %1 : tensor<1x8xf64>
 }
+
+// -----
+// AllGatherOp + AllSliceOp -> input tests
+// -----
+
+// Basic inverse case: all_slice(all_gather(x)) with matching grid/axes/axis.
+// CHECK-LABEL: func.func @all_gather_all_slice_to_input
+func.func @all_gather_all_slice_to_input(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<1x8xf32>
+    %arg0: tensor<1x8xf32>) -> tensor<1x8xf32> {
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
+    : tensor<1x8xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0
+    : tensor<4x8xf32> -> tensor<1x8xf32>
+  // CHECK-NOT: shard.all_gather
+  // CHECK-NOT: shard.all_slice
+  // CHECK: return %[[ARG0]]
+  return %1 : tensor<1x8xf32>
+}
+
+// Do not fold if gather/slice grid axes differ.
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid_axes
+func.func @all_gather_all_slice_different_grid_axes(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<1x8xf32>
+    %arg0: tensor<1x8xf32>) -> tensor<2x8xf32> {
+  // CHECK: %[[AG:.*]] = shard.all_gather %[[ARG0]] on @grid0 grid_axes = [0] gather_axis = 0
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
+    : tensor<1x8xf32> -> tensor<4x8xf32>
+  // CHECK: %[[AS:.*]] = shard.all_slice %[[AG]] on @grid0 grid_axes = [1] slice_axis = 0
+  %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 0
+    : tensor<4x8xf32> -> tensor<2x8xf32>
+  // CHECK: return %[[AS]]
+  return %1 : tensor<2x8xf32>
+}
+
+// Do not fold if gather/slice grids differ.
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid
+func.func @all_gather_all_slice_different_grid(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<1x8xf32>
+    %arg0: tensor<1x8xf32>) -> tensor<1x8xf32> {
+  // CHECK: %[[AG:.*]] = shard.all_gather %[[ARG0]] on @grid0 grid_axes = [0] gather_axis = 0
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
+    : tensor<1x8xf32> -> tensor<4x8xf32>
+  // CHECK: %[[AS:.*]] = shard.all_slice %[[AG]] on @grid1 grid_axes = [0] slice_axis = 0
+  %1 = shard.all_slice %0 on @grid1 grid_axes = [0] slice_axis = 0
+    : tensor<4x8xf32> -> tensor<1x8xf32>
+  // CHECK: return %[[AS]]
+  return %1 : tensor<1x8xf32>
+}
+
+// Do not fold if gather/slice tensor axes differ.
+// CHECK-LABEL: func.func @all_gather_all_slice_different_tensor_axes
+func.func @all_gather_all_slice_different_tensor_axes(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<2x2xf32>
+    %arg0: tensor<2x2xf32>) -> tensor<4x1xf32> {
+  // CHECK: %[[AG:.*]] = shard.all_gather %[[ARG0]] on @grid0 grid_axes = [1] gather_axis = 0
+  %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 0
+    : tensor<2x2xf32> -> tensor<4x2xf32>
+  // CHECK: %[[AS:.*]] = shard.all_slice %[[AG]] on @grid0 grid_axes = [1] slice_axis = 1
+  %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1
+    : tensor<4x2xf32> -> tensor<4x1xf32>
+  // CHECK: return %[[AS]]
+  return %1 : tensor<4x1xf32>
+}
+
+// -----
+// ReduceScatterOp + AllGatherOp -> AllReduceOp tests
+// -----
+
+// Basic case: all_gather(reduce_scatter(x)) with matching grid/axes/axis folds
+// into all_reduce.
+// CHECK-LABEL: func.func @reduce_scatter_all_gather_to_all_reduce
+func.func @reduce_scatter_all_gather_to_all_reduce(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
+    : tensor<4x8xf32> -> tensor<1x8xf32>
+  %1 = shard.all_gather %0 on @grid0 grid_axes = [0] gather_axis = 0
+    : tensor<1x8xf32> -> tensor<4x8xf32>
+  // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
+  // CHECK-SAME: : tensor<4x8xf32> -> tensor<4x8xf32>
+  // CHECK: return %[[AR]]
+  return %1 : tensor<4x8xf32>
+}
+
+// Verify reduction kind is preserved through the rewrite.
+// CHECK-LABEL: func.func @reduce_scatter_all_gather_preserve_reduction
+func.func @reduce_scatter_all_gather_preserve_reduction(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] reduction = max scatter_dim = 0
+    : tensor<4x8xf32> -> tensor<1x8xf32>
+  %1 = shard.all_gather %0 on @grid0 grid_axes = [0] gather_axis = 0
+    : tensor<1x8xf32> -> tensor<4x8xf32>
+  // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] reduction = max
+  // CHECK-SAME: : tensor<4x8xf32> -> tensor<4x8xf32>
+  // CHECK: return %[[AR]]
+  return %1 : tensor<4x8xf32>
+}
+
+// Do not fold if reduce-scatter/all-gather grid axes differ.
+// CHECK-LABEL: func.func @reduce_scatter_all_gather_different_grid_axes
+func.func @reduce_scatter_all_gather_different_grid_axes(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<2x8xf32> {
+  // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
+    : tensor<4x8xf32> -> tensor<1x8xf32>
+  // CHECK: %[[AG:.*]] = shard.all_gather %[[RS]] on @grid0 grid_axes = [1] gather_axis = 0
+  %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 0
+    : tensor<1x8xf32> -> tensor<2x8xf32>
+  // CHECK: return %[[AG]]
+  return %1 : tensor<2x8xf32>
+}
+
+// Do not fold if reduce-scatter/all-gather grids differ.
+// CHECK-LABEL: func.func @reduce_scatter_all_gather_different_grid
+func.func @reduce_scatter_all_gather_different_grid(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<2x8xf32> {
+  // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid1 grid_axes = [0] scatter_dim = 0
+  %0 = shard.reduce_scatter %arg0 on @grid1 grid_axes = [0] scatter_dim = 0
+    : tensor<4x8xf32> -> tensor<1x8xf32>
+  // CHECK: %[[AG:.*]] = shard.all_gather %[[RS]] on @grid0 grid_axes = [1] gather_axis = 0
+  %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 0
+    : tensor<1x8xf32> -> tensor<2x8xf32>
+  // CHECK: return %[[AG]]
+  return %1 : tensor<2x8xf32>
+}
+
+// Do not fold if scatter/gather tensor axes differ.
+// CHECK-LABEL: func.func @reduce_scatter_all_gather_different_tensor_axes
+func.func @reduce_scatter_all_gather_different_tensor_axes(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+    %arg0: tensor<4x8xf32>) -> tensor<2x8xf32> {
+  // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [1] scatter_dim = 0
+  %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [1] scatter_dim = 0
+    : tensor<4x8xf32> -> tensor<2x8xf32>
+  // CHECK: %[[AG:.*]] = shard.all_gather %[[RS]] on @grid0 grid_axes = [1] gather_axis = 1
+  %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1
+    : tensor<2x8xf32> -> tensor<2x16xf32>
+  // Keep function result type simple by slicing back.
+  %2 = tensor.extract_slice %1[0, 0] [2, 8] [1, 1]
+    : tensor<2x16xf32> to tensor<2x8xf32>
+  // CHECK: return %{{.*}}
+  return %2 : tensor<2x8xf32>
+}



More information about the Mlir-commits mailing list