[Mlir-commits] [mlir] [mlir][mpi] Lowering Mpi To LLVM (PR #127053)
Christian Ulmann
llvmlistbot at llvm.org
Mon Feb 17 08:53:58 PST 2025
================
@@ -0,0 +1,157 @@
+#define MPICH_SKIP_MPICXX 1
+#define OMPI_SKIP_MPICXX 1
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+// skip if no MPI C header was found
+#ifdef FOUND_MPI_C_HEADER
+#include <mpi.h>
+#else // not FOUND_MPI_C_HEADER
+#include "mpi_fallback.h"
+#endif // FOUND_MPI_C_HEADER
+
+namespace {
+
+// when lowerring the mpi dialect to functions calls certain details
+// differ between various MPI implementations. This class will provide
+// these depending on the MPI implementation that got included.
+struct MPIImplTraits {
+ // get/create MPI_COMM_WORLD as a mlir::Value
+ static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+ const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter);
+ // get/create MPI datatype as a mlir::Value which corresponds to the given
+ // mlir::Type
+ static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+ const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type type);
+};
+
+// ****************************************************************************
+// Intel MPI/MPICH
+#if defined(IMPI_DEVICE_EXPORT) || defined(_MPI_FALLBACK_DEFS)
+
+mlir::Value
+MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter) {
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+ MPI_COMM_WORLD);
+}
+
+mlir::Value
+MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::Type type) {
+ int32_t mtype = 0;
+ if (type.isF32())
+ mtype = MPI_FLOAT;
+ else if (type.isF64())
+ mtype = MPI_DOUBLE;
+ else if (type.isInteger(64) && !type.isUnsignedInteger())
+ mtype = MPI_INT64_T;
+ else if (type.isInteger(64))
+ mtype = MPI_UINT64_T;
+ else if (type.isInteger(32) && !type.isUnsignedInteger())
+ mtype = MPI_INT32_T;
+ else if (type.isInteger(32))
+ mtype = MPI_UINT32_T;
+ else if (type.isInteger(16) && !type.isUnsignedInteger())
+ mtype = MPI_INT16_T;
+ else if (type.isInteger(16))
+ mtype = MPI_UINT16_T;
+ else if (type.isInteger(8) && !type.isUnsignedInteger())
+ mtype = MPI_INT8_T;
+ else if (type.isInteger(8))
+ mtype = MPI_UINT8_T;
+ else
+ assert(false && "unsupported type");
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+ mtype);
+}
+
+// ****************************************************************************
+// OpenMPI
+#elif defined(OPEN_MPI) && OPEN_MPI == 1
+
+// TODO: this is pretty close to getOrDefineFunction, can probably be factored
+static mlir::LLVM::GlobalOp
+getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter,
+ mlir::StringRef name,
+ mlir::LLVM::LLVMStructType type) {
+ mlir::LLVM::GlobalOp ret;
+ if (!(ret = moduleOp.lookupSymbol<mlir::LLVM::GlobalOp>(name))) {
+ mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ ret = rewriter.create<mlir::LLVM::GlobalOp>(
+ loc, type, /*isConstant=*/false, mlir::LLVM::Linkage::External, name,
+ /*value=*/mlir::Attribute(), /*alignment=*/0, 0);
+ }
+ return ret;
+}
+
+mlir::Value
+MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+ mlir::ConversionPatternRewriter &rewriter) {
+ auto context = rewriter.getContext();
+ // get external opaque struct pointer type
+ auto commStructT =
+ mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
+ const char *name = "ompi_mpi_comm_world";
----------------
Dinistro wrote:
Nit: Don't use `const char *` in C++ except you have a very good reason.
```suggestion
StringRef name = "ompi_mpi_comm_world";
```
https://github.com/llvm/llvm-project/pull/127053
More information about the Mlir-commits
mailing list