[Mlir-commits] [mlir] [MLIR][Shard] Fold all_gather/all_slice inverse pairs (PR #193906)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 29 13:05:20 PDT 2026


https://github.com/cursor[bot] updated https://github.com/llvm/llvm-project/pull/193906

>From 0e40bd1b567e8177f3fc20601c9140cd622d2479 Mon Sep 17 00:00:00 2001
From: Cursor Agent <cursoragent at cursor.com>
Date: Wed, 29 Apr 2026 19:38:37 +0000
Subject: [PATCH 1/2] [MLIR][Shard] Fold and refactor inverse
 all_gather/all_slice patterns

Co-authored-by: zackc6 <zackc6 at users.noreply.github.com>
---
 .../lib/Dialect/Shard/Transforms/Simplify.cpp |  59 +++++++++-
 mlir/test/Dialect/Shard/simplify.mlir         | 102 ++++++++++++++++++
 2 files changed, 158 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index 525ff007bc2f6..aa169dda47a1d 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include <numeric>
+#include <type_traits>
 
 namespace mlir {
 namespace shard {
@@ -28,6 +29,42 @@ namespace shard {
 
 namespace {
 
+template <typename LhsOp, typename RhsOp>
+static bool haveSameGridAndGridAxes(LhsOp lhsOp, RhsOp rhsOp) {
+  return lhsOp.getGrid() == rhsOp.getGrid() &&
+         lhsOp.getGridAxes() == rhsOp.getGridAxes();
+}
+
+static bool areInverseAllGatherAllSlice(AllGatherOp gatherOp,
+                                        AllSliceOp sliceOp) {
+  return haveSameGridAndGridAxes(gatherOp, sliceOp) &&
+         gatherOp.getGatherAxis() == sliceOp.getSliceAxis();
+}
+
+template <typename OuterOp, typename InnerOp>
+static LogicalResult foldInverseAllGatherAllSlice(OuterOp outerOp,
+                                                  InnerOp innerOp,
+                                                  PatternRewriter &rewriter) {
+  if (!innerOp)
+    return failure();
+
+  AllGatherOp gatherOp;
+  AllSliceOp sliceOp;
+  if constexpr (std::is_same_v<OuterOp, AllGatherOp>) {
+    gatherOp = outerOp;
+    sliceOp = innerOp;
+  } else {
+    gatherOp = innerOp;
+    sliceOp = outerOp;
+  }
+
+  if (!areInverseAllGatherAllSlice(gatherOp, sliceOp))
+    return failure();
+
+  rewriter.replaceOp(outerOp, innerOp.getInput());
+  return success();
+}
+
 // This folding can not be done with an operation's fold method or
 // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
 // symbol tables.
@@ -117,8 +154,7 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
       return failure();
 
     // Both ops must operate on the same grid and grid axes.
-    if (reduceOp.getGrid() != sliceOp.getGrid() ||
-        reduceOp.getGridAxes() != sliceOp.getGridAxes())
+    if (!haveSameGridAndGridAxes(reduceOp, sliceOp))
       return failure();
 
     // Replace with a single ReduceScatterOp.
@@ -131,6 +167,19 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
   }
 };
 
+// Simplify all_slice(all_gather(x)) and all_gather(all_slice(x)) to x when
+// both ops share grid, grid_axes, and axis.
+template <typename OuterOp, typename InnerOp>
+struct InverseAllGatherAllSliceSimplification : OpRewritePattern<OuterOp> {
+  using OpRewritePattern<OuterOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OuterOp outerOp,
+                                PatternRewriter &rewriter) const override {
+    auto innerOp = outerOp.getInput().template getDefiningOp<InnerOp>();
+    return foldInverseAllGatherAllSlice(outerOp, innerOp, rewriter);
+  }
+};
+
 } // namespace
 
 void populateSimplifyPatterns(RewritePatternSet &patterns,
@@ -154,7 +203,11 @@ void populateSimplifyPatterns(RewritePatternSet &patterns,
   populateAllReduceEndomorphismSimplifyPatterns<arith::MaxUIOp>(
       patterns, ReductionKind::Max);
 
-  patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
+  patterns.add<
+      AllReduceAllSliceSimplification,
+      InverseAllGatherAllSliceSimplification<AllSliceOp, AllGatherOp>,
+      InverseAllGatherAllSliceSimplification<AllGatherOp, AllSliceOp>>(
+      patterns.getContext());
 
   // TODO: add simplify patterns for all-gather and other collectives.
 
diff --git a/mlir/test/Dialect/Shard/simplify.mlir b/mlir/test/Dialect/Shard/simplify.mlir
index e5693a288fda6..181ccde98c505 100644
--- a/mlir/test/Dialect/Shard/simplify.mlir
+++ b/mlir/test/Dialect/Shard/simplify.mlir
@@ -1,3 +1,105 @@
+// RUN: mlir-opt %s -shard-simplify | FileCheck %s
+
+shard.grid @grid_ag(shape = 2x2)
+shard.grid @grid_ag_alt(shape = 2x2)
+
+// CHECK-LABEL: func.func @all_gather_all_slice_identity
+func.func @all_gather_all_slice_identity(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+    : tensor<4x4xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 1
+    : tensor<4x8xf32> -> tensor<4x4xf32>
+  // CHECK-NOT: shard.all_gather
+  // CHECK-NOT: shard.all_slice
+  // CHECK: return %arg0 : tensor<4x4xf32>
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_axis
+func.func @all_gather_all_slice_different_axis(
+    %arg0: tensor<4x4xf32>) -> tensor<2x8xf32> {
+  %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+    : tensor<4x4xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 0
+    : tensor<4x8xf32> -> tensor<2x8xf32>
+  // CHECK: shard.all_gather
+  // CHECK: shard.all_slice
+  return %1 : tensor<2x8xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid_axes
+func.func @all_gather_all_slice_different_grid_axes(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [0] gather_axis = 0
+    : tensor<4x4xf32> -> tensor<8x4xf32>
+  %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 0
+    : tensor<8x4xf32> -> tensor<4x4xf32>
+  // CHECK: shard.all_gather
+  // CHECK: shard.all_slice
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid
+func.func @all_gather_all_slice_different_grid(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+    : tensor<4x4xf32> -> tensor<4x8xf32>
+  %1 = shard.all_slice %0 on @grid_ag_alt grid_axes = [1] slice_axis = 1
+    : tensor<4x8xf32> -> tensor<4x4xf32>
+  // CHECK: shard.all_gather
+  // CHECK: shard.all_slice
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_identity
+func.func @all_slice_all_gather_identity(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+    : tensor<4x4xf32> -> tensor<4x2xf32>
+  %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 1
+    : tensor<4x2xf32> -> tensor<4x4xf32>
+  // CHECK-NOT: shard.all_slice
+  // CHECK-NOT: shard.all_gather
+  // CHECK: return %arg0 : tensor<4x4xf32>
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_axis
+func.func @all_slice_all_gather_different_axis(
+    %arg0: tensor<4x4xf32>) -> tensor<8x2xf32> {
+  %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+    : tensor<4x4xf32> -> tensor<4x2xf32>
+  %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 0
+    : tensor<4x2xf32> -> tensor<8x2xf32>
+  // CHECK: shard.all_slice
+  // CHECK: shard.all_gather
+  return %1 : tensor<8x2xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_grid
+func.func @all_slice_all_gather_different_grid(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+    : tensor<4x4xf32> -> tensor<4x2xf32>
+  %1 = shard.all_gather %0 on @grid_ag_alt grid_axes = [1] gather_axis = 1
+    : tensor<4x2xf32> -> tensor<4x4xf32>
+  // CHECK: shard.all_slice
+  // CHECK: shard.all_gather
+  return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_grid_axes
+func.func @all_slice_all_gather_different_grid_axes(
+    %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [0] slice_axis = 0
+    : tensor<4x4xf32> -> tensor<2x4xf32>
+  %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 0
+    : tensor<2x4xf32> -> tensor<4x4xf32>
+  // CHECK: shard.all_slice
+  // CHECK: shard.all_gather
+  return %1 : tensor<4x4xf32>
+}
 // RUN: mlir-opt -shard-simplify %s | FileCheck %s
 
 shard.grid @grid0(shape = 4x2)

>From c407ed515fc8947cd960a9afbae0ac89c8f6fcb0 Mon Sep 17 00:00:00 2001
From: Cursor Agent <cursoragent at cursor.com>
Date: Wed, 29 Apr 2026 20:04:15 +0000
Subject: [PATCH 2/2] [MLIR][Shard] Fold and refactor inverse
 all_gather/all_slice patterns

Co-authored-by: zackc6 <zackc6 at users.noreply.github.com>
---
 mlir/lib/Dialect/Shard/Transforms/Simplify.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index aa169dda47a1d..acfc020948a30 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -203,10 +203,9 @@ void populateSimplifyPatterns(RewritePatternSet &patterns,
   populateAllReduceEndomorphismSimplifyPatterns<arith::MaxUIOp>(
       patterns, ReductionKind::Max);
 
-  patterns.add<
-      AllReduceAllSliceSimplification,
-      InverseAllGatherAllSliceSimplification<AllSliceOp, AllGatherOp>,
-      InverseAllGatherAllSliceSimplification<AllGatherOp, AllSliceOp>>(
+  patterns.add<AllReduceAllSliceSimplification,
+               InverseAllGatherAllSliceSimplification<AllSliceOp, AllGatherOp>,
+               InverseAllGatherAllSliceSimplification<AllGatherOp, AllSliceOp>>(
       patterns.getContext());
 
   // TODO: add simplify patterns for all-gather and other collectives.



More information about the Mlir-commits mailing list