[Mlir-commits] [mlir] [MLIR][Shard] Fold all_gather/all_slice inverse pairs (PR #193906)
Frank Schlimbach
llvmlistbot at llvm.org
Thu Apr 30 00:18:21 PDT 2026
================
@@ -28,6 +29,42 @@ namespace shard {
namespace {
+template <typename LhsOp, typename RhsOp>
+static bool haveSameGridAndGridAxes(LhsOp lhsOp, RhsOp rhsOp) {
+ return lhsOp.getGrid() == rhsOp.getGrid() &&
+ lhsOp.getGridAxes() == rhsOp.getGridAxes();
+}
+
+static bool areInverseAllGatherAllSlice(AllGatherOp gatherOp,
+ AllSliceOp sliceOp) {
+ return haveSameGridAndGridAxes(gatherOp, sliceOp) &&
+ gatherOp.getGatherAxis() == sliceOp.getSliceAxis();
+}
+
+template <typename OuterOp, typename InnerOp>
+static LogicalResult foldInverseAllGatherAllSlice(OuterOp outerOp,
+ InnerOp innerOp,
+ PatternRewriter &rewriter) {
+ if (!innerOp)
----------------
fschlimb wrote:
```suggestion
if (!innerOp || !outerOp)
```
https://github.com/llvm/llvm-project/pull/193906
More information about the Mlir-commits
mailing list