[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)

Chengji Yao llvmlistbot at llvm.org
Thu Nov 16 15:56:29 PST 2023


================
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+  if (!attr) {
+    return std::nullopt;
+  }
+  SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+  canonicalizeSetAsVector(axes);
+  if (axes.empty()) {
+    return std::nullopt;
+  }
+  return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+  AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+      : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+  LogicalResult matchAndRewrite(Op op,
+                                PatternRewriter &rewriter) const override {
+    auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+        op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+    if (!canonicalMeshAxesAttr) {
+      op->removeAttr(axisSetAttr);
+    } else {
+      op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+    }
+    return success();
+  }
+
+  std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+  FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+  if (!symbolAttr) {
+    return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+  }
+  SymbolTableCollection symbolTableCollection;
+  mesh::ClusterOp mesh =
+      symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+          op.getOperation(), symbolAttr);
+  if (!mesh) {
+    return op.emitError() << "Undefined required mesh symbol \""
+                          << symbolAttr.getValue() << "\".";
+  }
+  DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+  if (!meshAxes) {
+    return success();
+  }
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : meshAxes.asArrayRef()) {
+    if (axis >= rank || axis < 0) {
+      return op.emitError()
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \""
+             << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+    }
+  }
+
+  return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+  using ElementType = std::decay_t<decltype(*begin)>;
+  return std::accumulate(begin, end, ElementType(1),
+                         std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+  return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+                                  ArrayRef<int64_t> meshShape) {
+  int64_t res = 1;
+  for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
+    if (llvm::find(meshAxes, axis) == meshAxes.end()) {
+      continue;
+    }
+    if (isMeshDimensionDynamic(meshShape[axis])) {
+      return ShapedType::kDynamic;
+    }
+    res *= meshShape[axis];
+  }
+  return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+                                           int64_t expectedDimSize,
+                                           int64_t resultDimSize,
+                                           int64_t resultAxis) {
+  if (!ShapedType::isDynamic(resultDimSize) &&
+      expectedDimSize != resultDimSize) {
+    return emitError(loc) << "Dimension size mismatch for result axis "
+                          << resultAxis << ". Expected "
+                          << (ShapedType::isDynamic(expectedDimSize)
+                                  ? Twine("dynamic")
+                                  : Twine(expectedDimSize))
+                          << ", but got " << resultDimSize << ".";
+  }
+
+  return success();
+}
+
+LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
+                                                int64_t gatherAxis,
+                                                ArrayRef<MeshAxis> meshAxes,
+                                                ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+    auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+    auto expectedResultDimSize =
+        axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+    if (failed(verifyDimensionCompatibility(
+            result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+  SymbolTableCollection symbolTableCollection;
+  if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+    // We need to check the symbol here since this runs before
+    // SymbolUserOpInterface.
+    return failure();
+  }
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+      op.getOperation(), op.getMeshAttr());
+}
+
+template <typename Op>
+LogicalResult verifyGather(Op op) {
+  auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
+  auto gatherAxis = op.getGatherAxis().getSExtValue();
+  if (gatherAxis < 0 || gatherAxis >= rank) {
+    return op.emitError() << "Gather axis " << gatherAxis
+                          << " is out of bounds [0, " << rank << ").";
+  }
+
+  auto mesh = getMesh(op);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyGatherOperandAndResultShape(op.getOperand(), op.getResult(),
+                                           gatherAxis, op.getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
+}
+
+LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result,
+                                                  int64_t splitAxis,
+                                                  int64_t concatAxis,
+                                                  ArrayRef<MeshAxis> meshAxes,
+                                                  ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
+      if (failed(verifyDimensionCompatibility(
+              result.getLoc(), operandType.getDimSize(axis),
+              resultType.getDimSize(axis), axis))) {
+        return failure();
+      }
+    }
+  }
+
+  if (splitAxis == concatAxis) {
+    return success();
+  }
+
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
+  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
+  if (!operandSplitDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+      int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
+    return emitError(result.getLoc())
----------------
yaochengji wrote:

In this PR, I suggest we could mark the result size dynamic if `operandSplitDimSize` can not be divided by `deviceGroupSize` or at least one of the divisor dividend is dynamic. Thus we don't need to come out the formula at current stage. 

Later we can add more information if the size is dynamic, which should depends on the formula.

https://github.com/llvm/llvm-project/pull/71960


More information about the Mlir-commits mailing list