[Mlir-commits] [mlir] [mlir][mesh] Add verification and canonicalization for some collectives (PR #74905)

Boian Petkantchin llvmlistbot at llvm.org
Thu Dec 14 11:06:04 PST 2023


================
@@ -567,26 +653,74 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 //===----------------------------------------------------------------------===//
 
 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  auto scatterAxis = getScatterAxis().getSExtValue();
+  return verifyScatterOperandAndResultShape(getInput(), getResult(),
+                                            scatterAxis, getMeshAxes(),
+                                            mesh.value().canonicalDimSizes());
+}
+
+void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                            MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
 // mesh.send op
 //===----------------------------------------------------------------------===//
 
 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
+                                 getDestination(), getDestinationDynamic(),
+                                 getMeshAxes(), meshShape))) {
+    return failure();
+  }
+  return success();
+}
+
+void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                         MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
 // mesh.shift op
 //===----------------------------------------------------------------------===//
 
 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+
+  auto meshAxes = getMeshAxes();
+  auto shiftAxis = getShiftAxis().getZExtValue();
+  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
+    return emitError() << "Invalid shift axis " << shiftAxis
+                       << ". It must be one of the grouping mesh axes.";
+  }
+
+  return success();
+}
+
+void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                          MLIRContext *context) {
+  // TODO: remove op when offset is 0 or if it is a rotate with and
+  // offset % sift_axis_mesh_dim_size == 0.
----------------
sogartar wrote:

Fixed it.

https://github.com/llvm/llvm-project/pull/74905


More information about the Mlir-commits mailing list