[Mlir-commits] [mlir] [mlir][mesh, mpi] Lower allreduce (PR #144060)

Christian Ulmann llvmlistbot at llvm.org
Mon Jun 16 04:20:32 PDT 2025


================
@@ -529,6 +519,124 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
   }
 };
 
+static mpi::MPI_OpClassEnumAttr getMPIReduction(ReductionKindAttr kind) {
+  auto ctx = kind.getContext();
+  switch (kind.getValue()) {
+  case ReductionKind::Sum:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_SUM);
+  case ReductionKind::Product:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_PROD);
+  case ReductionKind::Min:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MIN);
+  case ReductionKind::Max:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MAX);
+  case ReductionKind::BitwiseAnd:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BAND);
+  case ReductionKind::BitwiseOr:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BOR);
+  case ReductionKind::BitwiseXor:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BXOR);
+  default:
+    assert(false && "Unknown/unsupported reduction kind");
+  }
+}
+
+struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    auto mesh = adaptor.getMesh();
+    auto meshOp = getMesh(op, symbolTableCollection);
+    if (!meshOp)
+      return op->emitError() << "No mesh found for AllReduceOp";
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return op->emitError()
+             << "Dynamic mesh shape not supported in AllReduceOp";
+
+    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
----------------
Dinistro wrote:

Using a normal builder is not legal, you need to use the provided rewriter for all IR manipulations. 

https://github.com/llvm/llvm-project/pull/144060


More information about the Mlir-commits mailing list