[Mlir-commits] [mlir] [MLIR][Shard] Fold all_slice(all_gather(...)) pairs (PR #193906)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 24 00:12:42 PDT 2026
https://github.com/zackc6 created https://github.com/llvm/llvm-project/pull/193906
Add a simplify pattern that replaces all_slice(all_gather(x)) with x when grid, grid axes, and gather/slice axis match. Extend shard simplify tests with positive and negative coverage for the new fold.
>From 73d021b0c9d81333433382a91bedb39b660f06b6 Mon Sep 17 00:00:00 2001
From: zack <zackchen666 at gmail.com>
Date: Fri, 24 Apr 2026 14:54:34 +0800
Subject: [PATCH] [MLIR][Shard] Fold all_slice(all_gather(...)) pairs
Add a simplify pattern that replaces all_slice(all_gather(x)) with x when grid, grid axes, and gather/slice axis match. Extend shard simplify tests with positive and negative coverage for the new fold.
---
.../lib/Dialect/Shard/Transforms/Simplify.cpp | 24 ++++++++++++++++
mlir/test/Dialect/Shard/simplify.mlir | 28 +++++++++++++++++++
2 files changed, 52 insertions(+)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index 525ff007bc2f6..904080cf2836a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -131,6 +131,29 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
}
};
+// Simplify AllSliceOp(AllGatherOp) -> input when both ops share the same grid,
+// grid_axes and axis. all_gather replicates grouped slices along gather_axis and
+// all_slice immediately picks the per-rank slice back out on the same axis.
+struct AllGatherAllSliceSimplification : OpRewritePattern<AllSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AllSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto gatherOp = sliceOp.getInput().getDefiningOp<AllGatherOp>();
+ if (!gatherOp)
+ return failure();
+
+ if (gatherOp.getGrid() != sliceOp.getGrid() ||
+ gatherOp.getGridAxes() != sliceOp.getGridAxes())
+ return failure();
+ if (gatherOp.getGatherAxis() != sliceOp.getSliceAxis())
+ return failure();
+
+ rewriter.replaceOp(sliceOp, gatherOp.getInput());
+ return success();
+ }
+};
+
} // namespace
void populateSimplifyPatterns(RewritePatternSet &patterns,
@@ -155,6 +178,7 @@ void populateSimplifyPatterns(RewritePatternSet &patterns,
patterns, ReductionKind::Max);
patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
+ patterns.add<AllGatherAllSliceSimplification>(patterns.getContext());
// TODO: add simplify patterns for all-gather and other collectives.
diff --git a/mlir/test/Dialect/Shard/simplify.mlir b/mlir/test/Dialect/Shard/simplify.mlir
index e5693a288fda6..1cc243112b741 100644
--- a/mlir/test/Dialect/Shard/simplify.mlir
+++ b/mlir/test/Dialect/Shard/simplify.mlir
@@ -1,3 +1,31 @@
+// RUN: mlir-opt %s -shard-simplify | FileCheck %s
+
+shard.grid @grid_ag(shape = 2x2)
+
+// CHECK-LABEL: func.func @all_gather_all_slice_identity
+func.func @all_gather_all_slice_identity(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+ : tensor<4x4xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 1
+ : tensor<4x8xf32> -> tensor<4x4xf32>
+ // CHECK-NOT: shard.all_gather
+ // CHECK-NOT: shard.all_slice
+ // CHECK: return %arg0 : tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_axis
+func.func @all_gather_all_slice_different_axis(
+ %arg0: tensor<4x4xf32>) -> tensor<2x8xf32> {
+ %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+ : tensor<4x4xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 0
+ : tensor<4x8xf32> -> tensor<2x8xf32>
+ // CHECK: shard.all_gather
+ // CHECK: shard.all_slice
+ return %1 : tensor<2x8xf32>
+}
// RUN: mlir-opt -shard-simplify %s | FileCheck %s
shard.grid @grid0(shape = 4x2)
More information about the Mlir-commits
mailing list