[Mlir-commits] [mlir] [mlir][mpi] Lowering Mpi To LLVM (PR #127053)
Christian Ulmann
llvmlistbot at llvm.org
Mon Feb 17 08:53:57 PST 2025
================
@@ -0,0 +1,157 @@
+#define MPICH_SKIP_MPICXX 1
+#define OMPI_SKIP_MPICXX 1
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+// skip if no MPI C header was found
+#ifdef FOUND_MPI_C_HEADER
+#include <mpi.h>
+#else // not FOUND_MPI_C_HEADER
+#include "mpi_fallback.h"
+#endif // FOUND_MPI_C_HEADER
+
+namespace {
+
+// when lowerring the mpi dialect to functions calls certain details
+// differ between various MPI implementations. This class will provide
+// these depending on the MPI implementation that got included.
+struct MPIImplTraits {
+ // get/create MPI_COMM_WORLD as a mlir::Value
+ static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+ const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter);
+ // get/create MPI datatype as a mlir::Value which corresponds to the given
+ // mlir::Type
+ static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+ const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type type);
+};
+
+// ****************************************************************************
+// Intel MPI/MPICH
+#if defined(IMPI_DEVICE_EXPORT) || defined(_MPI_FALLBACK_DEFS)
+
+mlir::Value
+MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter) {
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+ MPI_COMM_WORLD);
+}
+
+mlir::Value
+MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type type) {
+ int32_t mtype = 0;
+ if (type.isF32())
+ mtype = MPI_FLOAT;
+ else if (type.isF64())
+ mtype = MPI_DOUBLE;
+ else if (type.isInteger(64) && !type.isUnsignedInteger())
+ mtype = MPI_INT64_T;
+ else if (type.isInteger(64))
+ mtype = MPI_UINT64_T;
+ else if (type.isInteger(32) && !type.isUnsignedInteger())
+ mtype = MPI_INT32_T;
+ else if (type.isInteger(32))
+ mtype = MPI_UINT32_T;
+ else if (type.isInteger(16) && !type.isUnsignedInteger())
+ mtype = MPI_INT16_T;
+ else if (type.isInteger(16))
+ mtype = MPI_UINT16_T;
+ else if (type.isInteger(8) && !type.isUnsignedInteger())
+ mtype = MPI_INT8_T;
+ else if (type.isInteger(8))
+ mtype = MPI_UINT8_T;
+ else
+ assert(false && "unsupported type");
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+ mtype);
+}
+
+// ****************************************************************************
+// OpenMPI
+#elif defined(OPEN_MPI) && OPEN_MPI == 1
+
+// TODO: this is pretty close to getOrDefineFunction, can probably be factored
----------------
Dinistro wrote:
Nit: Please do
https://github.com/llvm/llvm-project/pull/127053
More information about the Mlir-commits
mailing list