[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)
Chengji Yao
llvmlistbot at llvm.org
Sun Nov 12 22:37:58 PST 2023
================
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+ if (!attr) {
+ return std::nullopt;
+ }
+ SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+ canonicalizeSetAsVector(axes);
+ if (axes.empty()) {
+ return std::nullopt;
+ }
+ return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+ AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+ : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+ op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+ if (!canonicalMeshAxesAttr) {
+ op->removeAttr(axisSetAttr);
+ } else {
+ op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+ }
+ return success();
+ }
+
+ std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+ FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+ if (!symbolAttr) {
+ return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+ }
+ SymbolTableCollection symbolTableCollection;
+ mesh::ClusterOp mesh =
+ symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), symbolAttr);
+ if (!mesh) {
+ return op.emitError() << "Undefined required mesh symbol \""
+ << symbolAttr.getValue() << "\".";
+ }
+ DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+ if (!meshAxes) {
+ return success();
+ }
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : meshAxes.asArrayRef()) {
+ if (axis >= rank || axis < 0) {
+ return op.emitError()
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \""
+ << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+ }
+ }
+
+ return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+ using ElementType = std::decay_t<decltype(*begin)>;
+ return std::accumulate(begin, end, ElementType(1),
----------------
yaochengji wrote:
I'd like to suggest to use `static_cast<ElementType>(1)` in case some classes use explicit constructor
https://github.com/llvm/llvm-project/pull/71960
More information about the Mlir-commits
mailing list