[Mlir-commits] [mlir] [mlir][mpi] Lowering Mpi To LLVM (PR #127053)
Christian Ulmann
llvmlistbot at llvm.org
Fri Feb 21 07:39:33 PST 2025
================
@@ -0,0 +1,501 @@
+//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+//
+// Copyright (C) by Argonne National Laboratory
+// See COPYRIGHT in top-level directory
+// of MPICH source repository.
+//
+
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <memory>
+
+using namespace mlir;
+
+namespace {
+
+template <typename Op, typename... Args>
+static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
+ ConversionPatternRewriter &rewriter, StringRef name,
+ Args &&...args) {
+ Op ret;
+ if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
+ }
+ return ret;
+}
+
+static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
+ const Location loc,
+ ConversionPatternRewriter &rewriter,
+ StringRef name,
+ LLVM::LLVMFunctionType type) {
+ return getOrDefineGlobal<LLVM::LLVMFuncOp>(
+ moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
+}
+
+/// When lowering the mpi dialect to functions calls certain details
+/// differ between various MPI implementations. This class will provide
+/// these in a generic way, depending on the MPI implementation that got
+/// selected by the DLTI attribute on the module.
+class MPIImplTraits {
+ ModuleOp &moduleOp;
+
+public:
+ /// Instantiate a new MPIImplTraits object according to the DLTI attribute
+ /// on the given module.
+ static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
+
+ MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
+
+ ModuleOp &getModuleOp() { return moduleOp; }
+
+ /// Gets or creates MPI_COMM_WORLD as a Value.
+ virtual Value getCommWorld(const Location loc,
+ ConversionPatternRewriter &rewriter) = 0;
+
+ /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
+ virtual intptr_t getStatusIgnore() = 0;
+
+ /// get/create MPI datatype as a Value which corresponds to the given
+ /// Type
+ virtual Value getDataType(const Location loc,
+ ConversionPatternRewriter &rewriter, Type type) = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// Implementation details for MPICH ABI compatible MPI implementations
+//===----------------------------------------------------------------------===//
+
+class MPICHImplTraits : public MPIImplTraits {
+ static constexpr int MPI_FLOAT = 0x4c00040a;
+ static constexpr int MPI_DOUBLE = 0x4c00080b;
+ static constexpr int MPI_INT8_T = 0x4c000137;
+ static constexpr int MPI_INT16_T = 0x4c000238;
+ static constexpr int MPI_INT32_T = 0x4c000439;
+ static constexpr int MPI_INT64_T = 0x4c00083a;
+ static constexpr int MPI_UINT8_T = 0x4c00013b;
+ static constexpr int MPI_UINT16_T = 0x4c00023c;
+ static constexpr int MPI_UINT32_T = 0x4c00043d;
+ static constexpr int MPI_UINT64_T = 0x4c00083e;
+
+public:
+ using MPIImplTraits::MPIImplTraits;
+
+ Value getCommWorld(const Location loc,
+ ConversionPatternRewriter &rewriter) override {
+ static const int MPI_COMM_WORLD = 0x44000000;
+ return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+ MPI_COMM_WORLD);
+ }
+
+ intptr_t getStatusIgnore() override { return 1; }
+
+ Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
+ Type type) override {
+ int32_t mtype = 0;
+ if (type.isF32())
+ mtype = MPI_FLOAT;
+ else if (type.isF64())
+ mtype = MPI_DOUBLE;
+ else if (type.isInteger(64) && !type.isUnsignedInteger())
+ mtype = MPI_INT64_T;
+ else if (type.isInteger(64))
+ mtype = MPI_UINT64_T;
+ else if (type.isInteger(32) && !type.isUnsignedInteger())
+ mtype = MPI_INT32_T;
+ else if (type.isInteger(32))
+ mtype = MPI_UINT32_T;
+ else if (type.isInteger(16) && !type.isUnsignedInteger())
+ mtype = MPI_INT16_T;
+ else if (type.isInteger(16))
+ mtype = MPI_UINT16_T;
+ else if (type.isInteger(8) && !type.isUnsignedInteger())
+ mtype = MPI_INT8_T;
+ else if (type.isInteger(8))
+ mtype = MPI_UINT8_T;
+ else
+ assert(false && "unsupported type");
+ return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Implementation details for OpenMPI
+//===----------------------------------------------------------------------===//
+class OMPIImplTraits : public MPIImplTraits {
+ LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
+ ConversionPatternRewriter &rewriter,
+ StringRef name,
+ LLVM::LLVMStructType type) {
+
+ return getOrDefineGlobal<LLVM::GlobalOp>(
+ getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
+ LLVM::Linkage::External, name,
+ /*value=*/Attribute(), /*alignment=*/0, 0);
+ }
+
+public:
+ using MPIImplTraits::MPIImplTraits;
+
+ Value getCommWorld(const Location loc,
+ ConversionPatternRewriter &rewriter) override {
+ auto context = rewriter.getContext();
+ // get external opaque struct pointer type
+ auto commStructT =
+ LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
+ StringRef name = "ompi_mpi_comm_world";
+
+ // make sure global op definition exists
+ (void)getOrDefineExternalStruct(loc, rewriter, name, commStructT);
+
+ // get address of symbol
+ return rewriter.create<LLVM::AddressOfOp>(
+ loc, LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, name));
+ }
+
+ intptr_t getStatusIgnore() override { return 0; }
+
+ Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
+ Type type) override {
+ StringRef mtype;
+ if (type.isF32())
+ mtype = "ompi_mpi_float";
+ else if (type.isF64())
+ mtype = "ompi_mpi_double";
+ else if (type.isInteger(64) && !type.isUnsignedInteger())
+ mtype = "ompi_mpi_int64_t";
+ else if (type.isInteger(64))
+ mtype = "ompi_mpi_uint64_t";
+ else if (type.isInteger(32) && !type.isUnsignedInteger())
+ mtype = "ompi_mpi_int32_t";
+ else if (type.isInteger(32))
+ mtype = "ompi_mpi_uint32_t";
+ else if (type.isInteger(16) && !type.isUnsignedInteger())
+ mtype = "ompi_mpi_int16_t";
+ else if (type.isInteger(16))
+ mtype = "ompi_mpi_uint16_t";
+ else if (type.isInteger(8) && !type.isUnsignedInteger())
+ mtype = "ompi_mpi_int8_t";
+ else if (type.isInteger(8))
+ mtype = "ompi_mpi_uint8_t";
+ else
+ assert(false && "unsupported type");
+
+ auto context = rewriter.getContext();
+ // get external opaque struct pointer type
+ auto commStructT =
+ LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
+ // make sure global op definition exists
+ (void)getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
----------------
Dinistro wrote:
```suggestion
getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
```
https://github.com/llvm/llvm-project/pull/127053
More information about the Mlir-commits
mailing list