[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)

Mehdi Amini llvmlistbot at llvm.org
Wed Nov 15 01:30:49 PST 2023


================
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+  if (!attr) {
+    return std::nullopt;
+  }
+  SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+  canonicalizeSetAsVector(axes);
+  if (axes.empty()) {
+    return std::nullopt;
+  }
+  return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+  AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+      : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+  LogicalResult matchAndRewrite(Op op,
+                                PatternRewriter &rewriter) const override {
+    auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+        op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+    if (!canonicalMeshAxesAttr) {
+      op->removeAttr(axisSetAttr);
+    } else {
+      op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+    }
+    return success();
+  }
+
+  std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
----------------
joker-eph wrote:

You could also add the specific attributes likes `getMeshAttr()` to the API of the non-templated function?

https://github.com/llvm/llvm-project/pull/71960


More information about the Mlir-commits mailing list