[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)

Boian Petkantchin llvmlistbot at llvm.org
Tue Nov 14 11:53:11 PST 2023


================
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+  if (!attr) {
+    return std::nullopt;
+  }
+  SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+  canonicalizeSetAsVector(axes);
+  if (axes.empty()) {
+    return std::nullopt;
+  }
+  return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+  AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+      : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+  LogicalResult matchAndRewrite(Op op,
+                                PatternRewriter &rewriter) const override {
+    auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+        op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+    if (!canonicalMeshAxesAttr) {
+      op->removeAttr(axisSetAttr);
+    } else {
+      op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+    }
+    return success();
+  }
+
+  std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+  FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+  if (!symbolAttr) {
+    return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+  }
+  SymbolTableCollection symbolTableCollection;
+  mesh::ClusterOp mesh =
+      symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+          op.getOperation(), symbolAttr);
+  if (!mesh) {
+    return op.emitError() << "Undefined required mesh symbol \""
+                          << symbolAttr.getValue() << "\".";
+  }
+  DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+  if (!meshAxes) {
+    return success();
+  }
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : meshAxes.asArrayRef()) {
+    if (axis >= rank || axis < 0) {
+      return op.emitError()
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \""
+             << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+    }
+  }
+
+  return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+  using ElementType = std::decay_t<decltype(*begin)>;
+  return std::accumulate(begin, end, ElementType(1),
+                         std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+  return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+                                  ArrayRef<int64_t> meshShape) {
+  int64_t res = 1;
+  for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
+    if (llvm::find(meshAxes, axis) == meshAxes.end()) {
+      continue;
+    }
+    if (isMeshDimensionDynamic(meshShape[axis])) {
+      return ShapedType::kDynamic;
+    }
+    res *= meshShape[axis];
+  }
+  return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+                                           int64_t expectedDimSize,
+                                           int64_t resultDimSize,
+                                           int64_t resultAxis) {
+  if (!ShapedType::isDynamic(resultDimSize) &&
+      expectedDimSize != resultDimSize) {
+    return emitError(loc) << "Dimension size mismatch for result axis "
+                          << resultAxis << ". Expected "
+                          << (ShapedType::isDynamic(expectedDimSize)
+                                  ? Twine("dynamic")
+                                  : Twine(expectedDimSize))
+                          << ", but got " << resultDimSize << ".";
+  }
+
+  return success();
+}
+
+LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
----------------
sogartar wrote:

I am not sure what are the restrictions on symbol resolution. We may have to impose additional restrictions on the mesh operations to be able to use this.
It is likely no one else has ran into this problem. `func.call` for example does not have type inference. It would have hit this issue if it did.

https://github.com/llvm/llvm-project/pull/71960


More information about the Mlir-commits mailing list