[Mlir-commits] [mlir] [mlir][mesh] Add verification and canonicalization for some collectives (PR #74905)
Chengji Yao
llvmlistbot at llvm.org
Thu Dec 14 09:43:35 PST 2023
================
@@ -510,35 +537,94 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
LogicalResult
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ return success();
+}
+
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
}
//===----------------------------------------------------------------------===//
// mesh.gather op
//===----------------------------------------------------------------------===//
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ auto gatherAxis = getGatherAxis().getSExtValue();
+ return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
+ getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
}
//===----------------------------------------------------------------------===//
-// mesh.receive op
+// mesh.recv op
//===----------------------------------------------------------------------===//
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (getSource() && failed(verifyInGroupDevice(
+ getLoc(), getSourceAttrName(), getSource().value(),
+ getSourceDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+ return success();
+}
+
+void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
}
//===----------------------------------------------------------------------===//
// mesh.reduce op
//===----------------------------------------------------------------------===//
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
----------------
yaochengji wrote:
`getRoot() && failed(verifyInGroupDevice(...))`?
https://github.com/llvm/llvm-project/pull/74905
More information about the Mlir-commits
mailing list