[Mlir-commits] [mlir] [mlir][mpi] Lowering Mpi To LLVM (PR #127053)

Frank Schlimbach llvmlistbot at llvm.org
Mon Feb 17 12:11:37 PST 2025


================
@@ -0,0 +1,332 @@
+//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// This must go first (MPI gets confused otherwise)
+#include "MPIImplTraits.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Pass/Pass.h"
+
+#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>
+
+using namespace mlir;
+
+// TODO: this was copied from GPUOpsLowering.cpp:288
+// is this okay, or should this be moved to some common file?
+static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
+                                            const Location loc,
+                                            ConversionPatternRewriter &rewriter,
+                                            StringRef name,
+                                            LLVM::LLVMFunctionType type) {
+  LLVM::LLVMFuncOp ret;
+  if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
+                                            LLVM::Linkage::External);
+  }
+  return ret;
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// InitOpLowering
+//===----------------------------------------------------------------------===//
+
+struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get loc
+    auto loc = op.getLoc();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+    // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
+    auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+    Value llvmnull = nullPtrOp.getRes();
+
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+
+    // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
+    auto initFuncType =
+        LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp initDecl =
+        getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
+
+    // replace init with function call
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
+                                              ValueRange{llvmnull, llvmnull});
+
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// FinalizeOpLowering
+//===----------------------------------------------------------------------===//
+
+struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get loc
+    auto loc = op.getLoc();
+
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+
+    // LLVM Function type representing `i32 MPI_Finalize()`
+    auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
+
+    // replace init with function call
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
+
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// CommRankOpLowering
+//===----------------------------------------------------------------------===//
+
+struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get some helper vars
+    auto loc = op.getLoc();
+    auto context = rewriter.getContext();
+    auto i32 = rewriter.getI32Type();
----------------
fschlimb wrote:

See above.

https://github.com/llvm/llvm-project/pull/127053


More information about the Mlir-commits mailing list