[Mlir-commits] [mlir] [mlir][mpi] Lowering Mpi To LLVM (PR #127053)

Frank Schlimbach llvmlistbot at llvm.org
Tue Feb 18 02:55:32 PST 2025


================
@@ -0,0 +1,157 @@
+#define MPICH_SKIP_MPICXX 1
+#define OMPI_SKIP_MPICXX 1
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+// skip if no MPI C header was found
+#ifdef FOUND_MPI_C_HEADER
+#include <mpi.h>
+#else // not FOUND_MPI_C_HEADER
+#include "mpi_fallback.h"
+#endif // FOUND_MPI_C_HEADER
+
+namespace {
+
+// when lowerring the mpi dialect to functions calls certain details
+// differ between various MPI implementations. This class will provide
+// these depending on the MPI implementation that got included.
+struct MPIImplTraits {
+  // get/create MPI_COMM_WORLD as a mlir::Value
+  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+                                  const mlir::Location loc,
+                                  mlir::ConversionPatternRewriter &rewriter);
+  // get/create MPI datatype as a mlir::Value which corresponds to the given
+  // mlir::Type
+  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+                                 const mlir::Location loc,
+                                 mlir::ConversionPatternRewriter &rewriter,
+                                 mlir::Type type);
+};
+
+// ****************************************************************************
+// Intel MPI/MPICH
+#if defined(IMPI_DEVICE_EXPORT) || defined(_MPI_FALLBACK_DEFS)
+
+mlir::Value
+MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                            mlir::ConversionPatternRewriter &rewriter) {
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                 MPI_COMM_WORLD);
+}
+
+mlir::Value
+MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                           mlir::ConversionPatternRewriter &rewriter,
+                           mlir::Type type) {
+  int32_t mtype = 0;
+  if (type.isF32())
+    mtype = MPI_FLOAT;
+  else if (type.isF64())
+    mtype = MPI_DOUBLE;
+  else if (type.isInteger(64) && !type.isUnsignedInteger())
+    mtype = MPI_INT64_T;
+  else if (type.isInteger(64))
+    mtype = MPI_UINT64_T;
+  else if (type.isInteger(32) && !type.isUnsignedInteger())
+    mtype = MPI_INT32_T;
+  else if (type.isInteger(32))
+    mtype = MPI_UINT32_T;
+  else if (type.isInteger(16) && !type.isUnsignedInteger())
+    mtype = MPI_INT16_T;
+  else if (type.isInteger(16))
+    mtype = MPI_UINT16_T;
+  else if (type.isInteger(8) && !type.isUnsignedInteger())
+    mtype = MPI_INT8_T;
+  else if (type.isInteger(8))
+    mtype = MPI_UINT8_T;
+  else
+    assert(false && "unsupported type");
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                 mtype);
+}
+
+// ****************************************************************************
+// OpenMPI
+#elif defined(OPEN_MPI) && OPEN_MPI == 1
+
+// TODO: this is pretty close to getOrDefineFunction, can probably be factored
----------------
fschlimb wrote:

done

https://github.com/llvm/llvm-project/pull/127053


More information about the Mlir-commits mailing list