[Mlir-commits] [mlir] [mlir][mpi] Lowering Mpi To LLVM (PR #127053)

Christian Ulmann llvmlistbot at llvm.org
Fri Feb 21 00:29:46 PST 2025


================
@@ -0,0 +1,521 @@
+//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+//
+// Copyright (C) by Argonne National Laboratory
+//    See COPYRIGHT in top-level directory
+//    of MPICH source repository.
+//
+
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+template <typename Op, typename... Args>
+static Op getOrDefineGlobal(mlir::ModuleOp &moduleOp, const Location loc,
+                            ConversionPatternRewriter &rewriter, StringRef name,
+                            Args &&...args) {
+  Op ret;
+  if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
+  }
+  return ret;
+}
+
+static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
+                                            const Location loc,
+                                            ConversionPatternRewriter &rewriter,
+                                            StringRef name,
+                                            LLVM::LLVMFunctionType type) {
+  return getOrDefineGlobal<LLVM::LLVMFuncOp>(
+      moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
+}
+
+//===----------------------------------------------------------------------===//
+// Implementation details for MPICH ABI compatible MPI implementations
+//===----------------------------------------------------------------------===//
+struct MPICHImplTraits {
+  static const int MPI_FLOAT = 0x4c00040a;
+  static const int MPI_DOUBLE = 0x4c00080b;
+  static const int MPI_INT8_T = 0x4c000137;
+  static const int MPI_INT16_T = 0x4c000238;
+  static const int MPI_INT32_T = 0x4c000439;
+  static const int MPI_INT64_T = 0x4c00083a;
+  static const int MPI_UINT8_T = 0x4c00013b;
+  static const int MPI_UINT16_T = 0x4c00023c;
+  static const int MPI_UINT32_T = 0x4c00043d;
+  static const int MPI_UINT64_T = 0x4c00083e;
+
+  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+                                  const mlir::Location loc,
+                                  mlir::ConversionPatternRewriter &rewriter) {
+    static const int MPI_COMM_WORLD = 0x44000000;
+    return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                   MPI_COMM_WORLD);
+  }
+
+  static intptr_t getStatusIgnore() { return 1; }
+
+  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+                                 const mlir::Location loc,
+                                 mlir::ConversionPatternRewriter &rewriter,
+                                 mlir::Type type) {
+    int32_t mtype = 0;
+    if (type.isF32())
+      mtype = MPI_FLOAT;
+    else if (type.isF64())
+      mtype = MPI_DOUBLE;
+    else if (type.isInteger(64) && !type.isUnsignedInteger())
+      mtype = MPI_INT64_T;
+    else if (type.isInteger(64))
+      mtype = MPI_UINT64_T;
+    else if (type.isInteger(32) && !type.isUnsignedInteger())
+      mtype = MPI_INT32_T;
+    else if (type.isInteger(32))
+      mtype = MPI_UINT32_T;
+    else if (type.isInteger(16) && !type.isUnsignedInteger())
+      mtype = MPI_INT16_T;
+    else if (type.isInteger(16))
+      mtype = MPI_UINT16_T;
+    else if (type.isInteger(8) && !type.isUnsignedInteger())
+      mtype = MPI_INT8_T;
+    else if (type.isInteger(8))
+      mtype = MPI_UINT8_T;
+    else
+      assert(false && "unsupported type");
+    return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                   mtype);
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Implementation details for OpenMPI
+//===----------------------------------------------------------------------===//
+struct OMPIImplTraits {
+
+  static mlir::LLVM::GlobalOp
+  getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                            mlir::ConversionPatternRewriter &rewriter,
+                            mlir::StringRef name,
+                            mlir::LLVM::LLVMStructType type) {
+
+    return getOrDefineGlobal<mlir::LLVM::GlobalOp>(
+        moduleOp, loc, rewriter, name, type, /*isConstant=*/false,
+        mlir::LLVM::Linkage::External, name,
+        /*value=*/mlir::Attribute(), /*alignment=*/0, 0);
+  }
+
+  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+                                  const mlir::Location loc,
+                                  mlir::ConversionPatternRewriter &rewriter) {
+    auto context = rewriter.getContext();
+    // get external opaque struct pointer type
+    auto commStructT =
+        mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
+    mlir::StringRef name = "ompi_mpi_comm_world";
+
+    // make sure global op definition exists
+    (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
+
+    // get address of symbol
+    return rewriter.create<mlir::LLVM::AddressOfOp>(
+        loc, mlir::LLVM::LLVMPointerType::get(context),
+        mlir::SymbolRefAttr::get(context, name));
+  }
+
+  static intptr_t getStatusIgnore() { return 0; }
+
+  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+                                 const mlir::Location loc,
+                                 mlir::ConversionPatternRewriter &rewriter,
+                                 mlir::Type type) {
+    mlir::StringRef mtype;
+    if (type.isF32())
+      mtype = "ompi_mpi_float";
+    else if (type.isF64())
+      mtype = "ompi_mpi_double";
+    else if (type.isInteger(64) && !type.isUnsignedInteger())
+      mtype = "ompi_mpi_int64_t";
+    else if (type.isInteger(64))
+      mtype = "ompi_mpi_uint64_t";
+    else if (type.isInteger(32) && !type.isUnsignedInteger())
+      mtype = "ompi_mpi_int32_t";
+    else if (type.isInteger(32))
+      mtype = "ompi_mpi_uint32_t";
+    else if (type.isInteger(16) && !type.isUnsignedInteger())
+      mtype = "ompi_mpi_int16_t";
+    else if (type.isInteger(16))
+      mtype = "ompi_mpi_uint16_t";
+    else if (type.isInteger(8) && !type.isUnsignedInteger())
+      mtype = "ompi_mpi_int8_t";
+    else if (type.isInteger(8))
+      mtype = "ompi_mpi_uint8_t";
+    else
+      assert(false && "unsupported type");
+
+    auto context = rewriter.getContext();
+    // get external opaque struct pointer type
+    auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
+        "ompi_predefined_datatype_t", context);
+    // make sure global op definition exists
+    (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype,
+                                    commStructT);
+    // get address of symbol
+    return rewriter.create<mlir::LLVM::AddressOfOp>(
+        loc, mlir::LLVM::LLVMPointerType::get(context),
+        mlir::SymbolRefAttr::get(context, mtype));
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// When lowering the mpi dialect to functions calls certain details
+// differ between various MPI implementations. This class will provide
+// these in a generic way, depending on the MPI implementation that got
+// selected by the DLTI attribute on the module.
+//===----------------------------------------------------------------------===//
+struct MPIImplTraits {
+  enum MPIImpl { MPICH, OMPI };
+
+  // Get the MPI implementation from a DLTI attribute on the module.
+  // Default to MPICH (and ABI compatible).
----------------
Dinistro wrote:

```suggestion
  /// Gets the MPI implementation from a DLTI attribute on the module.
  /// Defaults to MPICH (and ABI compatible).
```
Nit: In general, comments on classes and members should use `///` for doxygen to pick them up. 

https://github.com/llvm/llvm-project/pull/127053


More information about the Mlir-commits mailing list