[Mlir-commits] [mlir] [mlir][mesh] Add verification and canonicalization for some collectives (PR #74905)

Chengji Yao llvmlistbot at llvm.org
Thu Dec 14 09:43:35 PST 2023


================
@@ -510,35 +537,94 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 
 LogicalResult
 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  return success();
+}
+
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                              MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
 // mesh.gather op
 //===----------------------------------------------------------------------===//
 
 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  auto gatherAxis = getGatherAxis().getSExtValue();
+  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
+                                           getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
+}
+
+void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                           MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.receive op
+// mesh.recv op
 //===----------------------------------------------------------------------===//
 
 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (getSource() && failed(verifyInGroupDevice(
+                         getLoc(), getSourceAttrName(), getSource().value(),
+                         getSourceDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+  return success();
+}
+
+void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                         MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
 // mesh.reduce op
 //===----------------------------------------------------------------------===//
 
 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
----------------
yaochengji wrote:

`getRoot() && failed(verifyInGroupDevice(...))`?

https://github.com/llvm/llvm-project/pull/74905


More information about the Mlir-commits mailing list