[Mlir-commits] [mlir] [mlir][mesh, mpi] Lower allreduce (PR #144060)
Christian Ulmann
llvmlistbot at llvm.org
Mon Jun 16 04:20:32 PDT 2025
================
@@ -529,6 +519,124 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
}
};
+static mpi::MPI_OpClassEnumAttr getMPIReduction(ReductionKindAttr kind) {
+ auto ctx = kind.getContext();
+ switch (kind.getValue()) {
+ case ReductionKind::Sum:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_SUM);
+ case ReductionKind::Product:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_PROD);
+ case ReductionKind::Min:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MIN);
+ case ReductionKind::Max:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MAX);
+ case ReductionKind::BitwiseAnd:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BAND);
+ case ReductionKind::BitwiseOr:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BOR);
+ case ReductionKind::BitwiseXor:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BXOR);
+ default:
+ assert(false && "Unknown/unsupported reduction kind");
+ }
+}
+
+struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTableCollection;
+ auto mesh = adaptor.getMesh();
+ auto meshOp = getMesh(op, symbolTableCollection);
----------------
Dinistro wrote:
Nit: Avoid auto here, as the types aren't given on the RHS.
https://github.com/llvm/llvm-project/pull/144060
More information about the Mlir-commits
mailing list