[Mlir-commits] [mlir] [mlir][shard] Add collective simplify rewrites (PR #193982)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 24 07:14:26 PDT 2026
https://github.com/zackc6 created https://github.com/llvm/llvm-project/pull/193982
Fold all_slice(all_gather(x)) and all_gather(reduce_scatter(x)) into their simpler collective forms to remove redundant communication patterns. Add positive and negative shard simplify tests to lock in legality conditions.
>From 4a84299edff4053dae5ddf5e78be118638704750 Mon Sep 17 00:00:00 2001
From: zack <zackchen666 at gmail.com>
Date: Fri, 24 Apr 2026 22:08:59 +0800
Subject: [PATCH] [mlir][shard] Add collective simplify rewrites
Fold all_slice(all_gather(x)) and all_gather(reduce_scatter(x)) into their simpler collective forms to remove redundant communication patterns. Add positive and negative shard simplify tests to lock in legality conditions.
---
.../lib/Dialect/Shard/Transforms/Simplify.cpp | 63 +++++++-
mlir/test/Dialect/Shard/simplify.mlir | 147 ++++++++++++++++++
2 files changed, 209 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index 525ff007bc2f6..bc11d5d3fcf93 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -131,6 +131,66 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
}
};
+// Simplify AllSliceOp(AllGatherOp) -> input when both ops share the same grid,
+// grid_axes and gather/slice axis.
+//
+// AllGather concatenates in-group slices along gather_axis and replicates the
+// concatenated result. AllSlice on the same axis then takes each device-local
+// in-group slice from that replicated tensor, i.e. exactly the original input.
+struct AllGatherAllSliceSimplification : OpRewritePattern<AllSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AllSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto gatherOp = sliceOp.getInput().getDefiningOp<AllGatherOp>();
+ if (!gatherOp)
+ return failure();
+
+ if (gatherOp.getGrid() != sliceOp.getGrid() ||
+ gatherOp.getGridAxes() != sliceOp.getGridAxes())
+ return failure();
+
+ if (gatherOp.getGatherAxis() != sliceOp.getSliceAxis())
+ return failure();
+
+ if (gatherOp.getInput().getType() != sliceOp.getResult().getType())
+ return failure();
+
+ rewriter.replaceOp(sliceOp, gatherOp.getInput());
+ return success();
+ }
+};
+
+// Simplify AllGatherOp(ReduceScatterOp) -> AllReduceOp when both ops share the
+// same grid, grid_axes and gather/scatter axis.
+//
+// ReduceScatter computes an element-wise reduction and scatters along a tensor
+// axis. AllGather along the same axis reassembles that full reduced tensor and
+// replicates it to all participants, which is exactly AllReduce.
+struct ReduceScatterAllGatherSimplification : OpRewritePattern<AllGatherOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AllGatherOp gatherOp,
+ PatternRewriter &rewriter) const override {
+ auto reduceScatterOp = gatherOp.getInput().getDefiningOp<ReduceScatterOp>();
+ if (!reduceScatterOp)
+ return failure();
+
+ if (reduceScatterOp.getGrid() != gatherOp.getGrid() ||
+ reduceScatterOp.getGridAxes() != gatherOp.getGridAxes())
+ return failure();
+
+ if (reduceScatterOp.getScatterDim() != gatherOp.getGatherAxis())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<AllReduceOp>(
+ gatherOp, gatherOp.getResult().getType(), gatherOp.getGridAttr(),
+ gatherOp.getGridAxesAttr(), reduceScatterOp.getInput(),
+ reduceScatterOp.getReductionAttr());
+ return success();
+ }
+};
+
} // namespace
void populateSimplifyPatterns(RewritePatternSet &patterns,
@@ -154,7 +214,8 @@ void populateSimplifyPatterns(RewritePatternSet &patterns,
populateAllReduceEndomorphismSimplifyPatterns<arith::MaxUIOp>(
patterns, ReductionKind::Max);
- patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
+ patterns.add<AllReduceAllSliceSimplification, AllGatherAllSliceSimplification,
+ ReduceScatterAllGatherSimplification>(patterns.getContext());
// TODO: add simplify patterns for all-gather and other collectives.
diff --git a/mlir/test/Dialect/Shard/simplify.mlir b/mlir/test/Dialect/Shard/simplify.mlir
index e5693a288fda6..5f8b1b0fac83a 100644
--- a/mlir/test/Dialect/Shard/simplify.mlir
+++ b/mlir/test/Dialect/Shard/simplify.mlir
@@ -260,3 +260,150 @@ func.func @all_reduce_all_slice_type_promotion(
// CHECK: return %[[RS]]
return %1 : tensor<1x8xf64>
}
+
+// -----
+// AllGatherOp + AllSliceOp -> input tests
+// -----
+
+// Basic inverse case: all_slice(all_gather(x)) with matching grid/axes/axis.
+// CHECK-LABEL: func.func @all_gather_all_slice_to_input
+func.func @all_gather_all_slice_to_input(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<1x8xf32>
+ %arg0: tensor<1x8xf32>) -> tensor<1x8xf32> {
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
+ : tensor<1x8xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid0 grid_axes = [0] slice_axis = 0
+ : tensor<4x8xf32> -> tensor<1x8xf32>
+ // CHECK-NOT: shard.all_gather
+ // CHECK-NOT: shard.all_slice
+ // CHECK: return %[[ARG0]]
+ return %1 : tensor<1x8xf32>
+}
+
+// Do not fold if gather/slice grid axes differ.
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid_axes
+func.func @all_gather_all_slice_different_grid_axes(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<1x8xf32>
+ %arg0: tensor<1x8xf32>) -> tensor<2x8xf32> {
+ // CHECK: %[[AG:.*]] = shard.all_gather %[[ARG0]] on @grid0 grid_axes = [0] gather_axis = 0
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
+ : tensor<1x8xf32> -> tensor<4x8xf32>
+ // CHECK: %[[AS:.*]] = shard.all_slice %[[AG]] on @grid0 grid_axes = [1] slice_axis = 0
+ %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 0
+ : tensor<4x8xf32> -> tensor<2x8xf32>
+ // CHECK: return %[[AS]]
+ return %1 : tensor<2x8xf32>
+}
+
+// Do not fold if gather/slice grids differ.
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid
+func.func @all_gather_all_slice_different_grid(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<1x8xf32>
+ %arg0: tensor<1x8xf32>) -> tensor<1x8xf32> {
+ // CHECK: %[[AG:.*]] = shard.all_gather %[[ARG0]] on @grid0 grid_axes = [0] gather_axis = 0
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
+ : tensor<1x8xf32> -> tensor<4x8xf32>
+ // CHECK: %[[AS:.*]] = shard.all_slice %[[AG]] on @grid1 grid_axes = [0] slice_axis = 0
+ %1 = shard.all_slice %0 on @grid1 grid_axes = [0] slice_axis = 0
+ : tensor<4x8xf32> -> tensor<1x8xf32>
+ // CHECK: return %[[AS]]
+ return %1 : tensor<1x8xf32>
+}
+
+// Do not fold if gather/slice tensor axes differ.
+// CHECK-LABEL: func.func @all_gather_all_slice_different_tensor_axes
+func.func @all_gather_all_slice_different_tensor_axes(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<2x2xf32>
+ %arg0: tensor<2x2xf32>) -> tensor<4x1xf32> {
+ // CHECK: %[[AG:.*]] = shard.all_gather %[[ARG0]] on @grid0 grid_axes = [1] gather_axis = 0
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 0
+ : tensor<2x2xf32> -> tensor<4x2xf32>
+ // CHECK: %[[AS:.*]] = shard.all_slice %[[AG]] on @grid0 grid_axes = [1] slice_axis = 1
+ %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1
+ : tensor<4x2xf32> -> tensor<4x1xf32>
+ // CHECK: return %[[AS]]
+ return %1 : tensor<4x1xf32>
+}
+
+// -----
+// ReduceScatterOp + AllGatherOp -> AllReduceOp tests
+// -----
+
+// Basic case: all_gather(reduce_scatter(x)) with matching grid/axes/axis folds
+// into all_reduce.
+// CHECK-LABEL: func.func @reduce_scatter_all_gather_to_all_reduce
+func.func @reduce_scatter_all_gather_to_all_reduce(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
+ : tensor<4x8xf32> -> tensor<1x8xf32>
+ %1 = shard.all_gather %0 on @grid0 grid_axes = [0] gather_axis = 0
+ : tensor<1x8xf32> -> tensor<4x8xf32>
+ // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
+ // CHECK-SAME: : tensor<4x8xf32> -> tensor<4x8xf32>
+ // CHECK: return %[[AR]]
+ return %1 : tensor<4x8xf32>
+}
+
+// Verify reduction kind is preserved through the rewrite.
+// CHECK-LABEL: func.func @reduce_scatter_all_gather_preserve_reduction
+func.func @reduce_scatter_all_gather_preserve_reduction(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] reduction = max scatter_dim = 0
+ : tensor<4x8xf32> -> tensor<1x8xf32>
+ %1 = shard.all_gather %0 on @grid0 grid_axes = [0] gather_axis = 0
+ : tensor<1x8xf32> -> tensor<4x8xf32>
+ // CHECK: %[[AR:.*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] reduction = max
+ // CHECK-SAME: : tensor<4x8xf32> -> tensor<4x8xf32>
+ // CHECK: return %[[AR]]
+ return %1 : tensor<4x8xf32>
+}
+
+// Do not fold if reduce-scatter/all-gather grid axes differ.
+// CHECK-LABEL: func.func @reduce_scatter_all_gather_different_grid_axes
+func.func @reduce_scatter_all_gather_different_grid_axes(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<2x8xf32> {
+ // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [0] scatter_dim = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_dim = 0
+ : tensor<4x8xf32> -> tensor<1x8xf32>
+ // CHECK: %[[AG:.*]] = shard.all_gather %[[RS]] on @grid0 grid_axes = [1] gather_axis = 0
+ %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 0
+ : tensor<1x8xf32> -> tensor<2x8xf32>
+ // CHECK: return %[[AG]]
+ return %1 : tensor<2x8xf32>
+}
+
+// Do not fold if reduce-scatter/all-gather grids differ.
+// CHECK-LABEL: func.func @reduce_scatter_all_gather_different_grid
+func.func @reduce_scatter_all_gather_different_grid(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<2x8xf32> {
+ // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid1 grid_axes = [0] scatter_dim = 0
+ %0 = shard.reduce_scatter %arg0 on @grid1 grid_axes = [0] scatter_dim = 0
+ : tensor<4x8xf32> -> tensor<1x8xf32>
+ // CHECK: %[[AG:.*]] = shard.all_gather %[[RS]] on @grid0 grid_axes = [1] gather_axis = 0
+ %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 0
+ : tensor<1x8xf32> -> tensor<2x8xf32>
+ // CHECK: return %[[AG]]
+ return %1 : tensor<2x8xf32>
+}
+
+// Do not fold if scatter/gather tensor axes differ.
+// CHECK-LABEL: func.func @reduce_scatter_all_gather_different_tensor_axes
+func.func @reduce_scatter_all_gather_different_tensor_axes(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<4x8xf32>
+ %arg0: tensor<4x8xf32>) -> tensor<2x8xf32> {
+ // CHECK: %[[RS:.*]] = shard.reduce_scatter %[[ARG0]] on @grid0 grid_axes = [1] scatter_dim = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [1] scatter_dim = 0
+ : tensor<4x8xf32> -> tensor<2x8xf32>
+ // CHECK: %[[AG:.*]] = shard.all_gather %[[RS]] on @grid0 grid_axes = [1] gather_axis = 1
+ %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1
+ : tensor<2x8xf32> -> tensor<2x16xf32>
+ // Keep function result type simple by slicing back.
+ %2 = tensor.extract_slice %1[0, 0] [2, 8] [1, 1]
+ : tensor<2x16xf32> to tensor<2x8xf32>
+ // CHECK: return %{{.*}}
+ return %2 : tensor<2x8xf32>
+}
More information about the Mlir-commits
mailing list