[Mlir-commits] [mlir] [mlir][mesh] Add verification and canonicalization for some collectives (PR #74905)
Boian Petkantchin
llvmlistbot at llvm.org
Thu Dec 14 11:06:04 PST 2023
================
@@ -567,26 +653,74 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ auto scatterAxis = getScatterAxis().getSExtValue();
+ return verifyScatterOperandAndResultShape(getInput(), getResult(),
+ scatterAxis, getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
}
//===----------------------------------------------------------------------===//
// mesh.send op
//===----------------------------------------------------------------------===//
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
+ getDestination(), getDestinationDynamic(),
+ getMeshAxes(), meshShape))) {
+ return failure();
+ }
+ return success();
+}
+
+void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
}
//===----------------------------------------------------------------------===//
// mesh.shift op
//===----------------------------------------------------------------------===//
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+
+ auto meshAxes = getMeshAxes();
+ auto shiftAxis = getShiftAxis().getZExtValue();
+ if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
+ return emitError() << "Invalid shift axis " << shiftAxis
+ << ". It must be one of the grouping mesh axes.";
+ }
+
+ return success();
+}
+
+void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ // TODO: remove op when offset is 0 or if it is a rotate with and
+ // offset % sift_axis_mesh_dim_size == 0.
----------------
sogartar wrote:
Fixed it.
https://github.com/llvm/llvm-project/pull/74905
More information about the Mlir-commits
mailing list