[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)

Mehdi Amini llvmlistbot at llvm.org
Wed Nov 15 01:30:47 PST 2023


================
@@ -129,6 +207,327 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+  FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+  if (!symbolAttr) {
+    return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+  }
+  SymbolTableCollection symbolTableCollection;
+  mesh::ClusterOp mesh =
+      symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+          op.getOperation(), symbolAttr);
+  if (!mesh) {
+    return op.emitError() << "Undefined required mesh symbol \""
+                          << symbolAttr.getValue() << "\".";
+  }
+  DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+  if (!meshAxes) {
+    return success();
+  }
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : meshAxes.asArrayRef()) {
+    if (axis >= rank || axis < 0) {
+      return op.emitError()
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \""
+             << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+    }
+  }
+
+  return success();
+}
+
+template <typename It>
+bool isUnique(It begin, It end) {
+  if (begin == end) {
+    return true;
+  }
+  It next = std::next(begin);
+  if (next == end) {
+    return true;
+  }
+  for (; next != end; ++next, ++begin) {
+    if (*begin == *next) {
+      return false;
+    }
+  }
+  return true;
+}
+
+LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes) {
+  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+  std::sort(sorted.begin(), sorted.end());
+  if (!isUnique(sorted.begin(), sorted.end())) {
+    return emitError(loc) << "Mesh axes contains duplicate elements.";
+  }
+  return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+  using ElementType = std::decay_t<decltype(*begin)>;
+  return std::accumulate(begin, end, static_cast<ElementType>(1),
+                         std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+  return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+                                  ArrayRef<int64_t> meshShape) {
+  int64_t res = 1;
+
+  for (MeshAxis axis : meshAxes) {
+    if (isMeshDimensionDynamic(meshShape[axis])) {
+      return ShapedType::kDynamic;
+    }
+    assert(size_t(axis) < meshShape.size());
+    res *= meshShape[axis];
+  }
+
+  return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+                                           int64_t expectedDimSize,
+                                           int64_t resultDimSize,
+                                           int64_t resultAxis) {
+  if (!ShapedType::isDynamic(resultDimSize) &&
+      expectedDimSize != resultDimSize) {
+    return emitError(loc) << "Dimension size mismatch for result axis "
+                          << resultAxis << ". Expected "
+                          << (ShapedType::isDynamic(expectedDimSize)
+                                  ? Twine("dynamic")
+                                  : Twine(expectedDimSize))
+                          << ", but got " << resultDimSize << ".";
+  }
+
+  return success();
+}
+
+LogicalResult verifyAllGatherOperandAndResultShape(
+    Value operand, Value result, int64_t gatherAxis,
+    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+    auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+    auto expectedResultDimSize =
+        axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+    if (failed(verifyDimensionCompatibility(
+            result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+  SymbolTableCollection symbolTableCollection;
+  if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+    // We need to check the symbol here since this runs before
+    // SymbolUserOpInterface.
+    return failure();
+  }
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+      op.getOperation(), op.getMeshAttr());
----------------
joker-eph wrote:

(same as above)

https://github.com/llvm/llvm-project/pull/71960


More information about the Mlir-commits mailing list