[Mlir-commits] [mlir] [MLIR][Shard] Fold all_slice(all_gather(...)) pairs (PR #193906)

Frank Schlimbach llvmlistbot at llvm.org
Wed Apr 29 02:20:45 PDT 2026


================
@@ -131,6 +131,29 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
   }
 };
 
+// Simplify AllSliceOp(AllGatherOp) -> input when both ops share the same grid,
+// grid_axes and axis. all_gather replicates grouped slices along gather_axis
+// and all_slice immediately picks the per-rank slice back out on the same axis.
+struct AllGatherAllSliceSimplification : OpRewritePattern<AllSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AllSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto gatherOp = sliceOp.getInput().getDefiningOp<AllGatherOp>();
+    if (!gatherOp)
+      return failure();
+
+    if (gatherOp.getGrid() != sliceOp.getGrid() ||
+        gatherOp.getGridAxes() != sliceOp.getGridAxes())
+      return failure();
+    if (gatherOp.getGatherAxis() != sliceOp.getSliceAxis())
+      return failure();
----------------
fschlimb wrote:

nit: make it 3 checks or one. I'd prefer the latter.

https://github.com/llvm/llvm-project/pull/193906


More information about the Mlir-commits mailing list