[Mlir-commits] [mlir] [MLIR][Shard] Fold all_gather/all_slice inverse pairs (PR #193906)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 29 13:05:20 PDT 2026
https://github.com/cursor[bot] updated https://github.com/llvm/llvm-project/pull/193906
>From 0e40bd1b567e8177f3fc20601c9140cd622d2479 Mon Sep 17 00:00:00 2001
From: Cursor Agent <cursoragent at cursor.com>
Date: Wed, 29 Apr 2026 19:38:37 +0000
Subject: [PATCH 1/2] [MLIR][Shard] Fold and refactor inverse
all_gather/all_slice patterns
Co-authored-by: zackc6 <zackc6 at users.noreply.github.com>
---
.../lib/Dialect/Shard/Transforms/Simplify.cpp | 59 +++++++++-
mlir/test/Dialect/Shard/simplify.mlir | 102 ++++++++++++++++++
2 files changed, 158 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index 525ff007bc2f6..aa169dda47a1d 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -19,6 +19,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <numeric>
+#include <type_traits>
namespace mlir {
namespace shard {
@@ -28,6 +29,42 @@ namespace shard {
namespace {
+template <typename LhsOp, typename RhsOp>
+static bool haveSameGridAndGridAxes(LhsOp lhsOp, RhsOp rhsOp) {
+ return lhsOp.getGrid() == rhsOp.getGrid() &&
+ lhsOp.getGridAxes() == rhsOp.getGridAxes();
+}
+
+static bool areInverseAllGatherAllSlice(AllGatherOp gatherOp,
+ AllSliceOp sliceOp) {
+ return haveSameGridAndGridAxes(gatherOp, sliceOp) &&
+ gatherOp.getGatherAxis() == sliceOp.getSliceAxis();
+}
+
+template <typename OuterOp, typename InnerOp>
+static LogicalResult foldInverseAllGatherAllSlice(OuterOp outerOp,
+ InnerOp innerOp,
+ PatternRewriter &rewriter) {
+ if (!innerOp)
+ return failure();
+
+ AllGatherOp gatherOp;
+ AllSliceOp sliceOp;
+ if constexpr (std::is_same_v<OuterOp, AllGatherOp>) {
+ gatherOp = outerOp;
+ sliceOp = innerOp;
+ } else {
+ gatherOp = innerOp;
+ sliceOp = outerOp;
+ }
+
+ if (!areInverseAllGatherAllSlice(gatherOp, sliceOp))
+ return failure();
+
+ rewriter.replaceOp(outerOp, innerOp.getInput());
+ return success();
+}
+
// This folding can not be done with an operation's fold method or
// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
// symbol tables.
@@ -117,8 +154,7 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
return failure();
// Both ops must operate on the same grid and grid axes.
- if (reduceOp.getGrid() != sliceOp.getGrid() ||
- reduceOp.getGridAxes() != sliceOp.getGridAxes())
+ if (!haveSameGridAndGridAxes(reduceOp, sliceOp))
return failure();
// Replace with a single ReduceScatterOp.
@@ -131,6 +167,19 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
}
};
+// Simplify all_slice(all_gather(x)) and all_gather(all_slice(x)) to x when
+// both ops share grid, grid_axes, and axis.
+template <typename OuterOp, typename InnerOp>
+struct InverseAllGatherAllSliceSimplification : OpRewritePattern<OuterOp> {
+ using OpRewritePattern<OuterOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OuterOp outerOp,
+ PatternRewriter &rewriter) const override {
+ auto innerOp = outerOp.getInput().template getDefiningOp<InnerOp>();
+ return foldInverseAllGatherAllSlice(outerOp, innerOp, rewriter);
+ }
+};
+
} // namespace
void populateSimplifyPatterns(RewritePatternSet &patterns,
@@ -154,7 +203,11 @@ void populateSimplifyPatterns(RewritePatternSet &patterns,
populateAllReduceEndomorphismSimplifyPatterns<arith::MaxUIOp>(
patterns, ReductionKind::Max);
- patterns.add<AllReduceAllSliceSimplification>(patterns.getContext());
+ patterns.add<
+ AllReduceAllSliceSimplification,
+ InverseAllGatherAllSliceSimplification<AllSliceOp, AllGatherOp>,
+ InverseAllGatherAllSliceSimplification<AllGatherOp, AllSliceOp>>(
+ patterns.getContext());
// TODO: add simplify patterns for all-gather and other collectives.
diff --git a/mlir/test/Dialect/Shard/simplify.mlir b/mlir/test/Dialect/Shard/simplify.mlir
index e5693a288fda6..181ccde98c505 100644
--- a/mlir/test/Dialect/Shard/simplify.mlir
+++ b/mlir/test/Dialect/Shard/simplify.mlir
@@ -1,3 +1,105 @@
+// RUN: mlir-opt %s -shard-simplify | FileCheck %s
+
+shard.grid @grid_ag(shape = 2x2)
+shard.grid @grid_ag_alt(shape = 2x2)
+
+// CHECK-LABEL: func.func @all_gather_all_slice_identity
+func.func @all_gather_all_slice_identity(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+ : tensor<4x4xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 1
+ : tensor<4x8xf32> -> tensor<4x4xf32>
+ // CHECK-NOT: shard.all_gather
+ // CHECK-NOT: shard.all_slice
+ // CHECK: return %arg0 : tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_axis
+func.func @all_gather_all_slice_different_axis(
+ %arg0: tensor<4x4xf32>) -> tensor<2x8xf32> {
+ %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+ : tensor<4x4xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 0
+ : tensor<4x8xf32> -> tensor<2x8xf32>
+ // CHECK: shard.all_gather
+ // CHECK: shard.all_slice
+ return %1 : tensor<2x8xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid_axes
+func.func @all_gather_all_slice_different_grid_axes(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [0] gather_axis = 0
+ : tensor<4x4xf32> -> tensor<8x4xf32>
+ %1 = shard.all_slice %0 on @grid_ag grid_axes = [1] slice_axis = 0
+ : tensor<8x4xf32> -> tensor<4x4xf32>
+ // CHECK: shard.all_gather
+ // CHECK: shard.all_slice
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_gather_all_slice_different_grid
+func.func @all_gather_all_slice_different_grid(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_gather %arg0 on @grid_ag grid_axes = [1] gather_axis = 1
+ : tensor<4x4xf32> -> tensor<4x8xf32>
+ %1 = shard.all_slice %0 on @grid_ag_alt grid_axes = [1] slice_axis = 1
+ : tensor<4x8xf32> -> tensor<4x4xf32>
+ // CHECK: shard.all_gather
+ // CHECK: shard.all_slice
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_identity
+func.func @all_slice_all_gather_identity(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+ : tensor<4x4xf32> -> tensor<4x2xf32>
+ %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 1
+ : tensor<4x2xf32> -> tensor<4x4xf32>
+ // CHECK-NOT: shard.all_slice
+ // CHECK-NOT: shard.all_gather
+ // CHECK: return %arg0 : tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_axis
+func.func @all_slice_all_gather_different_axis(
+ %arg0: tensor<4x4xf32>) -> tensor<8x2xf32> {
+ %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+ : tensor<4x4xf32> -> tensor<4x2xf32>
+ %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 0
+ : tensor<4x2xf32> -> tensor<8x2xf32>
+ // CHECK: shard.all_slice
+ // CHECK: shard.all_gather
+ return %1 : tensor<8x2xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_grid
+func.func @all_slice_all_gather_different_grid(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [1] slice_axis = 1
+ : tensor<4x4xf32> -> tensor<4x2xf32>
+ %1 = shard.all_gather %0 on @grid_ag_alt grid_axes = [1] gather_axis = 1
+ : tensor<4x2xf32> -> tensor<4x4xf32>
+ // CHECK: shard.all_slice
+ // CHECK: shard.all_gather
+ return %1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @all_slice_all_gather_different_grid_axes
+func.func @all_slice_all_gather_different_grid_axes(
+ %arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
+ %0 = shard.all_slice %arg0 on @grid_ag grid_axes = [0] slice_axis = 0
+ : tensor<4x4xf32> -> tensor<2x4xf32>
+ %1 = shard.all_gather %0 on @grid_ag grid_axes = [1] gather_axis = 0
+ : tensor<2x4xf32> -> tensor<4x4xf32>
+ // CHECK: shard.all_slice
+ // CHECK: shard.all_gather
+ return %1 : tensor<4x4xf32>
+}
// RUN: mlir-opt -shard-simplify %s | FileCheck %s
shard.grid @grid0(shape = 4x2)
>From c407ed515fc8947cd960a9afbae0ac89c8f6fcb0 Mon Sep 17 00:00:00 2001
From: Cursor Agent <cursoragent at cursor.com>
Date: Wed, 29 Apr 2026 20:04:15 +0000
Subject: [PATCH 2/2] [MLIR][Shard] Fold and refactor inverse
all_gather/all_slice patterns
Co-authored-by: zackc6 <zackc6 at users.noreply.github.com>
---
mlir/lib/Dialect/Shard/Transforms/Simplify.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
index aa169dda47a1d..acfc020948a30 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplify.cpp
@@ -203,10 +203,9 @@ void populateSimplifyPatterns(RewritePatternSet &patterns,
populateAllReduceEndomorphismSimplifyPatterns<arith::MaxUIOp>(
patterns, ReductionKind::Max);
- patterns.add<
- AllReduceAllSliceSimplification,
- InverseAllGatherAllSliceSimplification<AllSliceOp, AllGatherOp>,
- InverseAllGatherAllSliceSimplification<AllGatherOp, AllSliceOp>>(
+ patterns.add<AllReduceAllSliceSimplification,
+ InverseAllGatherAllSliceSimplification<AllSliceOp, AllGatherOp>,
+ InverseAllGatherAllSliceSimplification<AllGatherOp, AllSliceOp>>(
patterns.getContext());
// TODO: add simplify patterns for all-gather and other collectives.
More information about the Mlir-commits
mailing list