[Mlir-commits] [mlir] [MLIR][MPI] Add first part of an `convert-mpi-to-llvm` lowering (PR #95524)
Anton Lydike
llvmlistbot at llvm.org
Fri Jun 14 03:31:35 PDT 2024
https://github.com/AntonLydike created https://github.com/llvm/llvm-project/pull/95524
The first set of patterns to convert the MPI dialect to LLVM.
Further conversion pattern will be added in future PRs.
>From 9d98c6bdd4495a7f3917ee86f2273a13967e4a6d Mon Sep 17 00:00:00 2001
From: Anton Lydike <me at antonlydike.de>
Date: Fri, 14 Jun 2024 11:26:26 +0100
Subject: [PATCH] add initial set of lowerings for MPI dialect
---
.../mlir/Conversion/MPIToLLVM/MPIToLLVM.h | 28 ++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 18 ++
mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 2 +-
mlir/lib/Conversion/CMakeLists.txt | 1 +
mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt | 17 ++
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 249 ++++++++++++++++++
mlir/test/Conversion/MPIToLLVM/ops.mlir | 39 +++
8 files changed, 354 insertions(+), 1 deletion(-)
create mode 100644 mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
create mode 100644 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
create mode 100644 mlir/test/Conversion/MPIToLLVM/ops.mlir
diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
new file mode 100644
index 0000000000000..181e3c3e72b3f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
@@ -0,0 +1,28 @@
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MPITOLLVM_H
+#define MLIR_CONVERSION_MPITOLLVM_H
+
+#include <memory>
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace mpi {
+void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+} // namespace mpi
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MPITOLLVM_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 7700299b3a4f3..4b4e40d4f7463 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -42,6 +42,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index eb58f4adc31d3..e947c9fc49d8c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -833,6 +833,24 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// MPItoLLVM
+//===----------------------------------------------------------------------===//
+
+def MPIToLLVMConversionPass : Pass<"convert-mpi-to-llvm"> {
+ let summary = "Convert MPI dialect operations to LLVM dialect function calls";
+ let description = [{
+ This pass converts MPI dialect operatoins to functions calls in the LLVM
+ dialect targeting the MPI stable ABI.
+ }];
+ let dependentDialects = ["LLVM::LLVMDialect"];
+ let options = [
+ Option<"indexBitwidth", "index-bitwidth", "unsigned",
+ /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+ "Bitwidth of the index type, 0 to use size of machine word">,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// NVVMToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index 87eefa719d45c..57ac512642829 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
//===----------------------------------------------------------------------===//
def MPI_Retval : MPI_Type<"Retval", "retval"> {
- let summary = "MPI function call return value";
+ let summary = "MPI function call return value (!mpi.retval)";
let description = [{
This type represents a return value from an MPI function call.
This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 0a03a2e133db1..46e3768801560 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -39,6 +39,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
+add_subdirectory(MPIToLLVM)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..f81fb25e56840
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRMPIToLLVM
+ MPIToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRMPIDialect
+ )
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
new file mode 100644
index 0000000000000..c4581dfbf3656
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -0,0 +1,249 @@
+//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_MPITOLLVMCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+// TODO: this was copied from GPUOpsLowering.cpp:288
+// is this okay, or should this be moved to some common file?
+LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
+ ConversionPatternRewriter &rewriter,
+ StringRef name,
+ LLVM::LLVMFunctionType type) {
+ LLVM::LLVMFuncOp ret;
+ if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
+ LLVM::Linkage::External);
+ }
+ return ret;
+}
+
+// TODO: this is pretty close to getOrDefineFunction, can probably be factored
+LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
+ ConversionPatternRewriter &rewriter,
+ StringRef name,
+ LLVM::LLVMStructType type) {
+ LLVM::GlobalOp ret;
+ if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ ret = rewriter.create<LLVM::GlobalOp>(
+ loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
+ /*value=*/Attribute(), /*alignment=*/0, 0);
+ }
+ return ret;
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// InitOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // get loc
+ auto loc = op.getLoc();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+ // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
+ auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+ Value llvmnull = nullPtrOp.getRes();
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
+ auto initFuncType =
+ LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl =
+ getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
+
+ // replace init with function call
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
+ ValueRange{llvmnull, llvmnull});
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FinalizeOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // get loc
+ auto loc = op.getLoc();
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // LLVM Function type representing `i32 MPI_Finalize()`
+ auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
+ "MPI_Finalize", initFuncType);
+
+ // replace init with function call
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CommRankLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // get some helper vars
+ auto loc = op.getLoc();
+ auto context = rewriter.getContext();
+ auto i32 = rewriter.getI32Type();
+
+ // ptrType `!llvm.ptr`
+ Type ptrType = LLVM::LLVMPointerType::get(context);
+
+ // get external opaque struct pointer type
+ auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
+
+ // grab a reference to the global module op:
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+
+ // make sure global op definition exists
+ getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
+ commStructT);
+
+ // get address of @MPI_COMM_WORLD
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+ auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
+ auto commWorld = rewriter.create<LLVM::AddressOfOp>(
+ loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
+
+ // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+ auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
+ // get or create function declaration:
+ LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
+
+ // replace init with function call
+ auto callOp = rewriter.create<LLVM::CallOp>(
+ loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});
+
+ // load the rank into a register
+ auto loadedRank =
+ rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
+
+ // if retval is checked, replace uses of retval with the results from the call
+ // op
+ SmallVector<Value> replacements;
+ if (op.getRetval()) {
+ replacements.push_back(callOp.getResult());
+ }
+ // replace all uses, then erase op
+ replacements.push_back(loadedRank.getRes());
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct MPIToLLVMConversionPass
+ : public impl::MPIToLLVMConversionPassBase<MPIToLLVMConversionPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ LLVMConversionTarget target(getContext());
+ RewritePatternSet patterns(&getContext());
+
+ LowerToLLVMOptions options(&getContext());
+ if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
+ options.overrideIndexBitwidth(indexBitwidth);
+
+ target.addIllegalDialect<mpi::MPIDialect>();
+
+ // not yet implemented, will be added in future patches:
+ target.addLegalOp<mpi::RecvOp>();
+ target.addLegalOp<mpi::SendOp>();
+ target.addLegalOp<mpi::ErrorClassOp>();
+ target.addLegalOp<mpi::RetvalCheckOp>();
+
+ LLVMTypeConverter converter(&getContext(), options);
+
+ converter.addConversion(
+ [&](mpi::RetvalType) { return IntegerType::get(&getContext(), 32); });
+
+ mpi::populateMPIToLLVMConversionPatterns(converter, patterns);
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.add<InitOpLowering>(converter);
+ patterns.add<CommRankOpLowering>(converter);
+ patterns.add<FinalizeOpLowering>(converter);
+}
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
new file mode 100644
index 0000000000000..71bd7ba464e67
--- /dev/null
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -convert-mpi-to-llvm %s | FileCheck %s
+
+module {
+// CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
+// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+
+ func.func @mpi_test(%arg0: memref<100xf32>) {
+ %0 = mpi.init : !mpi.retval
+// CHECK: %0 = llvm.mlir.zero : !llvm.ptr
+// CHECK: %1 = llvm.call @MPI_Init(%0, %0) : (!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: %2 = builtin.unrealized_conversion_cast %1 : i32 to !mpi.retval
+
+ %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+// CHECK: %3 = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %4 = llvm.alloca %3 x i32 : (i32) -> !llvm.ptr
+// CHECK: %5 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
+// CHECK: %6 = llvm.call @MPI_Comm_rank(%5, %4) : (!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: %7 = llvm.load %4 : !llvm.ptr -> i32
+// CHECK: %8 = builtin.unrealized_conversion_cast %6 : i32 to !mpi.retval
+
+ mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+ %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+ mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+ %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+ %3 = mpi.finalize : !mpi.retval
+// CHECK: %11 = llvm.call @MPI_Finalize() : () -> i32
+
+ %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+
+ %5 = mpi.error_class %0 : !mpi.retval
+ return
+ }
+}
More information about the Mlir-commits
mailing list