[Mlir-commits] [mlir] [MLIR][Shard] Fold all_slice(all_gather(...)) pairs (PR #193906)
Frank Schlimbach
llvmlistbot at llvm.org
Wed Apr 29 02:20:45 PDT 2026
================
@@ -131,6 +131,29 @@ struct AllReduceAllSliceSimplification : OpRewritePattern<AllSliceOp> {
}
};
+// Simplify AllSliceOp(AllGatherOp) -> input when both ops share the same grid,
+// grid_axes and axis. all_gather replicates grouped slices along gather_axis
+// and all_slice immediately picks the per-rank slice back out on the same axis.
+struct AllGatherAllSliceSimplification : OpRewritePattern<AllSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AllSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto gatherOp = sliceOp.getInput().getDefiningOp<AllGatherOp>();
+ if (!gatherOp)
+ return failure();
+
+ if (gatherOp.getGrid() != sliceOp.getGrid() ||
+ gatherOp.getGridAxes() != sliceOp.getGridAxes())
+ return failure();
+ if (gatherOp.getGatherAxis() != sliceOp.getSliceAxis())
+ return failure();
----------------
fschlimb wrote:
nit: make it 3 checks or one. I'd prefer the latter.
https://github.com/llvm/llvm-project/pull/193906
More information about the Mlir-commits
mailing list