[Mlir-commits] [mlir] [MLIR][Shard] Fold all_gather/all_slice inverse pairs (PR #193906)

Frank Schlimbach llvmlistbot at llvm.org
Thu Apr 30 00:18:21 PDT 2026


================
@@ -28,6 +29,42 @@ namespace shard {
 
 namespace {
 
+template <typename LhsOp, typename RhsOp>
+static bool haveSameGridAndGridAxes(LhsOp lhsOp, RhsOp rhsOp) {
+  return lhsOp.getGrid() == rhsOp.getGrid() &&
+         lhsOp.getGridAxes() == rhsOp.getGridAxes();
+}
+
+static bool areInverseAllGatherAllSlice(AllGatherOp gatherOp,
+                                        AllSliceOp sliceOp) {
+  return haveSameGridAndGridAxes(gatherOp, sliceOp) &&
+         gatherOp.getGatherAxis() == sliceOp.getSliceAxis();
+}
+
+template <typename OuterOp, typename InnerOp>
+static LogicalResult foldInverseAllGatherAllSlice(OuterOp outerOp,
+                                                  InnerOp innerOp,
+                                                  PatternRewriter &rewriter) {
+  if (!innerOp)
----------------
fschlimb wrote:

```suggestion
  if (!innerOp || !outerOp)
```

https://github.com/llvm/llvm-project/pull/193906


More information about the Mlir-commits mailing list