[Mlir-commits] [mlir] [MLIR][Shard] Fold all_slice(all_gather(...)) pairs (PR #193906)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 24 00:12:42 PDT 2026


https://github.com/zackc6 created https://github.com/llvm/llvm-project/pull/193906

Add a simplify pattern that replaces all_slice(all_gather(x)) with x when grid, grid axes, and gather/slice axis match. Extend shard simplify tests with positive and negative coverage for the new fold.

>From 73d021b0c9d81333433382a91bedb39b660f06b6 Mon Sep 17 00:00:00 2001
From: zack <zackchen666 at gmail.com>
Date: Fri, 24 Apr 2026 14:54:34 +0800
Subject: [PATCH] [MLIR][Shard] Fold all_slice(all_gather(...)) pairs

Add a simplify pattern that replaces all_slice(all_gather(x)) with x when grid, grid axes, and gather/slice axis match. Extend shard simplify tests with positive and negative coverage for the new fold.
---
 .../lib/Dialect/Shard/Transforms/Simplify.cpp | 24 ++++++++++++++++
 mlir/test/Dialect/Shard/simplify.mlir         | 28 +++++++++++++++++++
 2 files changed, 52 insertions(+)

diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index 525ff007bc2f6..904080cf2836a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -131,6 +131,29 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
   }
 };
 
+// Simplify AllSliceOp(AllGatherOp) -> input when both ops share the same grid,
+// grid_axes and axis. all_gather replicates grouped slices along gather_axis and
+// all_slice immediately picks the per-rank slice back out on the same axis.
+struct AllGatherAllSliceSimplification : OpRewritePattern<AllSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AllSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto gatherOp = sliceOp.getInput().getDefiningOp<AllGatherOp>();
+    if (!gatherOp)
+      return failure();
+
+    if (gatherOp.getGrid() != sliceOp.getGrid() ||
+        gatherOp.getGridAxes() != sliceOp.getGridAxes())
+      return failure();
+    if (gatherOp.getGatherAxis() != sliceOp.getSliceAxis())
+      return failure();
+
+    rewriter.replaceOp(sliceOp, gatherOp.getInput());
+    return success();
+  }
+};
+
 } // namespace
 
 void populateSimplifyPatterns(RewritePatternSet &patterns,
@@ -155,6 +178,7 @@ void populateSimplifyPatterns(RewritePatternSet &patterns,
       patterns, ReductionKind::Max);
 
   patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
+  patterns.add<AllGatherAllSliceSimplification>(patterns.getContext());
 
   // TODO: add simplify patterns for all-gather and other collectives.
 
diff --git a/mlir/test/Dialect/Shard/simplify.mlir b/mlir/test/Dialect/Shard/simplify.mlir
index e5693a288fda6..1cc243112b741 100644
--- a/mlir/test/Dialect/Shard/simplify.mlir
+++ b/mlir/test/Dialect/Shard/simplify.mlir
@@ -1,3 +1,31 @@
+// RUN: mlir-opt %s -shard-simplify | FileCheck %s
+
+shard.grid @grid_ag(shape = 2x2)
+
+// CHECK-LABEL: func.func @all_gather_all_slice_identity
+func.func @all_gather_all_slice_identity(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+    : tensor<4x4xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 1
+    : tensor<4x8xf32> -> tensor<4x4xf32>
+  // CHECK-NOT: shard.all_gather
+  // CHECK-NOT: shard.all_slice
+  // CHECK: return %arg0 : tensor<4x4xf32>
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_axis
+func.func @all_gather_all_slice_different_axis(
+    %arg0: tensor<4x4xf32>) -> tensor<2x8xf32> {
+  %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+    : tensor<4x4xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 0
+    : tensor<4x8xf32> -> tensor<2x8xf32>
+  // CHECK: shard.all_gather
+  // CHECK: shard.all_slice
+  return %1 : tensor<2x8xf32>
+}
 // RUN: mlir-opt -shard-simplify %s | FileCheck %s
 
 shard.grid @grid0(shape = 4x2)



More information about the Mlir-commits mailing list