[Mlir-commits] [mlir] [MLIR][Shard] Fold all_gather/all_slice inverse pairs (PR #193906)

Frank Schlimbach llvmlistbot at llvm.org
Thu Apr 30 04:03:44 PDT 2026


================
@@ -28,6 +29,42 @@ namespace shard {
 
 namespace {
 
+template <typename LhsOp, typename RhsOp>
+static bool haveSameGridAndGridAxes(LhsOp lhsOp, RhsOp rhsOp) {
+  return lhsOp.getGrid() == rhsOp.getGrid() &&
+         lhsOp.getGridAxes() == rhsOp.getGridAxes();
+}
+
+static bool areInverseAllGatherAllSlice(AllGatherOp gatherOp,
+                                        AllSliceOp sliceOp) {
+  return haveSameGridAndGridAxes(gatherOp, sliceOp) &&
+         gatherOp.getGatherAxis() == sliceOp.getSliceAxis();
+}
+
+template <typename OuterOp, typename InnerOp>
+static LogicalResult foldInverseAllGatherAllSlice(OuterOp outerOp,
----------------
fschlimb wrote:

Thanks for the refactor.
Can we remove "inverse" from the func name and the pattern class?

https://github.com/llvm/llvm-project/pull/193906


More information about the Mlir-commits mailing list