[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)
Mehdi Amini
llvmlistbot at llvm.org
Wed Nov 15 01:30:49 PST 2023
================
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+ if (!attr) {
+ return std::nullopt;
+ }
+ SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+ canonicalizeSetAsVector(axes);
+ if (axes.empty()) {
+ return std::nullopt;
+ }
+ return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+ AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+ : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+ op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+ if (!canonicalMeshAxesAttr) {
+ op->removeAttr(axisSetAttr);
+ } else {
+ op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+ }
+ return success();
+ }
+
+ std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
----------------
joker-eph wrote:
You could also add the specific attributes likes `getMeshAttr()` to the API of the non-templated function?
https://github.com/llvm/llvm-project/pull/71960
More information about the Mlir-commits
mailing list