[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)
Mehdi Amini
llvmlistbot at llvm.org
Wed Nov 15 01:30:47 PST 2023
================
@@ -129,6 +207,327 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+ FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+ if (!symbolAttr) {
+ return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+ }
+ SymbolTableCollection symbolTableCollection;
+ mesh::ClusterOp mesh =
+ symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), symbolAttr);
+ if (!mesh) {
+ return op.emitError() << "Undefined required mesh symbol \""
+ << symbolAttr.getValue() << "\".";
+ }
+ DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+ if (!meshAxes) {
+ return success();
+ }
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : meshAxes.asArrayRef()) {
+ if (axis >= rank || axis < 0) {
+ return op.emitError()
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \""
+ << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+ }
+ }
+
+ return success();
+}
+
+template <typename It>
+bool isUnique(It begin, It end) {
+ if (begin == end) {
+ return true;
+ }
+ It next = std::next(begin);
+ if (next == end) {
+ return true;
+ }
+ for (; next != end; ++next, ++begin) {
+ if (*begin == *next) {
+ return false;
+ }
+ }
+ return true;
+}
+
+LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes) {
+ SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+ std::sort(sorted.begin(), sorted.end());
+ if (!isUnique(sorted.begin(), sorted.end())) {
+ return emitError(loc) << "Mesh axes contains duplicate elements.";
+ }
+ return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+ using ElementType = std::decay_t<decltype(*begin)>;
+ return std::accumulate(begin, end, static_cast<ElementType>(1),
+ std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+ return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ int64_t res = 1;
+
+ for (MeshAxis axis : meshAxes) {
+ if (isMeshDimensionDynamic(meshShape[axis])) {
+ return ShapedType::kDynamic;
+ }
+ assert(size_t(axis) < meshShape.size());
+ res *= meshShape[axis];
+ }
+
+ return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+ int64_t expectedDimSize,
+ int64_t resultDimSize,
+ int64_t resultAxis) {
+ if (!ShapedType::isDynamic(resultDimSize) &&
+ expectedDimSize != resultDimSize) {
+ return emitError(loc) << "Dimension size mismatch for result axis "
+ << resultAxis << ". Expected "
+ << (ShapedType::isDynamic(expectedDimSize)
+ ? Twine("dynamic")
+ : Twine(expectedDimSize))
+ << ", but got " << resultDimSize << ".";
+ }
+
+ return success();
+}
+
+LogicalResult verifyAllGatherOperandAndResultShape(
+ Value operand, Value result, int64_t gatherAxis,
+ ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+ auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+ auto expectedResultDimSize =
+ axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+ SymbolTableCollection symbolTableCollection;
+ if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+ // We need to check the symbol here since this runs before
+ // SymbolUserOpInterface.
+ return failure();
+ }
+ return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), op.getMeshAttr());
----------------
joker-eph wrote:
(same as above)
https://github.com/llvm/llvm-project/pull/71960
More information about the Mlir-commits
mailing list