[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)
Boian Petkantchin
llvmlistbot at llvm.org
Wed Nov 15 16:30:04 PST 2023
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
+// collective communication ops
+namespace {
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+ if (!attr) {
+ return std::nullopt;
+ }
+ SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+ canonicalizeSetAsVector(axes);
+ if (axes.empty()) {
+ return std::nullopt;
+ }
+ return DenseI16ArrayAttr::get(attr.getContext(), axes);
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+ AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+ : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+ op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+ if (!canonicalMeshAxesAttr) {
+ op->removeAttr(axisSetAttr);
+ } else {
+ op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+ }
+ return success();
+ }
+ std::string axisSetAttr;
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+ FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+ if (!symbolAttr) {
+ return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+ }
+ SymbolTableCollection symbolTableCollection;
+ mesh::ClusterOp mesh =
+ symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), symbolAttr);
+ if (!mesh) {
+ return op.emitError() << "Undefined required mesh symbol \""
+ << symbolAttr.getValue() << "\".";
+ }
+ DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+ if (!meshAxes) {
+ return success();
+ }
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : meshAxes.asArrayRef()) {
+ if (axis >= rank || axis < 0) {
+ return op.emitError()
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \""
+ << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+ }
+ }
+ return success();
+template <typename It>
+auto product(It begin, It end) {
+ using ElementType = std::decay_t<decltype(*begin)>;
+ return std::accumulate(begin, end, ElementType(1),
+ std::multiplies<ElementType>());
+template <typename R>
+auto product(R &&range) {
+ return product(adl_begin(range), adl_end(range));
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ int64_t res = 1;
+ for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
+ if (llvm::find(meshAxes, axis) == meshAxes.end()) {
+ continue;
+ }
+ if (isMeshDimensionDynamic(meshShape[axis])) {
+ return ShapedType::kDynamic;
+ }
+ res *= meshShape[axis];
+ }
+ return res;
+LogicalResult verifyDimensionCompatibility(Location loc,
+ int64_t expectedDimSize,
+ int64_t resultDimSize,
+ int64_t resultAxis) {
+ if (!ShapedType::isDynamic(resultDimSize) &&
+ expectedDimSize != resultDimSize) {
+ return emitError(loc) << "Dimension size mismatch for result axis "
+ << resultAxis << ". Expected "
+ << (ShapedType::isDynamic(expectedDimSize)
+ ? Twine("dynamic")
+ : Twine(expectedDimSize))
+ << ", but got " << resultDimSize << ".";
+ }
+ return success();
+LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
+ int64_t gatherAxis,
+ ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+ auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+ auto expectedResultDimSize =
+ axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+ return failure();
+ }
+ }
+ return success();
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+ SymbolTableCollection symbolTableCollection;
+ if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+ // We need to check the symbol here since this runs before
+ // SymbolUserOpInterface.
+ return failure();
+ }
+ return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), op.getMeshAttr());
+template <typename Op>
+LogicalResult verifyGather(Op op) {
+ auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
+ auto gatherAxis = op.getGatherAxis().getSExtValue();
+ if (gatherAxis < 0 || gatherAxis >= rank) {
+ return op.emitError() << "Gather axis " << gatherAxis
+ << " is out of bounds [0, " << rank << ").";
+ }
+ auto mesh = getMesh(op);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return verifyGatherOperandAndResultShape(op.getOperand(), op.getResult(),
+ gatherAxis, op.getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result,
+ int64_t splitAxis,
+ int64_t concatAxis,
+ ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), operandType.getDimSize(axis),
+ resultType.getDimSize(axis), axis))) {
+ return failure();
+ }
+ }
+ }
+ if (splitAxis == concatAxis) {
+ return success();
+ }
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
+ auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
+ if (!operandSplitDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+ int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
+ return emitError(result.getLoc())
sogartar wrote:
How should the operation function in that scenario?
In StableHLO's [all-to-all](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_to_all) they also enforce this constraint.
More information about the Mlir-commits
mailing list