[Mlir-commits] [mlir] [mlir][mpi] Lowering Mpi To LLVM (PR #127053)

Frank Schlimbach llvmlistbot at llvm.org
Fri Feb 21 06:56:56 PST 2025


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/127053

>From 979b0ab2bfd62b9a254286e80bb6551c5ee2c3be Mon Sep 17 00:00:00 2001
From: Anton Lydike <me at antonlydike.de>
Date: Fri, 14 Jun 2024 11:26:26 +0100
Subject: [PATCH 01/15] add initial set of lowerings for MPI dialect

---
 .../mlir/Conversion/MPIToLLVM/MPIToLLVM.h     |  30 +++
 mlir/include/mlir/Dialect/MPI/IR/MPITypes.td  |   2 +-
 mlir/include/mlir/InitAllExtensions.h         |   2 +
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt  |  17 ++
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   | 230 ++++++++++++++++++
 mlir/test/Conversion/MPIToLLVM/ops.mlir       |  40 +++
 7 files changed, 321 insertions(+), 1 deletion(-)
 create mode 100644 mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
 create mode 100644 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
 create mode 100644 mlir/test/Conversion/MPIToLLVM/ops.mlir

diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
new file mode 100644
index 0000000000000..940e5e8097318
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
@@ -0,0 +1,30 @@
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MPITOLLVM_H
+#define MLIR_CONVERSION_MPITOLLVM_H
+
+#include "mlir/IR/DialectRegistry.h"
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace mpi {
+void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                         RewritePatternSet &patterns);
+
+void registerConvertMPIToLLVMInterface(DialectRegistry &registry);
+
+} // namespace mpi
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MPITOLLVM_H
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
index fafea0eac8bb7..a55d30e778e22 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
@@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
 //===----------------------------------------------------------------------===//
 
 def MPI_Retval : MPI_Type<"Retval", "retval"> {
-  let summary = "MPI function call return value";
+  let summary = "MPI function call return value (!mpi.retval)";
   let description = [{
     This type represents a return value from an MPI function call.
     This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 887db344ed88b..6ab23ff86b3c6 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_INITALLEXTENSIONS_H_
 #define MLIR_INITALLEXTENSIONS_H_
 
+#include "Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
@@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertFuncToLLVMInterface(registry);
   index::registerConvertIndexToLLVMInterface(registry);
   registerConvertMathToLLVMInterface(registry);
+  mpi::registerConvertMPIToLLVMInterface(registry);
   registerConvertMemRefToLLVMInterface(registry);
   registerConvertNVVMToLLVMInterface(registry);
   registerConvertOpenMPToLLVMInterface(registry);
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index fa904a33ebf96..b6c21440c571c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -42,6 +42,7 @@ add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
 add_subdirectory(MeshToMPI)
+add_subdirectory(MPIToLLVM)
 add_subdirectory(NVGPUToNVVM)
 add_subdirectory(NVVMToLLVM)
 add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..f81fb25e56840
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRMPIToLLVM
+  MPIToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRMPIDialect
+  )
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
new file mode 100644
index 0000000000000..d87a10aab8f49
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -0,0 +1,230 @@
+//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Pass/Pass.h"
+
+#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>
+
+using namespace mlir;
+
+namespace {
+
+struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+// TODO: this was copied from GPUOpsLowering.cpp:288
+// is this okay, or should this be moved to some common file?
+LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
+                                     ConversionPatternRewriter &rewriter,
+                                     StringRef name,
+                                     LLVM::LLVMFunctionType type) {
+  LLVM::LLVMFuncOp ret;
+  if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
+                                            LLVM::Linkage::External);
+  }
+  return ret;
+}
+
+// TODO: this is pretty close to getOrDefineFunction, can probably be factored
+LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
+                                         ConversionPatternRewriter &rewriter,
+                                         StringRef name,
+                                         LLVM::LLVMStructType type) {
+  LLVM::GlobalOp ret;
+  if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    ret = rewriter.create<LLVM::GlobalOp>(
+        loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
+        /*value=*/Attribute(), /*alignment=*/0, 0);
+  }
+  return ret;
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// InitOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+                                ConversionPatternRewriter &rewriter) const {
+  // get loc
+  auto loc = op.getLoc();
+
+  // ptrType `!llvm.ptr`
+  Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+  // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
+  auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+  Value llvmnull = nullPtrOp.getRes();
+
+  // grab a reference to the global module op:
+  auto moduleOp = op->getParentOfType<ModuleOp>();
+
+  // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
+  auto initFuncType =
+      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
+  // get or create function declaration:
+  LLVM::LLVMFuncOp initDecl =
+      getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
+
+  // replace init with function call
+  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
+                                            ValueRange{llvmnull, llvmnull});
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FinalizeOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  // get loc
+  auto loc = op.getLoc();
+
+  // grab a reference to the global module op:
+  auto moduleOp = op->getParentOfType<ModuleOp>();
+
+  // LLVM Function type representing `i32 MPI_Finalize()`
+  auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
+  // get or create function declaration:
+  LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
+                                                  "MPI_Finalize", initFuncType);
+
+  // replace init with function call
+  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CommRankLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  // get some helper vars
+  auto loc = op.getLoc();
+  auto context = rewriter.getContext();
+  auto i32 = rewriter.getI32Type();
+
+  // ptrType `!llvm.ptr`
+  Type ptrType = LLVM::LLVMPointerType::get(context);
+
+  // get external opaque struct pointer type
+  auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
+
+  // grab a reference to the global module op:
+  auto moduleOp = op->getParentOfType<ModuleOp>();
+
+  // make sure global op definition exists
+  getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
+                            commStructT);
+
+  // get address of @MPI_COMM_WORLD
+  auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+  auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
+  auto commWorld = rewriter.create<LLVM::AddressOfOp>(
+      loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
+
+  // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+  auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
+  // get or create function declaration:
+  LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+      moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
+
+  // replace init with function call
+  auto callOp = rewriter.create<LLVM::CallOp>(
+      loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});
+
+  // load the rank into a register
+  auto loadedRank =
+      rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
+
+  // if retval is checked, replace uses of retval with the results from the call
+  // op
+  SmallVector<Value> replacements;
+  if (op.getRetval()) {
+    replacements.push_back(callOp.getResult());
+  }
+  // replace all uses, then erase op
+  replacements.push_back(loadedRank.getRes());
+  rewriter.replaceOp(op, replacements);
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                              RewritePatternSet &patterns) {
+  patterns.add<InitOpLowering>(converter);
+  patterns.add<CommRankOpLowering>(converter);
+  patterns.add<FinalizeOpLowering>(converter);
+}
+
+//===----------------------------------------------------------------------===//
+// ConvertToLLVMPatternInterface implementation
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Implement the interface to convert Func to LLVM.
+struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    mpi::populateMPIToLLVMConversionPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
+    dialect->addInterfaces<FuncToLLVMDialectInterface>();
+  });
+}
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
new file mode 100644
index 0000000000000..a7a44ad24909a
--- /dev/null
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -convert-to-llvm %s | FileCheck %s
+
+module {
+// CHECK:  llvm.func @MPI_Finalize() -> i32
+// CHECK:  llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK:  llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
+// CHECK:  llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+
+  func.func @mpi_test(%arg0: memref<100xf32>) {
+    %0 = mpi.init : !mpi.retval
+// CHECK:      %7 = llvm.mlir.zero : !llvm.ptr
+// CHECK-NEXT: %8 = llvm.call @MPI_Init(%7, %7) : (!llvm.ptr, !llvm.ptr) -> i32
+// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %8 : i32 to !mpi.retval
+
+
+    %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+// CHECK:      %10 = llvm.mlir.constant(1 : i32) : i32
+// CHECK-NEXT: %11 = llvm.alloca %10 x i32 : (i32) -> !llvm.ptr
+// CHECK-NEXT: %12 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
+// CHECK-NEXT: %13 = llvm.call @MPI_Comm_rank(%12, %11) : (!llvm.ptr, !llvm.ptr) -> i32
+// CHECK-NEXT: %14 = llvm.load %11 : !llvm.ptr -> i32
+// CHECK-NEXT: %15 = builtin.unrealized_conversion_cast %13 : i32 to !mpi.retval
+
+    mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+    %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+    mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+    %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+    %3 = mpi.finalize : !mpi.retval
+// CHECK:      %18 = llvm.call @MPI_Finalize() : () -> i32
+
+    %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+
+    %5 = mpi.error_class %0 : !mpi.retval
+    return
+  }
+}

>From 3eb15c405ee8189feb6ca0ef976ef3d3d8ec87de Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 19 Nov 2024 18:02:18 +0100
Subject: [PATCH 02/15] lowering MPI_send and MPI_recv; some refactoring fixing
 MPI_Send/Recv signature: added size fixing MPIR ops.mlir tests MPI: fixing
 enum names and merge conflicts separate MPIImplTraits.h

---
 .../mlir/Conversion/MPIToLLVM/MPIToLLVM.h     |   1 +
 mlir/include/mlir/Dialect/MPI/IR/MPI.td       | 154 +++----
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td    |  16 +-
 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt  |   3 +
 mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h | 119 ++++++
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   | 393 +++++++++++-------
 mlir/test/Conversion/MPIToLLVM/ops.mlir       |  74 +++-
 7 files changed, 512 insertions(+), 248 deletions(-)
 create mode 100644 mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h

diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
index 940e5e8097318..8d2698aa91c7c 100644
--- a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
+++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
@@ -19,6 +19,7 @@ class RewritePatternSet;
 #include "mlir/Conversion/Passes.h.inc"
 
 namespace mpi {
+
 void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                          RewritePatternSet &patterns);
 
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index 7c84443e5520d..df0cf9d518faf 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -42,104 +42,104 @@ def MPI_Dialect : Dialect {
 // Error classes enum:
 //===----------------------------------------------------------------------===//
 
-def MPI_CodeSuccess : I32EnumAttrCase<"MPI_SUCCESS", 0, "MPI_SUCCESS">;
-def MPI_CodeErrAccess : I32EnumAttrCase<"MPI_ERR_ACCESS", 1, "MPI_ERR_ACCESS">;
-def MPI_CodeErrAmode : I32EnumAttrCase<"MPI_ERR_AMODE", 2, "MPI_ERR_AMODE">;
-def MPI_CodeErrArg : I32EnumAttrCase<"MPI_ERR_ARG", 3, "MPI_ERR_ARG">;
-def MPI_CodeErrAssert : I32EnumAttrCase<"MPI_ERR_ASSERT", 4, "MPI_ERR_ASSERT">;
+def MPI_CodeSuccess : I32EnumAttrCase<"_MPI_SUCCESS", 0, "MPI_SUCCESS">;
+def MPI_CodeErrAccess : I32EnumAttrCase<"_MPI_ERR_ACCESS", 1, "MPI_ERR_ACCESS">;
+def MPI_CodeErrAmode : I32EnumAttrCase<"_MPI_ERR_AMODE", 2, "MPI_ERR_AMODE">;
+def MPI_CodeErrArg : I32EnumAttrCase<"_MPI_ERR_ARG", 3, "MPI_ERR_ARG">;
+def MPI_CodeErrAssert : I32EnumAttrCase<"_MPI_ERR_ASSERT", 4, "MPI_ERR_ASSERT">;
 def MPI_CodeErrBadFile
-    : I32EnumAttrCase<"MPI_ERR_BAD_FILE", 5, "MPI_ERR_BAD_FILE">;
-def MPI_CodeErrBase : I32EnumAttrCase<"MPI_ERR_BASE", 6, "MPI_ERR_BASE">;
-def MPI_CodeErrBuffer : I32EnumAttrCase<"MPI_ERR_BUFFER", 7, "MPI_ERR_BUFFER">;
-def MPI_CodeErrComm : I32EnumAttrCase<"MPI_ERR_COMM", 8, "MPI_ERR_COMM">;
+    : I32EnumAttrCase<"_MPI_ERR_BAD_FILE", 5, "MPI_ERR_BAD_FILE">;
+def MPI_CodeErrBase : I32EnumAttrCase<"_MPI_ERR_BASE", 6, "MPI_ERR_BASE">;
+def MPI_CodeErrBuffer : I32EnumAttrCase<"_MPI_ERR_BUFFER", 7, "MPI_ERR_BUFFER">;
+def MPI_CodeErrComm : I32EnumAttrCase<"_MPI_ERR_COMM", 8, "MPI_ERR_COMM">;
 def MPI_CodeErrConversion
-    : I32EnumAttrCase<"MPI_ERR_CONVERSION", 9, "MPI_ERR_CONVERSION">;
-def MPI_CodeErrCount : I32EnumAttrCase<"MPI_ERR_COUNT", 10, "MPI_ERR_COUNT">;
-def MPI_CodeErrDims : I32EnumAttrCase<"MPI_ERR_DIMS", 11, "MPI_ERR_DIMS">;
-def MPI_CodeErrDisp : I32EnumAttrCase<"MPI_ERR_DISP", 12, "MPI_ERR_DISP">;
+    : I32EnumAttrCase<"_MPI_ERR_CONVERSION", 9, "MPI_ERR_CONVERSION">;
+def MPI_CodeErrCount : I32EnumAttrCase<"_MPI_ERR_COUNT", 10, "MPI_ERR_COUNT">;
+def MPI_CodeErrDims : I32EnumAttrCase<"_MPI_ERR_DIMS", 11, "MPI_ERR_DIMS">;
+def MPI_CodeErrDisp : I32EnumAttrCase<"_MPI_ERR_DISP", 12, "MPI_ERR_DISP">;
 def MPI_CodeErrDupDatarep
-    : I32EnumAttrCase<"MPI_ERR_DUP_DATAREP", 13, "MPI_ERR_DUP_DATAREP">;
+    : I32EnumAttrCase<"_MPI_ERR_DUP_DATAREP", 13, "MPI_ERR_DUP_DATAREP">;
 def MPI_CodeErrErrhandler
-    : I32EnumAttrCase<"MPI_ERR_ERRHANDLER", 14, "MPI_ERR_ERRHANDLER">;
-def MPI_CodeErrFile : I32EnumAttrCase<"MPI_ERR_FILE", 15, "MPI_ERR_FILE">;
+    : I32EnumAttrCase<"_MPI_ERR_ERRHANDLER", 14, "MPI_ERR_ERRHANDLER">;
+def MPI_CodeErrFile : I32EnumAttrCase<"_MPI_ERR_FILE", 15, "MPI_ERR_FILE">;
 def MPI_CodeErrFileExists
-    : I32EnumAttrCase<"MPI_ERR_FILE_EXISTS", 16, "MPI_ERR_FILE_EXISTS">;
+    : I32EnumAttrCase<"_MPI_ERR_FILE_EXISTS", 16, "MPI_ERR_FILE_EXISTS">;
 def MPI_CodeErrFileInUse
-    : I32EnumAttrCase<"MPI_ERR_FILE_IN_USE", 17, "MPI_ERR_FILE_IN_USE">;
-def MPI_CodeErrGroup : I32EnumAttrCase<"MPI_ERR_GROUP", 18, "MPI_ERR_GROUP">;
-def MPI_CodeErrInfo : I32EnumAttrCase<"MPI_ERR_INFO", 19, "MPI_ERR_INFO">;
+    : I32EnumAttrCase<"_MPI_ERR_FILE_IN_USE", 17, "MPI_ERR_FILE_IN_USE">;
+def MPI_CodeErrGroup : I32EnumAttrCase<"_MPI_ERR_GROUP", 18, "MPI_ERR_GROUP">;
+def MPI_CodeErrInfo : I32EnumAttrCase<"_MPI_ERR_INFO", 19, "MPI_ERR_INFO">;
 def MPI_CodeErrInfoKey
-    : I32EnumAttrCase<"MPI_ERR_INFO_KEY", 20, "MPI_ERR_INFO_KEY">;
+    : I32EnumAttrCase<"_MPI_ERR_INFO_KEY", 20, "MPI_ERR_INFO_KEY">;
 def MPI_CodeErrInfoNokey
-    : I32EnumAttrCase<"MPI_ERR_INFO_NOKEY", 21, "MPI_ERR_INFO_NOKEY">;
+    : I32EnumAttrCase<"_MPI_ERR_INFO_NOKEY", 21, "MPI_ERR_INFO_NOKEY">;
 def MPI_CodeErrInfoValue
-    : I32EnumAttrCase<"MPI_ERR_INFO_VALUE", 22, "MPI_ERR_INFO_VALUE">;
+    : I32EnumAttrCase<"_MPI_ERR_INFO_VALUE", 22, "MPI_ERR_INFO_VALUE">;
 def MPI_CodeErrInStatus
-    : I32EnumAttrCase<"MPI_ERR_IN_STATUS", 23, "MPI_ERR_IN_STATUS">;
-def MPI_CodeErrIntern : I32EnumAttrCase<"MPI_ERR_INTERN", 24, "MPI_ERR_INTERN">;
-def MPI_CodeErrIo : I32EnumAttrCase<"MPI_ERR_IO", 25, "MPI_ERR_IO">;
-def MPI_CodeErrKeyval : I32EnumAttrCase<"MPI_ERR_KEYVAL", 26, "MPI_ERR_KEYVAL">;
+    : I32EnumAttrCase<"_MPI_ERR_IN_STATUS", 23, "MPI_ERR_IN_STATUS">;
+def MPI_CodeErrIntern : I32EnumAttrCase<"_MPI_ERR_INTERN", 24, "MPI_ERR_INTERN">;
+def MPI_CodeErrIo : I32EnumAttrCase<"_MPI_ERR_IO", 25, "MPI_ERR_IO">;
+def MPI_CodeErrKeyval : I32EnumAttrCase<"_MPI_ERR_KEYVAL", 26, "MPI_ERR_KEYVAL">;
 def MPI_CodeErrLocktype
-    : I32EnumAttrCase<"MPI_ERR_LOCKTYPE", 27, "MPI_ERR_LOCKTYPE">;
-def MPI_CodeErrName : I32EnumAttrCase<"MPI_ERR_NAME", 28, "MPI_ERR_NAME">;
-def MPI_CodeErrNoMem : I32EnumAttrCase<"MPI_ERR_NO_MEM", 29, "MPI_ERR_NO_MEM">;
+    : I32EnumAttrCase<"_MPI_ERR_LOCKTYPE", 27, "MPI_ERR_LOCKTYPE">;
+def MPI_CodeErrName : I32EnumAttrCase<"_MPI_ERR_NAME", 28, "MPI_ERR_NAME">;
+def MPI_CodeErrNoMem : I32EnumAttrCase<"_MPI_ERR_NO_MEM", 29, "MPI_ERR_NO_MEM">;
 def MPI_CodeErrNoSpace
-    : I32EnumAttrCase<"MPI_ERR_NO_SPACE", 30, "MPI_ERR_NO_SPACE">;
+    : I32EnumAttrCase<"_MPI_ERR_NO_SPACE", 30, "MPI_ERR_NO_SPACE">;
 def MPI_CodeErrNoSuchFile
-    : I32EnumAttrCase<"MPI_ERR_NO_SUCH_FILE", 31, "MPI_ERR_NO_SUCH_FILE">;
+    : I32EnumAttrCase<"_MPI_ERR_NO_SUCH_FILE", 31, "MPI_ERR_NO_SUCH_FILE">;
 def MPI_CodeErrNotSame
-    : I32EnumAttrCase<"MPI_ERR_NOT_SAME", 32, "MPI_ERR_NOT_SAME">;
-def MPI_CodeErrOp : I32EnumAttrCase<"MPI_ERR_OP", 33, "MPI_ERR_OP">;
-def MPI_CodeErrOther : I32EnumAttrCase<"MPI_ERR_OTHER", 34, "MPI_ERR_OTHER">;
+    : I32EnumAttrCase<"_MPI_ERR_NOT_SAME", 32, "MPI_ERR_NOT_SAME">;
+def MPI_CodeErrOp : I32EnumAttrCase<"_MPI_ERR_OP", 33, "MPI_ERR_OP">;
+def MPI_CodeErrOther : I32EnumAttrCase<"_MPI_ERR_OTHER", 34, "MPI_ERR_OTHER">;
 def MPI_CodeErrPending
-    : I32EnumAttrCase<"MPI_ERR_PENDING", 35, "MPI_ERR_PENDING">;
-def MPI_CodeErrPort : I32EnumAttrCase<"MPI_ERR_PORT", 36, "MPI_ERR_PORT">;
+    : I32EnumAttrCase<"_MPI_ERR_PENDING", 35, "MPI_ERR_PENDING">;
+def MPI_CodeErrPort : I32EnumAttrCase<"_MPI_ERR_PORT", 36, "MPI_ERR_PORT">;
 def MPI_CodeErrProcAborted
-    : I32EnumAttrCase<"MPI_ERR_PROC_ABORTED", 37, "MPI_ERR_PROC_ABORTED">;
-def MPI_CodeErrQuota : I32EnumAttrCase<"MPI_ERR_QUOTA", 38, "MPI_ERR_QUOTA">;
-def MPI_CodeErrRank : I32EnumAttrCase<"MPI_ERR_RANK", 39, "MPI_ERR_RANK">;
+    : I32EnumAttrCase<"_MPI_ERR_PROC_ABORTED", 37, "MPI_ERR_PROC_ABORTED">;
+def MPI_CodeErrQuota : I32EnumAttrCase<"_MPI_ERR_QUOTA", 38, "MPI_ERR_QUOTA">;
+def MPI_CodeErrRank : I32EnumAttrCase<"_MPI_ERR_RANK", 39, "MPI_ERR_RANK">;
 def MPI_CodeErrReadOnly
-    : I32EnumAttrCase<"MPI_ERR_READ_ONLY", 40, "MPI_ERR_READ_ONLY">;
+    : I32EnumAttrCase<"_MPI_ERR_READ_ONLY", 40, "MPI_ERR_READ_ONLY">;
 def MPI_CodeErrRequest
-    : I32EnumAttrCase<"MPI_ERR_REQUEST", 41, "MPI_ERR_REQUEST">;
+    : I32EnumAttrCase<"_MPI_ERR_REQUEST", 41, "MPI_ERR_REQUEST">;
 def MPI_CodeErrRmaAttach
-    : I32EnumAttrCase<"MPI_ERR_RMA_ATTACH", 42, "MPI_ERR_RMA_ATTACH">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_ATTACH", 42, "MPI_ERR_RMA_ATTACH">;
 def MPI_CodeErrRmaConflict
-    : I32EnumAttrCase<"MPI_ERR_RMA_CONFLICT", 43, "MPI_ERR_RMA_CONFLICT">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_CONFLICT", 43, "MPI_ERR_RMA_CONFLICT">;
 def MPI_CodeErrRmaFlavor
-    : I32EnumAttrCase<"MPI_ERR_RMA_FLAVOR", 44, "MPI_ERR_RMA_FLAVOR">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_FLAVOR", 44, "MPI_ERR_RMA_FLAVOR">;
 def MPI_CodeErrRmaRange
-    : I32EnumAttrCase<"MPI_ERR_RMA_RANGE", 45, "MPI_ERR_RMA_RANGE">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_RANGE", 45, "MPI_ERR_RMA_RANGE">;
 def MPI_CodeErrRmaShared
-    : I32EnumAttrCase<"MPI_ERR_RMA_SHARED", 46, "MPI_ERR_RMA_SHARED">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_SHARED", 46, "MPI_ERR_RMA_SHARED">;
 def MPI_CodeErrRmaSync
-    : I32EnumAttrCase<"MPI_ERR_RMA_SYNC", 47, "MPI_ERR_RMA_SYNC">;
-def MPI_CodeErrRoot : I32EnumAttrCase<"MPI_ERR_ROOT", 48, "MPI_ERR_ROOT">;
+    : I32EnumAttrCase<"_MPI_ERR_RMA_SYNC", 47, "MPI_ERR_RMA_SYNC">;
+def MPI_CodeErrRoot : I32EnumAttrCase<"_MPI_ERR_ROOT", 48, "MPI_ERR_ROOT">;
 def MPI_CodeErrService
-    : I32EnumAttrCase<"MPI_ERR_SERVICE", 49, "MPI_ERR_SERVICE">;
+    : I32EnumAttrCase<"_MPI_ERR_SERVICE", 49, "MPI_ERR_SERVICE">;
 def MPI_CodeErrSession
-    : I32EnumAttrCase<"MPI_ERR_SESSION", 50, "MPI_ERR_SESSION">;
-def MPI_CodeErrSize : I32EnumAttrCase<"MPI_ERR_SIZE", 51, "MPI_ERR_SIZE">;
-def MPI_CodeErrSpawn : I32EnumAttrCase<"MPI_ERR_SPAWN", 52, "MPI_ERR_SPAWN">;
-def MPI_CodeErrTag : I32EnumAttrCase<"MPI_ERR_TAG", 53, "MPI_ERR_TAG">;
+    : I32EnumAttrCase<"_MPI_ERR_SESSION", 50, "MPI_ERR_SESSION">;
+def MPI_CodeErrSize : I32EnumAttrCase<"_MPI_ERR_SIZE", 51, "MPI_ERR_SIZE">;
+def MPI_CodeErrSpawn : I32EnumAttrCase<"_MPI_ERR_SPAWN", 52, "MPI_ERR_SPAWN">;
+def MPI_CodeErrTag : I32EnumAttrCase<"_MPI_ERR_TAG", 53, "MPI_ERR_TAG">;
 def MPI_CodeErrTopology
-    : I32EnumAttrCase<"MPI_ERR_TOPOLOGY", 54, "MPI_ERR_TOPOLOGY">;
+    : I32EnumAttrCase<"_MPI_ERR_TOPOLOGY", 54, "MPI_ERR_TOPOLOGY">;
 def MPI_CodeErrTruncate
-    : I32EnumAttrCase<"MPI_ERR_TRUNCATE", 55, "MPI_ERR_TRUNCATE">;
-def MPI_CodeErrType : I32EnumAttrCase<"MPI_ERR_TYPE", 56, "MPI_ERR_TYPE">;
+    : I32EnumAttrCase<"_MPI_ERR_TRUNCATE", 55, "MPI_ERR_TRUNCATE">;
+def MPI_CodeErrType : I32EnumAttrCase<"_MPI_ERR_TYPE", 56, "MPI_ERR_TYPE">;
 def MPI_CodeErrUnknown
-    : I32EnumAttrCase<"MPI_ERR_UNKNOWN", 57, "MPI_ERR_UNKNOWN">;
+    : I32EnumAttrCase<"_MPI_ERR_UNKNOWN", 57, "MPI_ERR_UNKNOWN">;
 def MPI_CodeErrUnsupportedDatarep
-    : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_DATAREP", 58,
+    : I32EnumAttrCase<"_MPI_ERR_UNSUPPORTED_DATAREP", 58,
                       "MPI_ERR_UNSUPPORTED_DATAREP">;
 def MPI_CodeErrUnsupportedOperation
-    : I32EnumAttrCase<"MPI_ERR_UNSUPPORTED_OPERATION", 59,
+    : I32EnumAttrCase<"_MPI_ERR_UNSUPPORTED_OPERATION", 59,
                       "MPI_ERR_UNSUPPORTED_OPERATION">;
 def MPI_CodeErrValueTooLarge
-    : I32EnumAttrCase<"MPI_ERR_VALUE_TOO_LARGE", 60, "MPI_ERR_VALUE_TOO_LARGE">;
-def MPI_CodeErrWin : I32EnumAttrCase<"MPI_ERR_WIN", 61, "MPI_ERR_WIN">;
+    : I32EnumAttrCase<"_MPI_ERR_VALUE_TOO_LARGE", 60, "MPI_ERR_VALUE_TOO_LARGE">;
+def MPI_CodeErrWin : I32EnumAttrCase<"_MPI_ERR_WIN", 61, "MPI_ERR_WIN">;
 def MPI_CodeErrLastcode
-    : I32EnumAttrCase<"MPI_ERR_LASTCODE", 62, "MPI_ERR_LASTCODE">;
+    : I32EnumAttrCase<"_MPI_ERR_LASTCODE", 62, "MPI_ERR_LASTCODE">;
 
 def MPI_ErrorClassEnum
     : I32EnumAttr<"MPI_ErrorClassEnum", "MPI error class name", [
@@ -215,20 +215,20 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
-def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">;
-def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">;
-def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">;
-def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">;
-def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">;
-def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">;
-def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">;
-def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">;
-def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">;
-def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">;
-def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">;
-def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
-def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
-def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
+def MPI_OpNull : I32EnumAttrCase<"_MPI_OP_NULL", 0, "MPI_OP_NULL">;
+def MPI_OpMax : I32EnumAttrCase<"_MPI_MAX", 1, "MPI_MAX">;
+def MPI_OpMin : I32EnumAttrCase<"_MPI_MIN", 2, "MPI_MIN">;
+def MPI_OpSum : I32EnumAttrCase<"_MPI_SUM", 3, "MPI_SUM">;
+def MPI_OpProd : I32EnumAttrCase<"_MPI_PROD", 4, "MPI_PROD">;
+def MPI_OpLand : I32EnumAttrCase<"_MPI_LAND", 5, "MPI_LAND">;
+def MPI_OpBand : I32EnumAttrCase<"_MPI_BAND", 6, "MPI_BAND">;
+def MPI_OpLor : I32EnumAttrCase<"_MPI_LOR", 7, "MPI_LOR">;
+def MPI_OpBor : I32EnumAttrCase<"_MPI_BOR", 8, "MPI_BOR">;
+def MPI_OpLxor : I32EnumAttrCase<"_MPI_LXOR", 9, "MPI_LXOR">;
+def MPI_OpBxor : I32EnumAttrCase<"_MPI_BXOR", 10, "MPI_BXOR">;
+def MPI_OpMinloc : I32EnumAttrCase<"_MPI_MINLOC", 11, "MPI_MINLOC">;
+def MPI_OpMaxloc : I32EnumAttrCase<"_MPI_MAXLOC", 12, "MPI_MAXLOC">;
+def MPI_OpReplace : I32EnumAttrCase<"_MPI_REPLACE", 13, "MPI_REPLACE">;
 
 def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
       MPI_OpNull,
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 284ba72af9768..db28bd09678f8 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -102,13 +102,13 @@ def MPI_SendOp : MPI_Op<"send", []> {
   let arguments = (
     ins AnyMemRef : $ref,
     I32 : $tag,
-    I32 : $rank
+    I32 : $dest
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
-                       "type($ref) `,` type($tag) `,` type($rank)"
+  let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` "
+                       "type($ref) `,` type($tag) `,` type($dest)"
                        "(`->` type($retval)^)?";
   let hasCanonicalizer = 1;
 }
@@ -154,11 +154,11 @@ def MPI_ISendOp : MPI_Op<"isend", []> {
 //===----------------------------------------------------------------------===//
 
 def MPI_RecvOp : MPI_Op<"recv", []> {
-  let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
+  let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, "
                 "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
   let description = [{
     MPI_Recv performs a blocking receive of `size` elements of type `dtype` 
-    from rank `dest`. The `tag` value and communicator enables the library to 
+    from rank `source`. The `tag` value and communicator enables the library to
     determine the matching of multiple sends and receives between the same 
     ranks.
 
@@ -172,13 +172,13 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
 
   let arguments = (
     ins AnyMemRef : $ref,
-    I32 : $tag, I32 : $rank
+    I32 : $tag, I32 : $source
   );
 
   let results = (outs Optional<MPI_Retval>:$retval);
 
-  let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`"
-                       "type($ref) `,` type($tag) `,` type($rank)"
+  let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` "
+                       "type($ref) `,` type($tag) `,` type($source)"
                        "(`->` type($retval)^)?";
   let hasCanonicalizer = 1;
 }
diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
index f81fb25e56840..5d6b82e655b82 100644
--- a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -1,8 +1,11 @@
+find_package(MPI REQUIRED) # MPI_SKIP_COMPILER_WRAPPER TRUE MPI_CXX_SKIP_MPICXX TRUE)
+
 add_mlir_conversion_library(MLIRMPIToLLVM
   MPIToLLVM.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
+  ${MPI_C_HEADER_DIR}
 
   DEPENDS
   MLIRConversionPassIncGen
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
new file mode 100644
index 0000000000000..5fd88d0aab4ac
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
@@ -0,0 +1,119 @@
+#define MPICH_SKIP_MPICXX 1
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <mpi.h>
+
+namespace {
+
+// when lowerring the mpi dialect to functions calls certain details
+// differ between various MPI implementations. This class will provide
+// these depending on the MPI implementation that got included.
+struct MPIImplTraits {
+  // get/create MPI_COMM_WORLD as a mlir::Value
+  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+                                  const mlir::Location loc,
+                                  mlir::ConversionPatternRewriter &rewriter);
+  // get/create MPI datatype as a mlir::Value which corresponds to the given
+  // mlir::Type
+  static mlir::Value getDataType(const mlir::Location loc,
+                                 mlir::ConversionPatternRewriter &rewriter,
+                                 mlir::Type type);
+};
+
+// ****************************************************************************
+// Intel MPI
+#ifdef IMPI_DEVICE_EXPORT
+
+mlir::Value
+MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                            mlir::ConversionPatternRewriter &rewriter) {
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                 MPI_COMM_WORLD);
+}
+
+mlir::Value
+MPIImplTraits::getDataType(const mlir::Location loc,
+                           mlir::ConversionPatternRewriter &rewriter,
+                           mlir::Type type) {
+  int32_t mtype = 0;
+  if (type.isF32())
+    mtype = MPI_FLOAT;
+  else if (type.isF64())
+    mtype = MPI_DOUBLE;
+  else if (type.isInteger(64) && !type.isUnsignedInteger())
+    mtype = MPI_LONG;
+  else if (type.isInteger(64))
+    mtype = MPI_UNSIGNED_LONG;
+  else if (type.isInteger(32) && !type.isUnsignedInteger())
+    mtype = MPI_INT;
+  else if (type.isInteger(32))
+    mtype = MPI_UNSIGNED;
+  else if (type.isInteger(16) && !type.isUnsignedInteger())
+    mtype = MPI_SHORT;
+  else if (type.isInteger(16))
+    mtype = MPI_UNSIGNED_SHORT;
+  else if (type.isInteger(8) && !type.isUnsignedInteger())
+    mtype = MPI_CHAR;
+  else if (type.isInteger(8))
+    mtype = MPI_UNSIGNED_CHAR;
+  else
+    assert(false && "unsupported type");
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                 mtype);
+}
+
+// ****************************************************************************
+// OpenMPI
+#elif defined(OPEN_MPI) && OPEN_MPI == 1
+
+// TODO: this is pretty close to getOrDefineFunction, can probably be factored
+static mlir::LLVM::GlobalOp
+getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                          mlir::ConversionPatternRewriter &rewriter,
+                          mlir::StringRef name,
+                          mlir::LLVM::LLVMStructType type) {
+  mlir::LLVM::GlobalOp ret;
+  if (!(ret = moduleOp.lookupSymbol<mlir::LLVM::GlobalOp>(name))) {
+    mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    ret = rewriter.create<mlir::LLVM::GlobalOp>(
+        loc, type, /*isConstant=*/false, mlir::LLVM::Linkage::External, name,
+        /*value=*/mlir::Attribute(), /*alignment=*/0, 0);
+  }
+  return ret;
+}
+
+mlir::Value
+MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                            mlir::ConversionPatternRewriter &rewriter) {
+  auto context = rewriter.getContext();
+  auto i32 = rewriter.getI32Type();
+  // ptrType `!llvm.ptr`
+  mlir::Type ptrType = mlir::LLVM::LLVMPointerType::get(context);
+  // get external opaque struct pointer type
+  auto commStructT =
+      mlir::LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
+
+  // make sure global op definition exists
+  getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
+                            commStructT);
+
+  // get address of @MPI_COMM_WORLD
+  return rewriter.create<mlir::LLVM::AddressOfOp>(
+      loc, ptrType, mlir::SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
+}
+
+mlir::Value
+MPIImplTraits::getDataType(const mlir::Location loc,
+                           mlir::ConversionPatternRewriter &rewriter,
+                           mlir::Type type) {
+  assert(false && "getDataType not implemented for this MPI implementation");
+  return {};
+}
+
+#else
+#error "Unsupported MPI implementation"
+#endif
+
+} // namespace
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d87a10aab8f49..adeb9e85f05f7 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -6,9 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+// This must go first (MPI gets confused otherwise)
+#include "MPIImplTraits.h"
 
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Pass/Pass.h"
@@ -17,38 +19,13 @@
 
 using namespace mlir;
 
-namespace {
-
-struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
-  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
-  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
-  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
 // TODO: this was copied from GPUOpsLowering.cpp:288
 // is this okay, or should this be moved to some common file?
-LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
-                                     ConversionPatternRewriter &rewriter,
-                                     StringRef name,
-                                     LLVM::LLVMFunctionType type) {
+static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
+                                            const Location loc,
+                                            ConversionPatternRewriter &rewriter,
+                                            StringRef name,
+                                            LLVM::LLVMFunctionType type) {
   LLVM::LLVMFuncOp ret;
   if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
     ConversionPatternRewriter::InsertionGuard guard(rewriter);
@@ -59,157 +36,267 @@ LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
   return ret;
 }
 
-// TODO: this is pretty close to getOrDefineFunction, can probably be factored
-LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
-                                         ConversionPatternRewriter &rewriter,
-                                         StringRef name,
-                                         LLVM::LLVMStructType type) {
-  LLVM::GlobalOp ret;
-  if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    ret = rewriter.create<LLVM::GlobalOp>(
-        loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
-        /*value=*/Attribute(), /*alignment=*/0, 0);
-  }
-  return ret;
-}
-
-} // namespace
+namespace {
 
 //===----------------------------------------------------------------------===//
 // InitOpLowering
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
-                                ConversionPatternRewriter &rewriter) const {
-  // get loc
-  auto loc = op.getLoc();
+struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
-  // ptrType `!llvm.ptr`
-  Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+  LogicalResult
+  matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get loc
+    auto loc = op.getLoc();
 
-  // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
-  auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
-  Value llvmnull = nullPtrOp.getRes();
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
 
-  // grab a reference to the global module op:
-  auto moduleOp = op->getParentOfType<ModuleOp>();
+    // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
+    auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+    Value llvmnull = nullPtrOp.getRes();
 
-  // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
-  auto initFuncType =
-      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
-  // get or create function declaration:
-  LLVM::LLVMFuncOp initDecl =
-      getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
 
-  // replace init with function call
-  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
-                                            ValueRange{llvmnull, llvmnull});
+    // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
+    auto initFuncType =
+        LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp initDecl =
+        getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
 
-  return success();
-}
+    // replace init with function call
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
+                                              ValueRange{llvmnull, llvmnull});
+
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
 // FinalizeOpLowering
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
-                                    ConversionPatternRewriter &rewriter) const {
-  // get loc
-  auto loc = op.getLoc();
+struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
-  // grab a reference to the global module op:
-  auto moduleOp = op->getParentOfType<ModuleOp>();
+  LogicalResult
+  matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get loc
+    auto loc = op.getLoc();
 
-  // LLVM Function type representing `i32 MPI_Finalize()`
-  auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
-  // get or create function declaration:
-  LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
-                                                  "MPI_Finalize", initFuncType);
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
 
-  // replace init with function call
-  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
+    // LLVM Function type representing `i32 MPI_Finalize()`
+    auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
 
-  return success();
-}
+    // replace init with function call
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
+
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// CommRankLowering
+// CommRankOpLowering
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
-                                    ConversionPatternRewriter &rewriter) const {
-  // get some helper vars
-  auto loc = op.getLoc();
-  auto context = rewriter.getContext();
-  auto i32 = rewriter.getI32Type();
-
-  // ptrType `!llvm.ptr`
-  Type ptrType = LLVM::LLVMPointerType::get(context);
-
-  // get external opaque struct pointer type
-  auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
-
-  // grab a reference to the global module op:
-  auto moduleOp = op->getParentOfType<ModuleOp>();
-
-  // make sure global op definition exists
-  getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
-                            commStructT);
-
-  // get address of @MPI_COMM_WORLD
-  auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
-  auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
-  auto commWorld = rewriter.create<LLVM::AddressOfOp>(
-      loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
-
-  // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
-  auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
-  // get or create function declaration:
-  LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
-      moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
-
-  // replace init with function call
-  auto callOp = rewriter.create<LLVM::CallOp>(
-      loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});
-
-  // load the rank into a register
-  auto loadedRank =
-      rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
-
-  // if retval is checked, replace uses of retval with the results from the call
-  // op
-  SmallVector<Value> replacements;
-  if (op.getRetval()) {
-    replacements.push_back(callOp.getResult());
+struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get some helper vars
+    auto loc = op.getLoc();
+    auto context = rewriter.getContext();
+    auto i32 = rewriter.getI32Type();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+
+    // get MPI_COMM_WORLD
+    auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+
+    // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
+    auto rankFuncType =
+        LLVM::LLVMFunctionType::get(i32, {commWorld.getType(), ptrType});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
+
+    // replace init with function call
+    auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+    auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
+    auto callOp = rewriter.create<LLVM::CallOp>(
+        loc, initDecl, ValueRange{commWorld, rankptr.getRes()});
+
+    // load the rank into a register
+    auto loadedRank =
+        rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
+
+    // if retval is checked, replace uses of retval with the results from the
+    // call op
+    SmallVector<Value> replacements;
+    if (op.getRetval()) {
+      replacements.push_back(callOp.getResult());
+    }
+    // replace all uses, then erase op
+    replacements.push_back(loadedRank.getRes());
+    rewriter.replaceOp(op, replacements);
+
+    return success();
   }
-  // replace all uses, then erase op
-  replacements.push_back(loadedRank.getRes());
-  rewriter.replaceOp(op, replacements);
+};
 
-  return success();
-}
+//===----------------------------------------------------------------------===//
+// SendOpLowering
+//===----------------------------------------------------------------------===//
+
+struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get some helper vars
+    auto loc = op.getLoc();
+    auto context = rewriter.getContext();
+    auto i32 = rewriter.getI32Type();
+    auto i64 = rewriter.getI64Type();
+    auto memRef = adaptor.getRef();
+    auto elemType = op.getRef().getType().getElementType();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+
+    // get MPI_COMM_WORLD, dataType and pointer
+    auto dataPtr =
+        rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1)
+            .getResult();
+    auto offset =
+        rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2).getResult();
+    dataPtr =
+        rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset)
+            .getResult();
+    auto size =
+        rewriter
+            .create<LLVM::ExtractValueOp>(loc, memRef, ArrayRef<int64_t>{3, 0})
+            .getResult();
+    size = rewriter.create<LLVM::TruncOp>(loc, i32, size).getResult();
+    auto dataType = MPIImplTraits::getDataType(loc, rewriter, elemType);
+    auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+
+    // LLVM Function type representing `i32 MPI_send(datatype, dst, tag, comm)`
+    auto funcType = LLVM::LLVMFunctionType::get(
+        i32, {ptrType, i32, i32, i32, i32, commWorld.getType()});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp funcDecl =
+        getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
+
+    // replace op with function call
+    auto funcCall = rewriter.create<LLVM::CallOp>(
+        loc, funcDecl,
+        ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
+                   commWorld});
+    if (op.getRetval()) {
+      rewriter.replaceOp(op, funcCall.getResult());
+    } else {
+      rewriter.eraseOp(op);
+    }
+
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
-// Pattern Population
+// RecvOpLowering
 //===----------------------------------------------------------------------===//
 
-void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                              RewritePatternSet &patterns) {
-  patterns.add<InitOpLowering>(converter);
-  patterns.add<CommRankOpLowering>(converter);
-  patterns.add<FinalizeOpLowering>(converter);
-}
+struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // get some helper vars
+    auto loc = op.getLoc();
+    auto context = rewriter.getContext();
+    auto i32 = rewriter.getI32Type();
+    auto i64 = rewriter.getI64Type();
+    auto memRef = adaptor.getRef();
+    auto elemType = op.getRef().getType().getElementType();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+
+    // grab a reference to the global module op:
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+
+    // get MPI_COMM_WORLD, dataType, status_ignore and pointer
+    auto dataPtr =
+        rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1)
+            .getResult();
+    auto offset =
+        rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2).getResult();
+    dataPtr =
+        rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset)
+            .getResult();
+    auto size =
+        rewriter
+            .create<LLVM::ExtractValueOp>(loc, memRef, ArrayRef<int64_t>{3, 0})
+            .getResult();
+    size = rewriter.create<LLVM::TruncOp>(loc, i32, size).getResult();
+    auto dataType = MPIImplTraits::getDataType(loc, rewriter, elemType);
+    auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+    auto statusIgnore =
+        rewriter
+            .create<LLVM::ConstantOp>(
+                loc, i64, reinterpret_cast<int64_t>(MPI_STATUS_IGNORE))
+            .getResult();
+    statusIgnore = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore)
+                       .getResult();
+
+    // LLVM Function type representing `i32 MPI_Recv(datatype, dst, tag, comm)`
+    auto funcType = LLVM::LLVMFunctionType::get(
+        i32, {ptrType, i32, i32, i32, i32, commWorld.getType(), ptrType});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp funcDecl =
+        getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
+
+    // replace op with function call
+    auto funcCall = rewriter.create<LLVM::CallOp>(
+        loc, funcDecl,
+        ValueRange{dataPtr, size, dataType, adaptor.getSource(),
+                   adaptor.getTag(), commWorld, statusIgnore});
+    if (op.getRetval()) {
+      rewriter.replaceOp(op, funcCall.getResult());
+    } else {
+      rewriter.eraseOp(op);
+    }
+
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
 // ConvertToLLVMPatternInterface implementation
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// Implement the interface to convert Func to LLVM.
 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
@@ -223,7 +310,17 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
 };
 } // namespace
 
-void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mlir::mpi::populateMPIToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
+               SendOpLowering, RecvOpLowering>(converter);
+}
+
+void mlir::mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
     dialect->addInterfaces<FuncToLLVMDialectInterface>();
   });
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
index a7a44ad24909a..61a4219fe35fc 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -1,36 +1,80 @@
 // RUN: mlir-opt -convert-to-llvm %s | FileCheck %s
 
 module {
-// CHECK:  llvm.func @MPI_Finalize() -> i32
-// CHECK:  llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
-// CHECK:  llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
-// CHECK:  llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+  // CHECK: llvm.func @MPI_Finalize() -> i32
+  // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+  // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+  // CHECK: llvm.func @MPI_Comm_rank({{.*}}, !llvm.ptr) -> i32
+  // COMM: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
+  // CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
 
   func.func @mpi_test(%arg0: memref<100xf32>) {
+    // CHECK: [[varg0:%.*]]: !llvm.ptr, [[varg1:%.*]]: !llvm.ptr, [[varg2:%.*]]: i64, [[varg3:%.*]]: i64, [[varg4:%.*]]: i64
+    // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK-NEXT: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
+    // CHECK-NEXT: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
+    // CHECK-NEXT: [[v8:%.*]] = builtin.unrealized_conversion_cast [[v7]] : i32 to !mpi.retval
     %0 = mpi.init : !mpi.retval
-// CHECK:      %7 = llvm.mlir.zero : !llvm.ptr
-// CHECK-NEXT: %8 = llvm.call @MPI_Init(%7, %7) : (!llvm.ptr, !llvm.ptr) -> i32
-// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %8 : i32 to !mpi.retval
-
 
+    // CHECK: [[v9:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK-NEXT: [[v10:%.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT: [[v11:%.*]] = llvm.alloca [[v10]] x i32 : (i32) -> !llvm.ptr
+    // CHECK-NEXT: [[v12:%.*]] = llvm.call @MPI_Comm_rank([[v9]], [[v11]]) : (i32, !llvm.ptr) -> i32
+    // CHECK-NEXT: [[v13:%.*]] = builtin.unrealized_conversion_cast [[v12]] : i32 to !mpi.retval
+    // CHECK-NEXT: [[v14:%.*]] = llvm.load [[v11]] : !llvm.ptr -> i32
     %retval, %rank = mpi.comm_rank : !mpi.retval, i32
-// CHECK:      %10 = llvm.mlir.constant(1 : i32) : i32
-// CHECK-NEXT: %11 = llvm.alloca %10 x i32 : (i32) -> !llvm.ptr
-// CHECK-NEXT: %12 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
-// CHECK-NEXT: %13 = llvm.call @MPI_Comm_rank(%12, %11) : (!llvm.ptr, !llvm.ptr) -> i32
-// CHECK-NEXT: %14 = llvm.load %11 : !llvm.ptr -> i32
-// CHECK-NEXT: %15 = builtin.unrealized_conversion_cast %13 : i32 to !mpi.retval
 
+    // CHECK: [[v15:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v16:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v17:%.*]] = llvm.getelementptr [[v15]][[[v16]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v18:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v19:%.*]] = llvm.trunc [[v18]] : i64 to i32
+    // CHECK-NEXT: [[v20:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK-NEXT: [[v21:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK-NEXT: [[v22:%.*]] = llvm.call @MPI_Send([[v17]], [[v19]], [[v20]], [[v14]], [[v14]], [[v21]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
     mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
 
+    // CHECK: [[v23:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v24:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v25:%.*]] = llvm.getelementptr [[v23]][[[v24]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v26:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v27:%.*]] = llvm.trunc [[v26]] : i64 to i32
+    // CHECK-NEXT: [[v28:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK-NEXT: [[v29:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK-NEXT: [[v30:%.*]] = llvm.call @MPI_Send([[v25]], [[v27]], [[v28]], [[v14]], [[v14]], [[v29]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
     %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
+    // CHECK: [[v31:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v32:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v33:%.*]] = llvm.getelementptr [[v31]][[[v32]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v34:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v35:%.*]] = llvm.trunc [[v34]] : i64 to i32
+    // CHECK-NEXT: [[v36:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK-NEXT: [[v37:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK-NEXT: [[v38:%.*]] = llvm.mlir.constant(1 : i64) : i64
+    // CHECK-NEXT: [[v39:%.*]] = llvm.inttoptr [[v38]] : i64 to !llvm.ptr
+    // CHECK-NEXT: [[v40:%.*]] = llvm.call @MPI_Recv([[v33]], [[v35]], [[v36]], [[v14]], [[v14]], [[v37]], [[v39]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
     mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
 
+    // CHECK: [[v41:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v42:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v43:%.*]] = llvm.getelementptr [[v41]][[[v42]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: [[v44:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK-NEXT: [[v45:%.*]] = llvm.trunc [[v44]] : i64 to i32
+    // CHECK-NEXT: [[v46:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK-NEXT: [[v47:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK-NEXT: [[v48:%.*]] = llvm.mlir.constant(1 : i64) : i64
+    // CHECK-NEXT: [[v49:%.*]] = llvm.inttoptr [[v48]] : i64 to !llvm.ptr
+    // CHECK-NEXT: [[v50:%.*]] = llvm.call @MPI_Recv([[v43]], [[v45]], [[v46]], [[v14]], [[v14]], [[v47]], [[v49]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
     %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
+    // CHECK: llvm.call @MPI_Finalize() : () -> i32
     %3 = mpi.finalize : !mpi.retval
-// CHECK:      %18 = llvm.call @MPI_Finalize() : () -> i32
 
     %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
 

>From 3b53ba82288c1266fc505a879b1f1829b6559f3e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 7 Feb 2025 09:07:19 +0100
Subject: [PATCH 03/15] fixed finding MPI in cmake for MPIToLLVM MPIImplTraits
 for OpenMPI

---
 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt  | 20 ++++-
 mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h | 74 +++++++++++++------
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   | 19 +++--
 mlir/test/Conversion/MPIToLLVM/ops.mlir       | 38 +++++-----
 4 files changed, 101 insertions(+), 50 deletions(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
index 5d6b82e655b82..17df603ff5686 100644
--- a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -1,11 +1,20 @@
-find_package(MPI REQUIRED) # MPI_SKIP_COMPILER_WRAPPER TRUE MPI_CXX_SKIP_MPICXX TRUE)
+find_path(MPI_C_HEADER_DIR mpi.h
+    PATHS $ENV{I_MPI_ROOT}/include
+          $ENV{MPI_HOME}/include
+          $ENV{MPI_ROOT}/include)
+if(MPI_C_HEADER_DIR)
+  # cmake_path(REMOVE_FILENAME MPI_C_HEADER_DIR)
+  message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
+else()
+  message(WARNING "MPI not found, disabling MLIRMPIToLLVM conversion")
+  return()
+endif()
 
 add_mlir_conversion_library(MLIRMPIToLLVM
   MPIToLLVM.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM
-  ${MPI_C_HEADER_DIR}
 
   DEPENDS
   MLIRConversionPassIncGen
@@ -17,4 +26,9 @@ add_mlir_conversion_library(MLIRMPIToLLVM
   MLIRLLVMCommonConversion
   MLIRLLVMDialect
   MLIRMPIDialect
-  )
+)
+target_include_directories(
+    MLIRMPIToLLVM
+    PRIVATE
+    ${MPI_C_HEADER_DIR}
+)
\ No newline at end of file
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
index 5fd88d0aab4ac..09811e1cb7c61 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
@@ -1,4 +1,5 @@
 #define MPICH_SKIP_MPICXX 1
+#define OMPI_SKIP_MPICXX 1
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -16,7 +17,8 @@ struct MPIImplTraits {
                                   mlir::ConversionPatternRewriter &rewriter);
   // get/create MPI datatype as a mlir::Value which corresponds to the given
   // mlir::Type
-  static mlir::Value getDataType(const mlir::Location loc,
+  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+                                 const mlir::Location loc,
                                  mlir::ConversionPatternRewriter &rewriter,
                                  mlir::Type type);
 };
@@ -33,7 +35,7 @@ MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
 }
 
 mlir::Value
-MPIImplTraits::getDataType(const mlir::Location loc,
+MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
                            mlir::ConversionPatternRewriter &rewriter,
                            mlir::Type type) {
   int32_t mtype = 0;
@@ -42,21 +44,21 @@ MPIImplTraits::getDataType(const mlir::Location loc,
   else if (type.isF64())
     mtype = MPI_DOUBLE;
   else if (type.isInteger(64) && !type.isUnsignedInteger())
-    mtype = MPI_LONG;
+    mtype = MPI_INT64_T;
   else if (type.isInteger(64))
-    mtype = MPI_UNSIGNED_LONG;
+    mtype = MPI_UINT64_T;
   else if (type.isInteger(32) && !type.isUnsignedInteger())
-    mtype = MPI_INT;
+    mtype = MPI_INT32_T;
   else if (type.isInteger(32))
-    mtype = MPI_UNSIGNED;
+    mtype = MPI_UINT32_T;
   else if (type.isInteger(16) && !type.isUnsignedInteger())
-    mtype = MPI_SHORT;
+    mtype = MPI_INT16_T;
   else if (type.isInteger(16))
-    mtype = MPI_UNSIGNED_SHORT;
+    mtype = MPI_UINT16_T;
   else if (type.isInteger(8) && !type.isUnsignedInteger())
-    mtype = MPI_CHAR;
+    mtype = MPI_INT8_T;
   else if (type.isInteger(8))
-    mtype = MPI_UNSIGNED_CHAR;
+    mtype = MPI_UINT8_T;
   else
     assert(false && "unsupported type");
   return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
@@ -88,28 +90,58 @@ mlir::Value
 MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
                             mlir::ConversionPatternRewriter &rewriter) {
   auto context = rewriter.getContext();
-  auto i32 = rewriter.getI32Type();
-  // ptrType `!llvm.ptr`
-  mlir::Type ptrType = mlir::LLVM::LLVMPointerType::get(context);
   // get external opaque struct pointer type
   auto commStructT =
-      mlir::LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);
+      mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
+  const char *name = "ompi_mpi_comm_world";
 
   // make sure global op definition exists
-  getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
-                            commStructT);
+  (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
 
-  // get address of @MPI_COMM_WORLD
+  // get address of symbol
   return rewriter.create<mlir::LLVM::AddressOfOp>(
-      loc, ptrType, mlir::SymbolRefAttr::get(context, "MPI_COMM_WORLD"));
+      loc, mlir::LLVM::LLVMPointerType::get(context),
+      mlir::SymbolRefAttr::get(context, name));
 }
 
 mlir::Value
-MPIImplTraits::getDataType(const mlir::Location loc,
+MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
                            mlir::ConversionPatternRewriter &rewriter,
                            mlir::Type type) {
-  assert(false && "getDataType not implemented for this MPI implementation");
-  return {};
+  const char *mtype = nullptr;
+  if (type.isF32())
+    mtype = "ompi_mpi_float";
+  else if (type.isF64())
+    mtype = "ompi_mpi_double";
+  else if (type.isInteger(64) && !type.isUnsignedInteger())
+    mtype = "ompi_mpi_int64_t";
+  else if (type.isInteger(64))
+    mtype = "ompi_mpi_uint64_t";
+  else if (type.isInteger(32) && !type.isUnsignedInteger())
+    mtype = "ompi_mpi_int32_t";
+  else if (type.isInteger(32))
+    mtype = "ompi_mpi_uint32_t";
+  else if (type.isInteger(16) && !type.isUnsignedInteger())
+    mtype = "ompi_mpi_int16_t";
+  else if (type.isInteger(16))
+    mtype = "ompi_mpi_uint16_t";
+  else if (type.isInteger(8) && !type.isUnsignedInteger())
+    mtype = "ompi_mpi_int8_t";
+  else if (type.isInteger(8))
+    mtype = "ompi_mpi_uint8_t";
+  else
+    assert(false && "unsupported type");
+
+  auto context = rewriter.getContext();
+  // get external opaque struct pointer type
+  auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
+      "ompi_predefined_datatype_t", context);
+  // make sure global op definition exists
+  (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype, commStructT);
+  // get address of symbol
+  return rewriter.create<mlir::LLVM::AddressOfOp>(
+      loc, mlir::LLVM::LLVMPointerType::get(context),
+      mlir::SymbolRefAttr::get(context, mtype));
 }
 
 #else
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index adeb9e85f05f7..9980c789cda6e 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -198,12 +198,14 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
             .create<LLVM::ExtractValueOp>(loc, memRef, ArrayRef<int64_t>{3, 0})
             .getResult();
     size = rewriter.create<LLVM::TruncOp>(loc, i32, size).getResult();
-    auto dataType = MPIImplTraits::getDataType(loc, rewriter, elemType);
+    auto dataType =
+        MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
     auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
 
-    // LLVM Function type representing `i32 MPI_send(datatype, dst, tag, comm)`
+    // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
+    // tag, comm)`
     auto funcType = LLVM::LLVMFunctionType::get(
-        i32, {ptrType, i32, i32, i32, i32, commWorld.getType()});
+        i32, {ptrType, i32, dataType.getType(), i32, i32, commWorld.getType()});
     // get or create function declaration:
     LLVM::LLVMFuncOp funcDecl =
         getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
@@ -261,7 +263,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
             .create<LLVM::ExtractValueOp>(loc, memRef, ArrayRef<int64_t>{3, 0})
             .getResult();
     size = rewriter.create<LLVM::TruncOp>(loc, i32, size).getResult();
-    auto dataType = MPIImplTraits::getDataType(loc, rewriter, elemType);
+    auto dataType =
+        MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
     auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
     auto statusIgnore =
         rewriter
@@ -271,9 +274,11 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     statusIgnore = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore)
                        .getResult();
 
-    // LLVM Function type representing `i32 MPI_Recv(datatype, dst, tag, comm)`
-    auto funcType = LLVM::LLVMFunctionType::get(
-        i32, {ptrType, i32, i32, i32, i32, commWorld.getType(), ptrType});
+    // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
+    // tag, comm)`
+    auto funcType =
+        LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
+                                          i32, commWorld.getType(), ptrType});
     // get or create function declaration:
     LLVM::LLVMFuncOp funcDecl =
         getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
index 61a4219fe35fc..449e6418976cc 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -2,9 +2,9 @@
 
 module {
   // CHECK: llvm.func @MPI_Finalize() -> i32
-  // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
-  // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
-  // CHECK: llvm.func @MPI_Comm_rank({{.*}}, !llvm.ptr) -> i32
+  // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}, !llvm.ptr) -> i32
+  // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}) -> i32
+  // CHECK: llvm.func @MPI_Comm_rank({{.+}}, !llvm.ptr) -> i32
   // COMM: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
   // CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
 
@@ -21,10 +21,10 @@ module {
     // CHECK-NEXT: [[v8:%.*]] = builtin.unrealized_conversion_cast [[v7]] : i32 to !mpi.retval
     %0 = mpi.init : !mpi.retval
 
-    // CHECK: [[v9:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK: [[v9:%.*]] = llvm.mlir.
     // CHECK-NEXT: [[v10:%.*]] = llvm.mlir.constant(1 : i32) : i32
     // CHECK-NEXT: [[v11:%.*]] = llvm.alloca [[v10]] x i32 : (i32) -> !llvm.ptr
-    // CHECK-NEXT: [[v12:%.*]] = llvm.call @MPI_Comm_rank([[v9]], [[v11]]) : (i32, !llvm.ptr) -> i32
+    // CHECK-NEXT: [[v12:%.*]] = llvm.call @MPI_Comm_rank([[v9]], [[v11]]) : ({{.+}}, !llvm.ptr) -> i32
     // CHECK-NEXT: [[v13:%.*]] = builtin.unrealized_conversion_cast [[v12]] : i32 to !mpi.retval
     // CHECK-NEXT: [[v14:%.*]] = llvm.load [[v11]] : !llvm.ptr -> i32
     %retval, %rank = mpi.comm_rank : !mpi.retval, i32
@@ -34,9 +34,9 @@ module {
     // CHECK-NEXT: [[v17:%.*]] = llvm.getelementptr [[v15]][[[v16]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK-NEXT: [[v18:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
     // CHECK-NEXT: [[v19:%.*]] = llvm.trunc [[v18]] : i64 to i32
-    // CHECK-NEXT: [[v20:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
-    // CHECK-NEXT: [[v21:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
-    // CHECK-NEXT: [[v22:%.*]] = llvm.call @MPI_Send([[v17]], [[v19]], [[v20]], [[v14]], [[v14]], [[v21]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+    // CHECK-NEXT: [[v20:%.*]] = llvm.mlir.
+    // CHECK-NEXT: [[v21:%.*]] = llvm.mlir.
+    // CHECK-NEXT: [[v22:%.*]] = llvm.call @MPI_Send([[v17]], [[v19]], [[v20]], [[v14]], [[v14]], [[v21]])
     mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
 
     // CHECK: [[v23:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
@@ -44,9 +44,9 @@ module {
     // CHECK-NEXT: [[v25:%.*]] = llvm.getelementptr [[v23]][[[v24]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK-NEXT: [[v26:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
     // CHECK-NEXT: [[v27:%.*]] = llvm.trunc [[v26]] : i64 to i32
-    // CHECK-NEXT: [[v28:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
-    // CHECK-NEXT: [[v29:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
-    // CHECK-NEXT: [[v30:%.*]] = llvm.call @MPI_Send([[v25]], [[v27]], [[v28]], [[v14]], [[v14]], [[v29]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+    // CHECK-NEXT: [[v28:%.*]] = llvm.mlir.
+    // CHECK-NEXT: [[v29:%.*]] = llvm.mlir.
+    // CHECK-NEXT: [[v30:%.*]] = llvm.call @MPI_Send([[v25]], [[v27]], [[v28]], [[v14]], [[v14]], [[v29]])
     %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
     // CHECK: [[v31:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
@@ -54,11 +54,11 @@ module {
     // CHECK-NEXT: [[v33:%.*]] = llvm.getelementptr [[v31]][[[v32]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK-NEXT: [[v34:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
     // CHECK-NEXT: [[v35:%.*]] = llvm.trunc [[v34]] : i64 to i32
-    // CHECK-NEXT: [[v36:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
-    // CHECK-NEXT: [[v37:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
-    // CHECK-NEXT: [[v38:%.*]] = llvm.mlir.constant(1 : i64) : i64
+    // CHECK-NEXT: [[v36:%.*]] = llvm.mlir.
+    // CHECK-NEXT: [[v37:%.*]] = llvm.mlir.
+    // CHECK-NEXT: [[v38:%.*]] = llvm.mlir.constant({{[0-9]+}} : i64) : i64
     // CHECK-NEXT: [[v39:%.*]] = llvm.inttoptr [[v38]] : i64 to !llvm.ptr
-    // CHECK-NEXT: [[v40:%.*]] = llvm.call @MPI_Recv([[v33]], [[v35]], [[v36]], [[v14]], [[v14]], [[v37]], [[v39]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+    // CHECK-NEXT: [[v40:%.*]] = llvm.call @MPI_Recv([[v33]], [[v35]], [[v36]], [[v14]], [[v14]], [[v37]], [[v39]])
     mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
 
     // CHECK: [[v41:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
@@ -66,11 +66,11 @@ module {
     // CHECK-NEXT: [[v43:%.*]] = llvm.getelementptr [[v41]][[[v42]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK-NEXT: [[v44:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
     // CHECK-NEXT: [[v45:%.*]] = llvm.trunc [[v44]] : i64 to i32
-    // CHECK-NEXT: [[v46:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
-    // CHECK-NEXT: [[v47:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
-    // CHECK-NEXT: [[v48:%.*]] = llvm.mlir.constant(1 : i64) : i64
+    // CHECK-NEXT: [[v46:%.*]] = llvm.mlir.
+    // CHECK-NEXT: [[v47:%.*]] = llvm.mlir.
+    // CHECK-NEXT: [[v48:%.*]] = llvm.mlir.constant({{[0-9]+}} : i64) : i64
     // CHECK-NEXT: [[v49:%.*]] = llvm.inttoptr [[v48]] : i64 to !llvm.ptr
-    // CHECK-NEXT: [[v50:%.*]] = llvm.call @MPI_Recv([[v43]], [[v45]], [[v46]], [[v14]], [[v14]], [[v47]], [[v49]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+    // CHECK-NEXT: [[v50:%.*]] = llvm.call @MPI_Recv([[v43]], [[v45]], [[v46]], [[v14]], [[v14]], [[v47]], [[v49]])
     %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
     // CHECK: llvm.call @MPI_Finalize() : () -> i32

>From ede4372738a7ab1bc59989be0fc8d8814a37faa0 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 13 Feb 2025 16:27:29 +0100
Subject: [PATCH 04/15] provide empty register/populate functions even when
 MPIToLLVM is disabled

---
 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt | 21 +++++++++-----------
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp  | 15 ++++++++++++++
 2 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
index 17df603ff5686..39c5a9d145818 100644
--- a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -2,13 +2,6 @@ find_path(MPI_C_HEADER_DIR mpi.h
     PATHS $ENV{I_MPI_ROOT}/include
           $ENV{MPI_HOME}/include
           $ENV{MPI_ROOT}/include)
-if(MPI_C_HEADER_DIR)
-  # cmake_path(REMOVE_FILENAME MPI_C_HEADER_DIR)
-  message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
-else()
-  message(WARNING "MPI not found, disabling MLIRMPIToLLVM conversion")
-  return()
-endif()
 
 add_mlir_conversion_library(MLIRMPIToLLVM
   MPIToLLVM.cpp
@@ -27,8 +20,12 @@ add_mlir_conversion_library(MLIRMPIToLLVM
   MLIRLLVMDialect
   MLIRMPIDialect
 )
-target_include_directories(
-    MLIRMPIToLLVM
-    PRIVATE
-    ${MPI_C_HEADER_DIR}
-)
\ No newline at end of file
+
+if(MPI_C_HEADER_DIR)
+  # cmake_path(REMOVE_FILENAME MPI_C_HEADER_DIR)
+  message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
+  target_include_directories(obj.MLIRMPIToLLVM PRIVATE ${MPI_C_HEADER_DIR})
+  target_compile_definitions(obj.MLIRMPIToLLVM PUBLIC FOUND_MPI_C_HEADER=1)
+else()
+  message(WARNING "MPI not found, disabling MLIRMPIToLLVM conversion")
+endif()
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 9980c789cda6e..9f9f0ee078746 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -6,6 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+// skip if no MPI C header was found
+#ifdef FOUND_MPI_C_HEADER
+
 // This must go first (MPI gets confused otherwise)
 #include "MPIImplTraits.h"
 
@@ -330,3 +333,15 @@ void mlir::mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
     dialect->addInterfaces<FuncToLLVMDialectInterface>();
   });
 }
+
+#else // FOUND_MPI_C_HEADER
+
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+using namespace mlir;
+
+void mlir::mpi::populateMPIToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns) {}
+
+void mlir::mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {}
+
+#endif // FOUND_MPI_C_HEADER

>From e71bc9b762b6662ddd2ba25d99348360d8eae28c Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 17 Feb 2025 12:20:59 +0100
Subject: [PATCH 05/15] fallbacks from MPICH if no mpi.h was found

---
 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt  |  2 +-
 mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h | 10 ++++++--
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   | 15 -----------
 mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h  | 25 +++++++++++++++++++
 4 files changed, 34 insertions(+), 18 deletions(-)
 create mode 100644 mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h

diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
index 39c5a9d145818..a0efb9836b80a 100644
--- a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -27,5 +27,5 @@ if(MPI_C_HEADER_DIR)
   target_include_directories(obj.MLIRMPIToLLVM PRIVATE ${MPI_C_HEADER_DIR})
   target_compile_definitions(obj.MLIRMPIToLLVM PUBLIC FOUND_MPI_C_HEADER=1)
 else()
-  message(WARNING "MPI not found, disabling MLIRMPIToLLVM conversion")
+  message(WARNING "MPI not found, falling back to definitions from MPICH")
 endif()
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
index 09811e1cb7c61..ab4a17206382c 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
@@ -3,7 +3,13 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Transforms/DialectConversion.h"
+
+// skip if no MPI C header was found
+#ifdef FOUND_MPI_C_HEADER
 #include <mpi.h>
+#else // not FOUND_MPI_C_HEADER
+#include "mpi_fallback.h"
+#endif // FOUND_MPI_C_HEADER
 
 namespace {
 
@@ -24,8 +30,8 @@ struct MPIImplTraits {
 };
 
 // ****************************************************************************
-// Intel MPI
-#ifdef IMPI_DEVICE_EXPORT
+// Intel MPI/MPICH
+#if defined(IMPI_DEVICE_EXPORT) || defined(_MPI_FALLBACK_DEFS)
 
 mlir::Value
 MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 9f9f0ee078746..9980c789cda6e 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -6,9 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-// skip if no MPI C header was found
-#ifdef FOUND_MPI_C_HEADER
-
 // This must go first (MPI gets confused otherwise)
 #include "MPIImplTraits.h"
 
@@ -333,15 +330,3 @@ void mlir::mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
     dialect->addInterfaces<FuncToLLVMDialectInterface>();
   });
 }
-
-#else // FOUND_MPI_C_HEADER
-
-#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
-using namespace mlir;
-
-void mlir::mpi::populateMPIToLLVMConversionPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns) {}
-
-void mlir::mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {}
-
-#endif // FOUND_MPI_C_HEADER
diff --git a/mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h b/mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h
new file mode 100644
index 0000000000000..a4b97f2149339
--- /dev/null
+++ b/mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h
@@ -0,0 +1,25 @@
+/*
+ * Copyright (C) by Argonne National Laboratory
+ *     See COPYRIGHT in top-level directory
+ *     of MPICH source repository.
+ */
+
+typedef int MPI_Comm;
+#define MPI_COMM_WORLD ((MPI_Comm)0x44000000)
+
+typedef int MPI_Datatype;
+#define MPI_FLOAT ((MPI_Datatype)0x4c00040a)
+#define MPI_DOUBLE ((MPI_Datatype)0x4c00080b)
+#define MPI_INT8_T ((MPI_Datatype)0x4c000137)
+#define MPI_INT16_T ((MPI_Datatype)0x4c000238)
+#define MPI_INT32_T ((MPI_Datatype)0x4c000439)
+#define MPI_INT64_T ((MPI_Datatype)0x4c00083a)
+#define MPI_UINT8_T ((MPI_Datatype)0x4c00013b)
+#define MPI_UINT16_T ((MPI_Datatype)0x4c00023c)
+#define MPI_UINT32_T ((MPI_Datatype)0x4c00043d)
+#define MPI_UINT64_T ((MPI_Datatype)0x4c00083e)
+
+typedef struct MPI_Status;
+#define MPI_STATUS_IGNORE (MPI_Status *)1
+
+#define _MPI_FALLBACK_DEFS 1

>From 6d59dd681f8883431463db9a30e54eb0e4a6f018 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 17 Feb 2025 21:24:30 +0100
Subject: [PATCH 06/15] formatting and cleanup (review comments)

---
 .../mlir/Conversion/MPIToLLVM/MPIToLLVM.h     |   4 +-
 mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h |   4 +-
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   |   1 -
 mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h  |   8 +
 mlir/test/Conversion/MPIToLLVM/ops.mlir       | 145 +++++++++---------
 5 files changed, 83 insertions(+), 79 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
index 8d2698aa91c7c..3bae3a5f22248 100644
--- a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
+++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
@@ -1,3 +1,4 @@
+//===----------------------------------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -15,9 +16,6 @@ namespace mlir {
 class LLVMTypeConverter;
 class RewritePatternSet;
 
-#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
-#include "mlir/Conversion/Passes.h.inc"
-
 namespace mpi {
 
 void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
index ab4a17206382c..829bc0d8b4d61 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
@@ -99,7 +99,7 @@ MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
   // get external opaque struct pointer type
   auto commStructT =
       mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
-  const char *name = "ompi_mpi_comm_world";
+  mlir::StringRef name = "ompi_mpi_comm_world";
 
   // make sure global op definition exists
   (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
@@ -114,7 +114,7 @@ mlir::Value
 MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
                            mlir::ConversionPatternRewriter &rewriter,
                            mlir::Type type) {
-  const char *mtype = nullptr;
+  mlir::StringRef mtype = nullptr;
   if (type.isF32())
     mtype = "ompi_mpi_float";
   else if (type.isF64())
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 9980c789cda6e..e0050538316b2 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
-#include "mlir/Pass/Pass.h"
 
 #include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>
 
diff --git a/mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h b/mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h
index a4b97f2149339..e383a27c4e382 100644
--- a/mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h
+++ b/mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h
@@ -1,3 +1,11 @@
+//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
 /*
  * Copyright (C) by Argonne National Laboratory
  *     See COPYRIGHT in top-level directory
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
index 449e6418976cc..33fb928802afd 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -1,84 +1,83 @@
 // RUN: mlir-opt -convert-to-llvm %s | FileCheck %s
 
-module {
-  // CHECK: llvm.func @MPI_Finalize() -> i32
-  // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}, !llvm.ptr) -> i32
-  // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}) -> i32
-  // CHECK: llvm.func @MPI_Comm_rank({{.+}}, !llvm.ptr) -> i32
-  // COMM: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
-  // CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}) -> i32
+// CHECK: llvm.func @MPI_Comm_rank({{.+}}, !llvm.ptr) -> i32
+// COMM: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
+// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
 
-  func.func @mpi_test(%arg0: memref<100xf32>) {
-    // CHECK: [[varg0:%.*]]: !llvm.ptr, [[varg1:%.*]]: !llvm.ptr, [[varg2:%.*]]: i64, [[varg3:%.*]]: i64, [[varg4:%.*]]: i64
-    // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK-NEXT: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
-    // CHECK-NEXT: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
-    // CHECK-NEXT: [[v8:%.*]] = builtin.unrealized_conversion_cast [[v7]] : i32 to !mpi.retval
-    %0 = mpi.init : !mpi.retval
+func.func @mpi_test(%arg0: memref<100xf32>) {
+  // CHECK: [[varg0:%.*]]: !llvm.ptr, [[varg1:%.*]]: !llvm.ptr, [[varg2:%.*]]: i64, [[varg3:%.*]]: i64, [[varg4:%.*]]: i64
+  // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
+  // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
+  // CHECK: [[v8:%.*]] = builtin.unrealized_conversion_cast [[v7]] : i32 to !mpi.retval
+  %0 = mpi.init : !mpi.retval
 
-    // CHECK: [[v9:%.*]] = llvm.mlir.
-    // CHECK-NEXT: [[v10:%.*]] = llvm.mlir.constant(1 : i32) : i32
-    // CHECK-NEXT: [[v11:%.*]] = llvm.alloca [[v10]] x i32 : (i32) -> !llvm.ptr
-    // CHECK-NEXT: [[v12:%.*]] = llvm.call @MPI_Comm_rank([[v9]], [[v11]]) : ({{.+}}, !llvm.ptr) -> i32
-    // CHECK-NEXT: [[v13:%.*]] = builtin.unrealized_conversion_cast [[v12]] : i32 to !mpi.retval
-    // CHECK-NEXT: [[v14:%.*]] = llvm.load [[v11]] : !llvm.ptr -> i32
-    %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+  // CHECK: [[v9:%.*]] = llvm.mlir.
+  // CHECK: [[v10:%.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: [[v11:%.*]] = llvm.alloca [[v10]] x i32 : (i32) -> !llvm.ptr
+  // CHECK: [[v12:%.*]] = llvm.call @MPI_Comm_rank([[v9]], [[v11]]) : ({{.+}}, !llvm.ptr) -> i32
+  // CHECK: [[v13:%.*]] = builtin.unrealized_conversion_cast [[v12]] : i32 to !mpi.retval
+  // CHECK: [[v14:%.*]] = llvm.load [[v11]] : !llvm.ptr -> i32
+  %retval, %rank = mpi.comm_rank : !mpi.retval, i32
 
-    // CHECK: [[v15:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v16:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v17:%.*]] = llvm.getelementptr [[v15]][[[v16]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    // CHECK-NEXT: [[v18:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v19:%.*]] = llvm.trunc [[v18]] : i64 to i32
-    // CHECK-NEXT: [[v20:%.*]] = llvm.mlir.
-    // CHECK-NEXT: [[v21:%.*]] = llvm.mlir.
-    // CHECK-NEXT: [[v22:%.*]] = llvm.call @MPI_Send([[v17]], [[v19]], [[v20]], [[v14]], [[v14]], [[v21]])
-    mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+  // CHECK: [[v15:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v17:%.*]] = llvm.getelementptr [[v15]][[[v16]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+  // CHECK: [[v18:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v19:%.*]] = llvm.trunc [[v18]] : i64 to i32
+  // CHECK: [[v20:%.*]] = llvm.mlir.
+  // CHECK: [[v21:%.*]] = llvm.mlir.
+  // CHECK: [[v22:%.*]] = llvm.call @MPI_Send([[v17]], [[v19]], [[v20]], [[v14]], [[v14]], [[v21]])
+  mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
 
-    // CHECK: [[v23:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v24:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v25:%.*]] = llvm.getelementptr [[v23]][[[v24]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    // CHECK-NEXT: [[v26:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v27:%.*]] = llvm.trunc [[v26]] : i64 to i32
-    // CHECK-NEXT: [[v28:%.*]] = llvm.mlir.
-    // CHECK-NEXT: [[v29:%.*]] = llvm.mlir.
-    // CHECK-NEXT: [[v30:%.*]] = llvm.call @MPI_Send([[v25]], [[v27]], [[v28]], [[v14]], [[v14]], [[v29]])
-    %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+  // CHECK: [[v23:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v25:%.*]] = llvm.getelementptr [[v23]][[[v24]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+  // CHECK: [[v26:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v27:%.*]] = llvm.trunc [[v26]] : i64 to i32
+  // CHECK: [[v28:%.*]] = llvm.mlir.
+  // CHECK: [[v29:%.*]] = llvm.mlir.
+  // CHECK: [[v30:%.*]] = llvm.call @MPI_Send([[v25]], [[v27]], [[v28]], [[v14]], [[v14]], [[v29]])
+  %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
-    // CHECK: [[v31:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v32:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v33:%.*]] = llvm.getelementptr [[v31]][[[v32]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    // CHECK-NEXT: [[v34:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v35:%.*]] = llvm.trunc [[v34]] : i64 to i32
-    // CHECK-NEXT: [[v36:%.*]] = llvm.mlir.
-    // CHECK-NEXT: [[v37:%.*]] = llvm.mlir.
-    // CHECK-NEXT: [[v38:%.*]] = llvm.mlir.constant({{[0-9]+}} : i64) : i64
-    // CHECK-NEXT: [[v39:%.*]] = llvm.inttoptr [[v38]] : i64 to !llvm.ptr
-    // CHECK-NEXT: [[v40:%.*]] = llvm.call @MPI_Recv([[v33]], [[v35]], [[v36]], [[v14]], [[v14]], [[v37]], [[v39]])
-    mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+  // CHECK: [[v31:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v33:%.*]] = llvm.getelementptr [[v31]][[[v32]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+  // CHECK: [[v34:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v35:%.*]] = llvm.trunc [[v34]] : i64 to i32
+  // CHECK: [[v36:%.*]] = llvm.mlir.
+  // CHECK: [[v37:%.*]] = llvm.mlir.
+  // CHECK: [[v38:%.*]] = llvm.mlir.constant({{[0-9]+}} : i64) : i64
+  // CHECK: [[v39:%.*]] = llvm.inttoptr [[v38]] : i64 to !llvm.ptr
+  // CHECK: [[v40:%.*]] = llvm.call @MPI_Recv([[v33]], [[v35]], [[v36]], [[v14]], [[v14]], [[v37]], [[v39]])
+  mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
 
-    // CHECK: [[v41:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v42:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v43:%.*]] = llvm.getelementptr [[v41]][[[v42]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    // CHECK-NEXT: [[v44:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-    // CHECK-NEXT: [[v45:%.*]] = llvm.trunc [[v44]] : i64 to i32
-    // CHECK-NEXT: [[v46:%.*]] = llvm.mlir.
-    // CHECK-NEXT: [[v47:%.*]] = llvm.mlir.
-    // CHECK-NEXT: [[v48:%.*]] = llvm.mlir.constant({{[0-9]+}} : i64) : i64
-    // CHECK-NEXT: [[v49:%.*]] = llvm.inttoptr [[v48]] : i64 to !llvm.ptr
-    // CHECK-NEXT: [[v50:%.*]] = llvm.call @MPI_Recv([[v43]], [[v45]], [[v46]], [[v14]], [[v14]], [[v47]], [[v49]])
-    %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+  // CHECK: [[v41:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v43:%.*]] = llvm.getelementptr [[v41]][[[v42]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+  // CHECK: [[v44:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: [[v45:%.*]] = llvm.trunc [[v44]] : i64 to i32
+  // CHECK: [[v46:%.*]] = llvm.mlir.
+  // CHECK: [[v47:%.*]] = llvm.mlir.
+  // CHECK: [[v48:%.*]] = llvm.mlir.constant({{[0-9]+}} : i64) : i64
+  // CHECK: [[v49:%.*]] = llvm.inttoptr [[v48]] : i64 to !llvm.ptr
+  // CHECK: [[v50:%.*]] = llvm.call @MPI_Recv([[v43]], [[v45]], [[v46]], [[v14]], [[v14]], [[v47]], [[v49]])
+  %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
 
-    // CHECK: llvm.call @MPI_Finalize() : () -> i32
-    %3 = mpi.finalize : !mpi.retval
+  // CHECK: llvm.call @MPI_Finalize() : () -> i32
+  %3 = mpi.finalize : !mpi.retval
 
-    %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
-
-    %5 = mpi.error_class %0 : !mpi.retval
-    return
-  }
+  // CHECK: mpi.retval_check
+  %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
+  // CEHCK: mpi.error_class
+  %5 = mpi.error_class %0 : !mpi.retval
+  return
 }

>From c7744b0c04601b94b567a10d3ec20182a9279cbc Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 18 Feb 2025 11:48:16 +0100
Subject: [PATCH 07/15] inlining MPIImplTraits and fallback into cpp file,
 minor mods addressing review comments

---
 mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h | 157 ----------
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   | 294 ++++++++++++++----
 mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h  |  33 --
 mlir/test/Conversion/MPIToLLVM/ops.mlir       |   6 -
 4 files changed, 226 insertions(+), 264 deletions(-)
 delete mode 100644 mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
 delete mode 100644 mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h

diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h b/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
deleted file mode 100644
index 829bc0d8b4d61..0000000000000
--- a/mlir/lib/Conversion/MPIToLLVM/MPIImplTraits.h
+++ /dev/null
@@ -1,157 +0,0 @@
-#define MPICH_SKIP_MPICXX 1
-#define OMPI_SKIP_MPICXX 1
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/MPI/IR/MPI.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-// skip if no MPI C header was found
-#ifdef FOUND_MPI_C_HEADER
-#include <mpi.h>
-#else // not FOUND_MPI_C_HEADER
-#include "mpi_fallback.h"
-#endif // FOUND_MPI_C_HEADER
-
-namespace {
-
-// when lowerring the mpi dialect to functions calls certain details
-// differ between various MPI implementations. This class will provide
-// these depending on the MPI implementation that got included.
-struct MPIImplTraits {
-  // get/create MPI_COMM_WORLD as a mlir::Value
-  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
-                                  const mlir::Location loc,
-                                  mlir::ConversionPatternRewriter &rewriter);
-  // get/create MPI datatype as a mlir::Value which corresponds to the given
-  // mlir::Type
-  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
-                                 const mlir::Location loc,
-                                 mlir::ConversionPatternRewriter &rewriter,
-                                 mlir::Type type);
-};
-
-// ****************************************************************************
-// Intel MPI/MPICH
-#if defined(IMPI_DEVICE_EXPORT) || defined(_MPI_FALLBACK_DEFS)
-
-mlir::Value
-MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                            mlir::ConversionPatternRewriter &rewriter) {
-  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
-                                                 MPI_COMM_WORLD);
-}
-
-mlir::Value
-MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                           mlir::ConversionPatternRewriter &rewriter,
-                           mlir::Type type) {
-  int32_t mtype = 0;
-  if (type.isF32())
-    mtype = MPI_FLOAT;
-  else if (type.isF64())
-    mtype = MPI_DOUBLE;
-  else if (type.isInteger(64) && !type.isUnsignedInteger())
-    mtype = MPI_INT64_T;
-  else if (type.isInteger(64))
-    mtype = MPI_UINT64_T;
-  else if (type.isInteger(32) && !type.isUnsignedInteger())
-    mtype = MPI_INT32_T;
-  else if (type.isInteger(32))
-    mtype = MPI_UINT32_T;
-  else if (type.isInteger(16) && !type.isUnsignedInteger())
-    mtype = MPI_INT16_T;
-  else if (type.isInteger(16))
-    mtype = MPI_UINT16_T;
-  else if (type.isInteger(8) && !type.isUnsignedInteger())
-    mtype = MPI_INT8_T;
-  else if (type.isInteger(8))
-    mtype = MPI_UINT8_T;
-  else
-    assert(false && "unsupported type");
-  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
-                                                 mtype);
-}
-
-// ****************************************************************************
-// OpenMPI
-#elif defined(OPEN_MPI) && OPEN_MPI == 1
-
-// TODO: this is pretty close to getOrDefineFunction, can probably be factored
-static mlir::LLVM::GlobalOp
-getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                          mlir::ConversionPatternRewriter &rewriter,
-                          mlir::StringRef name,
-                          mlir::LLVM::LLVMStructType type) {
-  mlir::LLVM::GlobalOp ret;
-  if (!(ret = moduleOp.lookupSymbol<mlir::LLVM::GlobalOp>(name))) {
-    mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    ret = rewriter.create<mlir::LLVM::GlobalOp>(
-        loc, type, /*isConstant=*/false, mlir::LLVM::Linkage::External, name,
-        /*value=*/mlir::Attribute(), /*alignment=*/0, 0);
-  }
-  return ret;
-}
-
-mlir::Value
-MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                            mlir::ConversionPatternRewriter &rewriter) {
-  auto context = rewriter.getContext();
-  // get external opaque struct pointer type
-  auto commStructT =
-      mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
-  mlir::StringRef name = "ompi_mpi_comm_world";
-
-  // make sure global op definition exists
-  (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
-
-  // get address of symbol
-  return rewriter.create<mlir::LLVM::AddressOfOp>(
-      loc, mlir::LLVM::LLVMPointerType::get(context),
-      mlir::SymbolRefAttr::get(context, name));
-}
-
-mlir::Value
-MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                           mlir::ConversionPatternRewriter &rewriter,
-                           mlir::Type type) {
-  mlir::StringRef mtype = nullptr;
-  if (type.isF32())
-    mtype = "ompi_mpi_float";
-  else if (type.isF64())
-    mtype = "ompi_mpi_double";
-  else if (type.isInteger(64) && !type.isUnsignedInteger())
-    mtype = "ompi_mpi_int64_t";
-  else if (type.isInteger(64))
-    mtype = "ompi_mpi_uint64_t";
-  else if (type.isInteger(32) && !type.isUnsignedInteger())
-    mtype = "ompi_mpi_int32_t";
-  else if (type.isInteger(32))
-    mtype = "ompi_mpi_uint32_t";
-  else if (type.isInteger(16) && !type.isUnsignedInteger())
-    mtype = "ompi_mpi_int16_t";
-  else if (type.isInteger(16))
-    mtype = "ompi_mpi_uint16_t";
-  else if (type.isInteger(8) && !type.isUnsignedInteger())
-    mtype = "ompi_mpi_int8_t";
-  else if (type.isInteger(8))
-    mtype = "ompi_mpi_uint8_t";
-  else
-    assert(false && "unsupported type");
-
-  auto context = rewriter.getContext();
-  // get external opaque struct pointer type
-  auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
-      "ompi_predefined_datatype_t", context);
-  // make sure global op definition exists
-  (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype, commStructT);
-  // get address of symbol
-  return rewriter.create<mlir::LLVM::AddressOfOp>(
-      loc, mlir::LLVM::LLVMPointerType::get(context),
-      mlir::SymbolRefAttr::get(context, mtype));
-}
-
-#else
-#error "Unsupported MPI implementation"
-#endif
-
-} // namespace
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index e0050538316b2..98184d4cad5cc 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -6,36 +6,207 @@
 //
 //===----------------------------------------------------------------------===//
 
+#define MPICH_SKIP_MPICXX 1
+#define OMPI_SKIP_MPICXX 1
+#ifdef FOUND_MPI_C_HEADER
 // This must go first (MPI gets confused otherwise)
-#include "MPIImplTraits.h"
-
+#include <mpi.h>
+#else // not FOUND_MPI_C_HEADER
+//
+// Copyright (C) by Argonne National Laboratory
+//    See COPYRIGHT in top-level directory
+//    of MPICH source repository.
+//
+typedef int MPI_Comm;
+#define MPI_COMM_WORLD ((MPI_Comm)0x44000000)
+
+typedef int MPI_Datatype;
+#define MPI_FLOAT ((MPI_Datatype)0x4c00040a)
+#define MPI_DOUBLE ((MPI_Datatype)0x4c00080b)
+#define MPI_INT8_T ((MPI_Datatype)0x4c000137)
+#define MPI_INT16_T ((MPI_Datatype)0x4c000238)
+#define MPI_INT32_T ((MPI_Datatype)0x4c000439)
+#define MPI_INT64_T ((MPI_Datatype)0x4c00083a)
+#define MPI_UINT8_T ((MPI_Datatype)0x4c00013b)
+#define MPI_UINT16_T ((MPI_Datatype)0x4c00023c)
+#define MPI_UINT32_T ((MPI_Datatype)0x4c00043d)
+#define MPI_UINT64_T ((MPI_Datatype)0x4c00083e)
+
+typedef struct MPI_Status;
+#define MPI_STATUS_IGNORE (MPI_Status *)1
+
+#define _MPI_FALLBACK_DEFS 1
+#endif // FOUND_MPI_C_HEADER
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
-
-#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 
-// TODO: this was copied from GPUOpsLowering.cpp:288
-// is this okay, or should this be moved to some common file?
+namespace {
+
+template <typename Op, typename... Args>
+static Op getOrDefineGlobal(mlir::ModuleOp &moduleOp, const Location loc,
+                            ConversionPatternRewriter &rewriter, StringRef name,
+                            Args &&...args) {
+  Op ret;
+  if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
+  }
+  return ret;
+}
+
 static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
                                             const Location loc,
                                             ConversionPatternRewriter &rewriter,
                                             StringRef name,
                                             LLVM::LLVMFunctionType type) {
-  LLVM::LLVMFuncOp ret;
-  if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
-                                            LLVM::Linkage::External);
-  }
-  return ret;
+  return getOrDefineGlobal<LLVM::LLVMFuncOp>(
+      moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
 }
 
-namespace {
+// ****************************************************************************
+// When lowering the mpi dialect to functions calls certain details
+// differ between various MPI implementations. This class will provide
+// these depending on the MPI implementation that got included.
+struct MPIImplTraits {
+  // get/create MPI_COMM_WORLD as a mlir::Value
+  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+                                  const mlir::Location loc,
+                                  mlir::ConversionPatternRewriter &rewriter);
+  // get/create MPI datatype as a mlir::Value which corresponds to the given
+  // mlir::Type
+  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+                                 const mlir::Location loc,
+                                 mlir::ConversionPatternRewriter &rewriter,
+                                 mlir::Type type);
+};
+
+// ****************************************************************************
+// Intel MPI/MPICH
+#if defined(IMPI_DEVICE_EXPORT) || defined(_MPI_FALLBACK_DEFS)
+
+mlir::Value
+MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                            mlir::ConversionPatternRewriter &rewriter) {
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                 MPI_COMM_WORLD);
+}
+
+mlir::Value
+MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                           mlir::ConversionPatternRewriter &rewriter,
+                           mlir::Type type) {
+  int32_t mtype = 0;
+  if (type.isF32())
+    mtype = MPI_FLOAT;
+  else if (type.isF64())
+    mtype = MPI_DOUBLE;
+  else if (type.isInteger(64) && !type.isUnsignedInteger())
+    mtype = MPI_INT64_T;
+  else if (type.isInteger(64))
+    mtype = MPI_UINT64_T;
+  else if (type.isInteger(32) && !type.isUnsignedInteger())
+    mtype = MPI_INT32_T;
+  else if (type.isInteger(32))
+    mtype = MPI_UINT32_T;
+  else if (type.isInteger(16) && !type.isUnsignedInteger())
+    mtype = MPI_INT16_T;
+  else if (type.isInteger(16))
+    mtype = MPI_UINT16_T;
+  else if (type.isInteger(8) && !type.isUnsignedInteger())
+    mtype = MPI_INT8_T;
+  else if (type.isInteger(8))
+    mtype = MPI_UINT8_T;
+  else
+    assert(false && "unsupported type");
+  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                 mtype);
+}
+
+// ****************************************************************************
+// OpenMPI
+#elif defined(OPEN_MPI) && OPEN_MPI == 1
+
+static mlir::LLVM::GlobalOp
+getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                          mlir::ConversionPatternRewriter &rewriter,
+                          mlir::StringRef name,
+                          mlir::LLVM::LLVMStructType type) {
+
+  return getOrDefineGlobal<mlir::LLVM::GlobalOp>(
+      moduleOp, loc, rewriter, name, type, /*isConstant=*/false,
+      mlir::LLVM::Linkage::External, name,
+      /*value=*/mlir::Attribute(), /*alignment=*/0, 0);
+}
+
+mlir::Value
+MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                            mlir::ConversionPatternRewriter &rewriter) {
+  auto context = rewriter.getContext();
+  // get external opaque struct pointer type
+  auto commStructT =
+      mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
+  mlir::StringRef name = "ompi_mpi_comm_world";
+
+  // make sure global op definition exists
+  (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
+
+  // get address of symbol
+  return rewriter.create<mlir::LLVM::AddressOfOp>(
+      loc, mlir::LLVM::LLVMPointerType::get(context),
+      mlir::SymbolRefAttr::get(context, name));
+}
+
+mlir::Value
+MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                           mlir::ConversionPatternRewriter &rewriter,
+                           mlir::Type type) {
+  mlir::StringRef mtype = nullptr;
+  if (type.isF32())
+    mtype = "ompi_mpi_float";
+  else if (type.isF64())
+    mtype = "ompi_mpi_double";
+  else if (type.isInteger(64) && !type.isUnsignedInteger())
+    mtype = "ompi_mpi_int64_t";
+  else if (type.isInteger(64))
+    mtype = "ompi_mpi_uint64_t";
+  else if (type.isInteger(32) && !type.isUnsignedInteger())
+    mtype = "ompi_mpi_int32_t";
+  else if (type.isInteger(32))
+    mtype = "ompi_mpi_uint32_t";
+  else if (type.isInteger(16) && !type.isUnsignedInteger())
+    mtype = "ompi_mpi_int16_t";
+  else if (type.isInteger(16))
+    mtype = "ompi_mpi_uint16_t";
+  else if (type.isInteger(8) && !type.isUnsignedInteger())
+    mtype = "ompi_mpi_int8_t";
+  else if (type.isInteger(8))
+    mtype = "ompi_mpi_uint8_t";
+  else
+    assert(false && "unsupported type");
+
+  auto context = rewriter.getContext();
+  // get external opaque struct pointer type
+  auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
+      "ompi_predefined_datatype_t", context);
+  // make sure global op definition exists
+  (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype, commStructT);
+  // get address of symbol
+  return rewriter.create<mlir::LLVM::AddressOfOp>(
+      loc, mlir::LLVM::LLVMPointerType::get(context),
+      mlir::SymbolRefAttr::get(context, mtype));
+}
+
+#else
+#error "Unsupported MPI implementation"
+#endif
 
 //===----------------------------------------------------------------------===//
 // InitOpLowering
@@ -48,7 +219,7 @@ struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
   matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // get loc
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
 
     // ptrType `!llvm.ptr`
     Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
@@ -86,7 +257,7 @@ struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
   matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // get loc
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
 
     // grab a reference to the global module op:
     auto moduleOp = op->getParentOfType<ModuleOp>();
@@ -115,9 +286,9 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
   matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // get some helper vars
-    auto loc = op.getLoc();
-    auto context = rewriter.getContext();
-    auto i32 = rewriter.getI32Type();
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type i32 = rewriter.getI32Type();
 
     // ptrType `!llvm.ptr`
     Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -126,7 +297,7 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
     auto moduleOp = op->getParentOfType<ModuleOp>();
 
     // get MPI_COMM_WORLD
-    auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+    Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
 
     // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
     auto rankFuncType =
@@ -170,12 +341,12 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
   matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // get some helper vars
-    auto loc = op.getLoc();
-    auto context = rewriter.getContext();
-    auto i32 = rewriter.getI32Type();
-    auto i64 = rewriter.getI64Type();
-    auto memRef = adaptor.getRef();
-    auto elemType = op.getRef().getType().getElementType();
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type i32 = rewriter.getI32Type();
+    Type i64 = rewriter.getI64Type();
+    Value memRef = adaptor.getRef();
+    Type elemType = op.getRef().getType().getElementType();
 
     // ptrType `!llvm.ptr`
     Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -184,22 +355,17 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
     auto moduleOp = op->getParentOfType<ModuleOp>();
 
     // get MPI_COMM_WORLD, dataType and pointer
-    auto dataPtr =
-        rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1)
-            .getResult();
-    auto offset =
-        rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2).getResult();
+    Value dataPtr =
+        rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
+    Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
     dataPtr =
-        rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset)
-            .getResult();
-    auto size =
-        rewriter
-            .create<LLVM::ExtractValueOp>(loc, memRef, ArrayRef<int64_t>{3, 0})
-            .getResult();
-    size = rewriter.create<LLVM::TruncOp>(loc, i32, size).getResult();
-    auto dataType =
+        rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
+    Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
+                                                       ArrayRef<int64_t>{3, 0});
+    size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+    Value dataType =
         MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
-    auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+    Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
 
     // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
     // tag, comm)`
@@ -235,12 +401,12 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
   matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // get some helper vars
-    auto loc = op.getLoc();
-    auto context = rewriter.getContext();
-    auto i32 = rewriter.getI32Type();
-    auto i64 = rewriter.getI64Type();
-    auto memRef = adaptor.getRef();
-    auto elemType = op.getRef().getType().getElementType();
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type i32 = rewriter.getI32Type();
+    Type i64 = rewriter.getI64Type();
+    Value memRef = adaptor.getRef();
+    Type elemType = op.getRef().getType().getElementType();
 
     // ptrType `!llvm.ptr`
     Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -249,29 +415,21 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     auto moduleOp = op->getParentOfType<ModuleOp>();
 
     // get MPI_COMM_WORLD, dataType, status_ignore and pointer
-    auto dataPtr =
-        rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1)
-            .getResult();
-    auto offset =
-        rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2).getResult();
+    Value dataPtr =
+        rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
+    Value offset = rewriter.create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
     dataPtr =
-        rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset)
-            .getResult();
-    auto size =
-        rewriter
-            .create<LLVM::ExtractValueOp>(loc, memRef, ArrayRef<int64_t>{3, 0})
-            .getResult();
-    size = rewriter.create<LLVM::TruncOp>(loc, i32, size).getResult();
-    auto dataType =
+        rewriter.create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
+    Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
+                                                       ArrayRef<int64_t>{3, 0});
+    size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
+    Value dataType =
         MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
-    auto commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
-    auto statusIgnore =
-        rewriter
-            .create<LLVM::ConstantOp>(
-                loc, i64, reinterpret_cast<int64_t>(MPI_STATUS_IGNORE))
-            .getResult();
-    statusIgnore = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore)
-                       .getResult();
+    Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+    Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
+        loc, i64, reinterpret_cast<int64_t>(MPI_STATUS_IGNORE));
+    statusIgnore =
+        rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
 
     // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
     // tag, comm)`
diff --git a/mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h b/mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h
deleted file mode 100644
index e383a27c4e382..0000000000000
--- a/mlir/lib/Conversion/MPIToLLVM/mpi_fallback.h
+++ /dev/null
@@ -1,33 +0,0 @@
-//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-/*
- * Copyright (C) by Argonne National Laboratory
- *     See COPYRIGHT in top-level directory
- *     of MPICH source repository.
- */
-
-typedef int MPI_Comm;
-#define MPI_COMM_WORLD ((MPI_Comm)0x44000000)
-
-typedef int MPI_Datatype;
-#define MPI_FLOAT ((MPI_Datatype)0x4c00040a)
-#define MPI_DOUBLE ((MPI_Datatype)0x4c00080b)
-#define MPI_INT8_T ((MPI_Datatype)0x4c000137)
-#define MPI_INT16_T ((MPI_Datatype)0x4c000238)
-#define MPI_INT32_T ((MPI_Datatype)0x4c000439)
-#define MPI_INT64_T ((MPI_Datatype)0x4c00083a)
-#define MPI_UINT8_T ((MPI_Datatype)0x4c00013b)
-#define MPI_UINT16_T ((MPI_Datatype)0x4c00023c)
-#define MPI_UINT32_T ((MPI_Datatype)0x4c00043d)
-#define MPI_UINT64_T ((MPI_Datatype)0x4c00083e)
-
-typedef struct MPI_Status;
-#define MPI_STATUS_IGNORE (MPI_Status *)1
-
-#define _MPI_FALLBACK_DEFS 1
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
index 33fb928802afd..9b3c818b7eae3 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -17,14 +17,12 @@ func.func @mpi_test(%arg0: memref<100xf32>) {
   // CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
   // CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
   // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
-  // CHECK: [[v8:%.*]] = builtin.unrealized_conversion_cast [[v7]] : i32 to !mpi.retval
   %0 = mpi.init : !mpi.retval
 
   // CHECK: [[v9:%.*]] = llvm.mlir.
   // CHECK: [[v10:%.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: [[v11:%.*]] = llvm.alloca [[v10]] x i32 : (i32) -> !llvm.ptr
   // CHECK: [[v12:%.*]] = llvm.call @MPI_Comm_rank([[v9]], [[v11]]) : ({{.+}}, !llvm.ptr) -> i32
-  // CHECK: [[v13:%.*]] = builtin.unrealized_conversion_cast [[v12]] : i32 to !mpi.retval
   // CHECK: [[v14:%.*]] = llvm.load [[v11]] : !llvm.ptr -> i32
   %retval, %rank = mpi.comm_rank : !mpi.retval, i32
 
@@ -75,9 +73,5 @@ func.func @mpi_test(%arg0: memref<100xf32>) {
   // CHECK: llvm.call @MPI_Finalize() : () -> i32
   %3 = mpi.finalize : !mpi.retval
 
-  // CHECK: mpi.retval_check
-  %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1
-  // CEHCK: mpi.error_class
-  %5 = mpi.error_class %0 : !mpi.retval
   return
 }

>From 255de645854a3b4c203f9f0a04b68dd3bac70018 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Tue, 18 Feb 2025 11:52:32 +0100
Subject: [PATCH 08/15] rm unused line

Co-authored-by: Christian Ulmann <christianulmann at gmail.com>
---
 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
index a0efb9836b80a..830b635981ea2 100644
--- a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -22,7 +22,6 @@ add_mlir_conversion_library(MLIRMPIToLLVM
 )
 
 if(MPI_C_HEADER_DIR)
-  # cmake_path(REMOVE_FILENAME MPI_C_HEADER_DIR)
   message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
   target_include_directories(obj.MLIRMPIToLLVM PRIVATE ${MPI_C_HEADER_DIR})
   target_compile_definitions(obj.MLIRMPIToLLVM PUBLIC FOUND_MPI_C_HEADER=1)

>From 3a2fd9f0e2695eb12febcb2e4e1222b5b28a0f75 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 19 Feb 2025 12:12:01 +0100
Subject: [PATCH 09/15] remove dependency on mpi.h; TODO: runtime dispatch

---
 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt |  13 -
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp  | 305 ++++++++++---------
 mlir/test/Conversion/MPIToLLVM/ops.mlir      |   3 +
 3 files changed, 157 insertions(+), 164 deletions(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
index 830b635981ea2..29175964a10bf 100644
--- a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -1,8 +1,3 @@
-find_path(MPI_C_HEADER_DIR mpi.h
-    PATHS $ENV{I_MPI_ROOT}/include
-          $ENV{MPI_HOME}/include
-          $ENV{MPI_ROOT}/include)
-
 add_mlir_conversion_library(MLIRMPIToLLVM
   MPIToLLVM.cpp
 
@@ -20,11 +15,3 @@ add_mlir_conversion_library(MLIRMPIToLLVM
   MLIRLLVMDialect
   MLIRMPIDialect
 )
-
-if(MPI_C_HEADER_DIR)
-  message(STATUS "found MPI_C_HEADER_DIR: ${MPI_C_HEADER_DIR}")
-  target_include_directories(obj.MLIRMPIToLLVM PRIVATE ${MPI_C_HEADER_DIR})
-  target_compile_definitions(obj.MLIRMPIToLLVM PUBLIC FOUND_MPI_C_HEADER=1)
-else()
-  message(WARNING "MPI not found, falling back to definitions from MPICH")
-endif()
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 98184d4cad5cc..40bdc6730e6f6 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -6,37 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#define MPICH_SKIP_MPICXX 1
-#define OMPI_SKIP_MPICXX 1
-#ifdef FOUND_MPI_C_HEADER
-// This must go first (MPI gets confused otherwise)
-#include <mpi.h>
-#else // not FOUND_MPI_C_HEADER
 //
 // Copyright (C) by Argonne National Laboratory
 //    See COPYRIGHT in top-level directory
 //    of MPICH source repository.
 //
-typedef int MPI_Comm;
-#define MPI_COMM_WORLD ((MPI_Comm)0x44000000)
-
-typedef int MPI_Datatype;
-#define MPI_FLOAT ((MPI_Datatype)0x4c00040a)
-#define MPI_DOUBLE ((MPI_Datatype)0x4c00080b)
-#define MPI_INT8_T ((MPI_Datatype)0x4c000137)
-#define MPI_INT16_T ((MPI_Datatype)0x4c000238)
-#define MPI_INT32_T ((MPI_Datatype)0x4c000439)
-#define MPI_INT64_T ((MPI_Datatype)0x4c00083a)
-#define MPI_UINT8_T ((MPI_Datatype)0x4c00013b)
-#define MPI_UINT16_T ((MPI_Datatype)0x4c00023c)
-#define MPI_UINT32_T ((MPI_Datatype)0x4c00043d)
-#define MPI_UINT64_T ((MPI_Datatype)0x4c00083e)
-
-typedef struct MPI_Status;
-#define MPI_STATUS_IGNORE (MPI_Status *)1
-
-#define _MPI_FALLBACK_DEFS 1
-#endif // FOUND_MPI_C_HEADER
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -71,143 +45,172 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
       moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
 }
 
-// ****************************************************************************
+//===----------------------------------------------------------------------===//
+// Implementation details for MPICH ABI compatible MPI implementations
+//===----------------------------------------------------------------------===//
+struct MPICHImplTraits {
+  static const int MPI_FLOAT = 0x4c00040a;
+  static const int MPI_DOUBLE = 0x4c00080b;
+  static const int MPI_INT8_T = 0x4c000137;
+  static const int MPI_INT16_T = 0x4c000238;
+  static const int MPI_INT32_T = 0x4c000439;
+  static const int MPI_INT64_T = 0x4c00083a;
+  static const int MPI_UINT8_T = 0x4c00013b;
+  static const int MPI_UINT16_T = 0x4c00023c;
+  static const int MPI_UINT32_T = 0x4c00043d;
+  static const int MPI_UINT64_T = 0x4c00083e;
+
+  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+                                  const mlir::Location loc,
+                                  mlir::ConversionPatternRewriter &rewriter) {
+    static const int MPI_COMM_WORLD = 0x44000000;
+    return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                   MPI_COMM_WORLD);
+  }
+
+  static intptr_t getStatusIgnore() { return 1; }
+
+  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+                                 const mlir::Location loc,
+                                 mlir::ConversionPatternRewriter &rewriter,
+                                 mlir::Type type) {
+    int32_t mtype = 0;
+    if (type.isF32())
+      mtype = MPI_FLOAT;
+    else if (type.isF64())
+      mtype = MPI_DOUBLE;
+    else if (type.isInteger(64) && !type.isUnsignedInteger())
+      mtype = MPI_INT64_T;
+    else if (type.isInteger(64))
+      mtype = MPI_UINT64_T;
+    else if (type.isInteger(32) && !type.isUnsignedInteger())
+      mtype = MPI_INT32_T;
+    else if (type.isInteger(32))
+      mtype = MPI_UINT32_T;
+    else if (type.isInteger(16) && !type.isUnsignedInteger())
+      mtype = MPI_INT16_T;
+    else if (type.isInteger(16))
+      mtype = MPI_UINT16_T;
+    else if (type.isInteger(8) && !type.isUnsignedInteger())
+      mtype = MPI_INT8_T;
+    else if (type.isInteger(8))
+      mtype = MPI_UINT8_T;
+    else
+      assert(false && "unsupported type");
+    return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                   mtype);
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Implementation details for OpenMPI
+//===----------------------------------------------------------------------===//
+struct OMPIImplTraits {
+
+  static mlir::LLVM::GlobalOp
+  getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
+                            mlir::ConversionPatternRewriter &rewriter,
+                            mlir::StringRef name,
+                            mlir::LLVM::LLVMStructType type) {
+
+    return getOrDefineGlobal<mlir::LLVM::GlobalOp>(
+        moduleOp, loc, rewriter, name, type, /*isConstant=*/false,
+        mlir::LLVM::Linkage::External, name,
+        /*value=*/mlir::Attribute(), /*alignment=*/0, 0);
+  }
+
+  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
+                                  const mlir::Location loc,
+                                  mlir::ConversionPatternRewriter &rewriter) {
+    auto context = rewriter.getContext();
+    // get external opaque struct pointer type
+    auto commStructT =
+        mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
+    mlir::StringRef name = "ompi_mpi_comm_world";
+
+    // make sure global op definition exists
+    (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
+
+    // get address of symbol
+    return rewriter.create<mlir::LLVM::AddressOfOp>(
+        loc, mlir::LLVM::LLVMPointerType::get(context),
+        mlir::SymbolRefAttr::get(context, name));
+  }
+
+  static intptr_t getStatusIgnore() { return 0; }
+
+  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
+                                 const mlir::Location loc,
+                                 mlir::ConversionPatternRewriter &rewriter,
+                                 mlir::Type type) {
+    mlir::StringRef mtype;
+    if (type.isF32())
+      mtype = "ompi_mpi_float";
+    else if (type.isF64())
+      mtype = "ompi_mpi_double";
+    else if (type.isInteger(64) && !type.isUnsignedInteger())
+      mtype = "ompi_mpi_int64_t";
+    else if (type.isInteger(64))
+      mtype = "ompi_mpi_uint64_t";
+    else if (type.isInteger(32) && !type.isUnsignedInteger())
+      mtype = "ompi_mpi_int32_t";
+    else if (type.isInteger(32))
+      mtype = "ompi_mpi_uint32_t";
+    else if (type.isInteger(16) && !type.isUnsignedInteger())
+      mtype = "ompi_mpi_int16_t";
+    else if (type.isInteger(16))
+      mtype = "ompi_mpi_uint16_t";
+    else if (type.isInteger(8) && !type.isUnsignedInteger())
+      mtype = "ompi_mpi_int8_t";
+    else if (type.isInteger(8))
+      mtype = "ompi_mpi_uint8_t";
+    else
+      assert(false && "unsupported type");
+
+    auto context = rewriter.getContext();
+    // get external opaque struct pointer type
+    auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
+        "ompi_predefined_datatype_t", context);
+    // make sure global op definition exists
+    (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype,
+                                    commStructT);
+    // get address of symbol
+    return rewriter.create<mlir::LLVM::AddressOfOp>(
+        loc, mlir::LLVM::LLVMPointerType::get(context),
+        mlir::SymbolRefAttr::get(context, mtype));
+  }
+};
+
+//===----------------------------------------------------------------------===//
 // When lowering the mpi dialect to functions calls certain details
 // differ between various MPI implementations. This class will provide
-// these depending on the MPI implementation that got included.
+// these in a gnereic way, depending on the MPI implementation that got
+// included.
+//===----------------------------------------------------------------------===//
 struct MPIImplTraits {
   // get/create MPI_COMM_WORLD as a mlir::Value
   static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
                                   const mlir::Location loc,
-                                  mlir::ConversionPatternRewriter &rewriter);
+                                  mlir::ConversionPatternRewriter &rewriter) {
+    // TODO: dispatch based on the MPI implementation
+    return MPICHImplTraits::getCommWorld(moduleOp, loc, rewriter);
+  }
+  // Get the MPI_STATUS_IGNORE value (typically a pointer type).
+  static intptr_t getStatusIgnore() {
+    // TODO: dispatch based on the MPI implementation
+    return MPICHImplTraits::getStatusIgnore();
+  }
   // get/create MPI datatype as a mlir::Value which corresponds to the given
   // mlir::Type
   static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
                                  const mlir::Location loc,
                                  mlir::ConversionPatternRewriter &rewriter,
-                                 mlir::Type type);
+                                 mlir::Type type) {
+    // TODO: dispatch based on the MPI implementation
+    return MPICHImplTraits::getDataType(moduleOp, loc, rewriter, type);
+  }
 };
 
-// ****************************************************************************
-// Intel MPI/MPICH
-#if defined(IMPI_DEVICE_EXPORT) || defined(_MPI_FALLBACK_DEFS)
-
-mlir::Value
-MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                            mlir::ConversionPatternRewriter &rewriter) {
-  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
-                                                 MPI_COMM_WORLD);
-}
-
-mlir::Value
-MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                           mlir::ConversionPatternRewriter &rewriter,
-                           mlir::Type type) {
-  int32_t mtype = 0;
-  if (type.isF32())
-    mtype = MPI_FLOAT;
-  else if (type.isF64())
-    mtype = MPI_DOUBLE;
-  else if (type.isInteger(64) && !type.isUnsignedInteger())
-    mtype = MPI_INT64_T;
-  else if (type.isInteger(64))
-    mtype = MPI_UINT64_T;
-  else if (type.isInteger(32) && !type.isUnsignedInteger())
-    mtype = MPI_INT32_T;
-  else if (type.isInteger(32))
-    mtype = MPI_UINT32_T;
-  else if (type.isInteger(16) && !type.isUnsignedInteger())
-    mtype = MPI_INT16_T;
-  else if (type.isInteger(16))
-    mtype = MPI_UINT16_T;
-  else if (type.isInteger(8) && !type.isUnsignedInteger())
-    mtype = MPI_INT8_T;
-  else if (type.isInteger(8))
-    mtype = MPI_UINT8_T;
-  else
-    assert(false && "unsupported type");
-  return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
-                                                 mtype);
-}
-
-// ****************************************************************************
-// OpenMPI
-#elif defined(OPEN_MPI) && OPEN_MPI == 1
-
-static mlir::LLVM::GlobalOp
-getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                          mlir::ConversionPatternRewriter &rewriter,
-                          mlir::StringRef name,
-                          mlir::LLVM::LLVMStructType type) {
-
-  return getOrDefineGlobal<mlir::LLVM::GlobalOp>(
-      moduleOp, loc, rewriter, name, type, /*isConstant=*/false,
-      mlir::LLVM::Linkage::External, name,
-      /*value=*/mlir::Attribute(), /*alignment=*/0, 0);
-}
-
-mlir::Value
-MPIImplTraits::getCommWorld(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                            mlir::ConversionPatternRewriter &rewriter) {
-  auto context = rewriter.getContext();
-  // get external opaque struct pointer type
-  auto commStructT =
-      mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
-  mlir::StringRef name = "ompi_mpi_comm_world";
-
-  // make sure global op definition exists
-  (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
-
-  // get address of symbol
-  return rewriter.create<mlir::LLVM::AddressOfOp>(
-      loc, mlir::LLVM::LLVMPointerType::get(context),
-      mlir::SymbolRefAttr::get(context, name));
-}
-
-mlir::Value
-MPIImplTraits::getDataType(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                           mlir::ConversionPatternRewriter &rewriter,
-                           mlir::Type type) {
-  mlir::StringRef mtype = nullptr;
-  if (type.isF32())
-    mtype = "ompi_mpi_float";
-  else if (type.isF64())
-    mtype = "ompi_mpi_double";
-  else if (type.isInteger(64) && !type.isUnsignedInteger())
-    mtype = "ompi_mpi_int64_t";
-  else if (type.isInteger(64))
-    mtype = "ompi_mpi_uint64_t";
-  else if (type.isInteger(32) && !type.isUnsignedInteger())
-    mtype = "ompi_mpi_int32_t";
-  else if (type.isInteger(32))
-    mtype = "ompi_mpi_uint32_t";
-  else if (type.isInteger(16) && !type.isUnsignedInteger())
-    mtype = "ompi_mpi_int16_t";
-  else if (type.isInteger(16))
-    mtype = "ompi_mpi_uint16_t";
-  else if (type.isInteger(8) && !type.isUnsignedInteger())
-    mtype = "ompi_mpi_int8_t";
-  else if (type.isInteger(8))
-    mtype = "ompi_mpi_uint8_t";
-  else
-    assert(false && "unsupported type");
-
-  auto context = rewriter.getContext();
-  // get external opaque struct pointer type
-  auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
-      "ompi_predefined_datatype_t", context);
-  // make sure global op definition exists
-  (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype, commStructT);
-  // get address of symbol
-  return rewriter.create<mlir::LLVM::AddressOfOp>(
-      loc, mlir::LLVM::LLVMPointerType::get(context),
-      mlir::SymbolRefAttr::get(context, mtype));
-}
-
-#else
-#error "Unsupported MPI implementation"
-#endif
-
 //===----------------------------------------------------------------------===//
 // InitOpLowering
 //===----------------------------------------------------------------------===//
@@ -427,7 +430,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
         MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
     Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
     Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
-        loc, i64, reinterpret_cast<int64_t>(MPI_STATUS_IGNORE));
+        loc, i64, MPIImplTraits::getStatusIgnore());
     statusIgnore =
         rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
 
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
index 9b3c818b7eae3..c1938cade8609 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt -convert-to-llvm %s | FileCheck %s
 
+module attributes { mpi.dlti = #dlti.map<"MPI:Implemention" = "Intel"> } {
+
 // CHECK: llvm.func @MPI_Finalize() -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}) -> i32
@@ -75,3 +77,4 @@ func.func @mpi_test(%arg0: memref<100xf32>) {
 
   return
 }
+}
\ No newline at end of file

>From 7706fdf5fa574033e105050e37d0583622f1f0c3 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 19 Feb 2025 12:17:15 +0100
Subject: [PATCH 10/15] clang-format

---
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 40bdc6730e6f6..f75575ca55dfe 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -12,9 +12,9 @@
 //    of MPICH source repository.
 //
 
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
-#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Transforms/DialectConversion.h"

>From 6464e201820afad3491a4f66ce2f3abf28dc1c1b Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 19 Feb 2025 18:27:20 +0100
Subject: [PATCH 11/15] MPI implementation selection at runtime

---
 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt |   1 +
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp  |  39 +++-
 mlir/test/Conversion/MPIToLLVM/ops.mlir      | 233 +++++++++++++------
 3 files changed, 192 insertions(+), 81 deletions(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
index 29175964a10bf..2c80d8230515a 100644
--- a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMPIToLLVM
   Core
 
   LINK_LIBS PUBLIC
+  MLIRDLTIDialect
   MLIRLLVMCommonConversion
   MLIRLLVMDialect
   MLIRMPIDialect
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index f75575ca55dfe..37cda94829655 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -184,29 +185,53 @@ struct OMPIImplTraits {
 //===----------------------------------------------------------------------===//
 // When lowering the mpi dialect to functions calls certain details
 // differ between various MPI implementations. This class will provide
-// these in a gnereic way, depending on the MPI implementation that got
-// included.
+// these in a generic way, depending on the MPI implementation that got
+// selected by the DLTI attribute on the module.
 //===----------------------------------------------------------------------===//
 struct MPIImplTraits {
+  enum MPIImpl { MPICH, OMPI };
+
+  // Get the MPI implementation from a DLTI attribute on the module.
+  // Default to MPICH (and ABI compatible).
+  static MPIImpl getMPIImpl(mlir::ModuleOp &moduleOp) {
+    auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
+    if (failed(attr)) {
+      return MPICH;
+    }
+    auto strAttr = dyn_cast<StringAttr>(attr.value());
+    if (strAttr && strAttr.getValue() == "OpenMPI") {
+      return OMPI;
+    }
+    return MPICH;
+  }
+
   // get/create MPI_COMM_WORLD as a mlir::Value
   static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
                                   const mlir::Location loc,
                                   mlir::ConversionPatternRewriter &rewriter) {
-    // TODO: dispatch based on the MPI implementation
+    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) {
+      return OMPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+    }
     return MPICHImplTraits::getCommWorld(moduleOp, loc, rewriter);
   }
+
   // Get the MPI_STATUS_IGNORE value (typically a pointer type).
-  static intptr_t getStatusIgnore() {
-    // TODO: dispatch based on the MPI implementation
+  static intptr_t getStatusIgnore(mlir::ModuleOp &moduleOp) {
+    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) {
+      return OMPIImplTraits::getStatusIgnore();
+    }
     return MPICHImplTraits::getStatusIgnore();
   }
+
   // get/create MPI datatype as a mlir::Value which corresponds to the given
   // mlir::Type
   static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
                                  const mlir::Location loc,
                                  mlir::ConversionPatternRewriter &rewriter,
                                  mlir::Type type) {
-    // TODO: dispatch based on the MPI implementation
+    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) {
+      return OMPIImplTraits::getDataType(moduleOp, loc, rewriter, type);
+    }
     return MPICHImplTraits::getDataType(moduleOp, loc, rewriter, type);
   }
 };
@@ -430,7 +455,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
         MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
     Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
     Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
-        loc, i64, MPIImplTraits::getStatusIgnore());
+        loc, i64, MPIImplTraits::getStatusIgnore(moduleOp));
     statusIgnore =
         rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
 
diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir
index c1938cade8609..2d0f81f038879 100644
--- a/mlir/test/Conversion/MPIToLLVM/ops.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir
@@ -1,80 +1,165 @@
-// RUN: mlir-opt -convert-to-llvm %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-to-llvm %s | FileCheck %s
 
-module attributes { mpi.dlti = #dlti.map<"MPI:Implemention" = "Intel"> } {
+// COM: Test MPICh ABI
+// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH">} {
+// CHECK: llvm.func @MPI_Finalize() -> i32
+// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+// CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH"> } {
+
+  // CHECK: llvm.func @mpi_test_mpich([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
+  func.func @mpi_test_mpich(%arg0: memref<100xf32>) {
+
+    // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
+    // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
+    %0 = mpi.init : !mpi.retval
+
+    // CHECK: [[v8:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
+    // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32
+    %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+
+    // CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
+    // CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
+    // CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK: [[v19:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+    mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
 
+    // CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
+    // CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK: [[v27:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
+    %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+    // CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
+    // CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK: [[v35:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
+    // CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
+    // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+    mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+    // CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
+    // CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
+    // CHECK: [[v45:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32
+    // CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
+    // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
+    // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
+    %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+    // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
+    %3 = mpi.finalize : !mpi.retval
+
+    return
+  }
+}
+
+// -----
+
+// COM: Test OpenMPI ABI
+// CHECK: module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
-// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}, !llvm.ptr) -> i32
-// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, {{.+}}, i32, i32, {{.+}}) -> i32
-// CHECK: llvm.func @MPI_Comm_rank({{.+}}, !llvm.ptr) -> i32
-// COMM: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
+// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
+// CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque>
+// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.mlir.global external @ompi_mpi_comm_world() {addr_space = 0 : i32} : !llvm.struct<"ompi_communicator_t", opaque>
 // CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32
+module attributes { mpi.dlti = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
+
+  // CHECK: llvm.func @mpi_test_openmpi([[varg0:%.+]]: !llvm.ptr, [[varg1:%.+]]: !llvm.ptr, [[varg2:%.+]]: i64, [[varg3:%.+]]: i64, [[varg4:%.+]]: i64) {
+  func.func @mpi_test_openmpi(%arg0: memref<100xf32>) {
+
+    // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
+    // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
+    %0 = mpi.init : !mpi.retval
+
+    // CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr
+    // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32
+    %retval, %rank = mpi.comm_rank : !mpi.retval, i32
+
+    // CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32
+    // CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
+    // CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+    // CHECK: [[v19:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
+    mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+    // CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
+    // CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+    // CHECK: [[v27:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
+    %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+    // CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
+    // CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+    // CHECK: [[v35:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
+    // CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr
+    // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
+    mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
+
+    // CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
+    // CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
+    // CHECK: [[v45:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr
+    // CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
+    // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr
+    // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
+    %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
+
+    // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32
+    %3 = mpi.finalize : !mpi.retval
 
-func.func @mpi_test(%arg0: memref<100xf32>) {
-  // CHECK: [[varg0:%.*]]: !llvm.ptr, [[varg1:%.*]]: !llvm.ptr, [[varg2:%.*]]: i64, [[varg3:%.*]]: i64, [[varg4:%.*]]: i64
-  // CHECK: [[v0:%.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-  // CHECK: [[v1:%.*]] = llvm.insertvalue [[varg0]], [[v0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v2:%.*]] = llvm.insertvalue [[varg1]], [[v1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v3:%.*]] = llvm.insertvalue [[varg2]], [[v2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v4:%.*]] = llvm.insertvalue [[varg3]], [[v3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v5:%.*]] = llvm.insertvalue [[varg4]], [[v4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v6:%.*]] = llvm.mlir.zero : !llvm.ptr
-  // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32
-  %0 = mpi.init : !mpi.retval
-
-  // CHECK: [[v9:%.*]] = llvm.mlir.
-  // CHECK: [[v10:%.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: [[v11:%.*]] = llvm.alloca [[v10]] x i32 : (i32) -> !llvm.ptr
-  // CHECK: [[v12:%.*]] = llvm.call @MPI_Comm_rank([[v9]], [[v11]]) : ({{.+}}, !llvm.ptr) -> i32
-  // CHECK: [[v14:%.*]] = llvm.load [[v11]] : !llvm.ptr -> i32
-  %retval, %rank = mpi.comm_rank : !mpi.retval, i32
-
-  // CHECK: [[v15:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v17:%.*]] = llvm.getelementptr [[v15]][[[v16]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-  // CHECK: [[v18:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v19:%.*]] = llvm.trunc [[v18]] : i64 to i32
-  // CHECK: [[v20:%.*]] = llvm.mlir.
-  // CHECK: [[v21:%.*]] = llvm.mlir.
-  // CHECK: [[v22:%.*]] = llvm.call @MPI_Send([[v17]], [[v19]], [[v20]], [[v14]], [[v14]], [[v21]])
-  mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
-
-  // CHECK: [[v23:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v25:%.*]] = llvm.getelementptr [[v23]][[[v24]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-  // CHECK: [[v26:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v27:%.*]] = llvm.trunc [[v26]] : i64 to i32
-  // CHECK: [[v28:%.*]] = llvm.mlir.
-  // CHECK: [[v29:%.*]] = llvm.mlir.
-  // CHECK: [[v30:%.*]] = llvm.call @MPI_Send([[v25]], [[v27]], [[v28]], [[v14]], [[v14]], [[v29]])
-  %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
-
-  // CHECK: [[v31:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v33:%.*]] = llvm.getelementptr [[v31]][[[v32]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-  // CHECK: [[v34:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v35:%.*]] = llvm.trunc [[v34]] : i64 to i32
-  // CHECK: [[v36:%.*]] = llvm.mlir.
-  // CHECK: [[v37:%.*]] = llvm.mlir.
-  // CHECK: [[v38:%.*]] = llvm.mlir.constant({{[0-9]+}} : i64) : i64
-  // CHECK: [[v39:%.*]] = llvm.inttoptr [[v38]] : i64 to !llvm.ptr
-  // CHECK: [[v40:%.*]] = llvm.call @MPI_Recv([[v33]], [[v35]], [[v36]], [[v14]], [[v14]], [[v37]], [[v39]])
-  mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32
-
-  // CHECK: [[v41:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v43:%.*]] = llvm.getelementptr [[v41]][[[v42]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-  // CHECK: [[v44:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
-  // CHECK: [[v45:%.*]] = llvm.trunc [[v44]] : i64 to i32
-  // CHECK: [[v46:%.*]] = llvm.mlir.
-  // CHECK: [[v47:%.*]] = llvm.mlir.
-  // CHECK: [[v48:%.*]] = llvm.mlir.constant({{[0-9]+}} : i64) : i64
-  // CHECK: [[v49:%.*]] = llvm.inttoptr [[v48]] : i64 to !llvm.ptr
-  // CHECK: [[v50:%.*]] = llvm.call @MPI_Recv([[v43]], [[v45]], [[v46]], [[v14]], [[v14]], [[v47]], [[v49]])
-  %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval
-
-  // CHECK: llvm.call @MPI_Finalize() : () -> i32
-  %3 = mpi.finalize : !mpi.retval
-
-  return
+    return
+  }
 }
-}
\ No newline at end of file

>From 5f2644a39c68609829814bfbfd9893bd36fd40b7 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 19 Feb 2025 19:48:52 +0100
Subject: [PATCH 12/15] warning about unkonw MPi implementation

---
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 37cda94829655..3e8d7e820e6d6 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -202,6 +202,10 @@ struct MPIImplTraits {
     if (strAttr && strAttr.getValue() == "OpenMPI") {
       return OMPI;
     }
+    if (!strAttr || strAttr.getValue() != "MPICH") {
+      moduleOp.emitWarning("Unknown \"MPI:Implementation\" specified in DLTI, "
+                           "defaulting to MPICH");
+    }
     return MPICH;
   }
 

>From 4e2a795f9bac155c71fcf317be1ff21d598b6a2d Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Fri, 21 Feb 2025 09:45:48 +0100
Subject: [PATCH 13/15] Apply suggestions from code review

Co-authored-by: Christian Ulmann <christianulmann at gmail.com>
---
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 24 +++++++++------------
 1 file changed, 10 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 3e8d7e820e6d6..5c986ddc7cafd 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -50,7 +50,7 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
 // Implementation details for MPICH ABI compatible MPI implementations
 //===----------------------------------------------------------------------===//
 struct MPICHImplTraits {
-  static const int MPI_FLOAT = 0x4c00040a;
+  static constexpr int MPI_FLOAT = 0x4c00040a;
   static const int MPI_DOUBLE = 0x4c00080b;
   static const int MPI_INT8_T = 0x4c000137;
   static const int MPI_INT16_T = 0x4c000238;
@@ -182,34 +182,31 @@ struct OMPIImplTraits {
   }
 };
 
-//===----------------------------------------------------------------------===//
-// When lowering the mpi dialect to functions calls certain details
-// differ between various MPI implementations. This class will provide
-// these in a generic way, depending on the MPI implementation that got
-// selected by the DLTI attribute on the module.
-//===----------------------------------------------------------------------===//
+/// When lowering the mpi dialect to functions calls certain details
+/// differ between various MPI implementations. This class will provide
+/// these in a generic way, depending on the MPI implementation that got
+/// selected by the DLTI attribute on the module.
 struct MPIImplTraits {
   enum MPIImpl { MPICH, OMPI };
 
-  // Get the MPI implementation from a DLTI attribute on the module.
-  // Default to MPICH (and ABI compatible).
+  /// Gets the MPI implementation from a DLTI attribute on the module.
+  /// Defaults to MPICH (and ABI compatible).
   static MPIImpl getMPIImpl(mlir::ModuleOp &moduleOp) {
     auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
-    if (failed(attr)) {
+    if (failed(attr))
       return MPICH;
-    }
     auto strAttr = dyn_cast<StringAttr>(attr.value());
     if (strAttr && strAttr.getValue() == "OpenMPI") {
       return OMPI;
     }
     if (!strAttr || strAttr.getValue() != "MPICH") {
-      moduleOp.emitWarning("Unknown \"MPI:Implementation\" specified in DLTI, "
+      moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI (" << strAttr.getValue() << "), "
                            "defaulting to MPICH");
     }
     return MPICH;
   }
 
-  // get/create MPI_COMM_WORLD as a mlir::Value
+  /// Gets or creates MPI_COMM_WORLD as a mlir::Value.
   static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
                                   const mlir::Location loc,
                                   mlir::ConversionPatternRewriter &rewriter) {
@@ -250,7 +247,6 @@ struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
   LogicalResult
   matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // get loc
     Location loc = op.getLoc();
 
     // ptrType `!llvm.ptr`

>From 2a8745be3069bdc433c2d8c6c731d0ba9c58833b Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 21 Feb 2025 10:01:46 +0100
Subject: [PATCH 14/15] fixing formatting and typo

---
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 58 +++++++++------------
 1 file changed, 26 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 5c986ddc7cafd..e1c693912dabf 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -49,17 +49,18 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
 //===----------------------------------------------------------------------===//
 // Implementation details for MPICH ABI compatible MPI implementations
 //===----------------------------------------------------------------------===//
+
 struct MPICHImplTraits {
   static constexpr int MPI_FLOAT = 0x4c00040a;
-  static const int MPI_DOUBLE = 0x4c00080b;
-  static const int MPI_INT8_T = 0x4c000137;
-  static const int MPI_INT16_T = 0x4c000238;
-  static const int MPI_INT32_T = 0x4c000439;
-  static const int MPI_INT64_T = 0x4c00083a;
-  static const int MPI_UINT8_T = 0x4c00013b;
-  static const int MPI_UINT16_T = 0x4c00023c;
-  static const int MPI_UINT32_T = 0x4c00043d;
-  static const int MPI_UINT64_T = 0x4c00083e;
+  static constexpr int MPI_DOUBLE = 0x4c00080b;
+  static constexpr int MPI_INT8_T = 0x4c000137;
+  static constexpr int MPI_INT16_T = 0x4c000238;
+  static constexpr int MPI_INT32_T = 0x4c000439;
+  static constexpr int MPI_INT64_T = 0x4c00083a;
+  static constexpr int MPI_UINT8_T = 0x4c00013b;
+  static constexpr int MPI_UINT16_T = 0x4c00023c;
+  static constexpr int MPI_UINT32_T = 0x4c00043d;
+  static constexpr int MPI_UINT64_T = 0x4c00083e;
 
   static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
                                   const mlir::Location loc,
@@ -196,13 +197,11 @@ struct MPIImplTraits {
     if (failed(attr))
       return MPICH;
     auto strAttr = dyn_cast<StringAttr>(attr.value());
-    if (strAttr && strAttr.getValue() == "OpenMPI") {
+    if (strAttr && strAttr.getValue() == "OpenMPI")
       return OMPI;
-    }
-    if (!strAttr || strAttr.getValue() != "MPICH") {
-      moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI (" << strAttr.getValue() << "), "
-                           "defaulting to MPICH");
-    }
+    if (!strAttr || strAttr.getValue() != "MPICH")
+      moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
+                             << strAttr.getValue() << "), defaulting to MPICH";
     return MPICH;
   }
 
@@ -210,29 +209,26 @@ struct MPIImplTraits {
   static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
                                   const mlir::Location loc,
                                   mlir::ConversionPatternRewriter &rewriter) {
-    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) {
+    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI)
       return OMPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
-    }
     return MPICHImplTraits::getCommWorld(moduleOp, loc, rewriter);
   }
 
-  // Get the MPI_STATUS_IGNORE value (typically a pointer type).
+  /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
   static intptr_t getStatusIgnore(mlir::ModuleOp &moduleOp) {
-    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) {
+    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI)
       return OMPIImplTraits::getStatusIgnore();
-    }
     return MPICHImplTraits::getStatusIgnore();
   }
 
-  // get/create MPI datatype as a mlir::Value which corresponds to the given
-  // mlir::Type
+  /// get/create MPI datatype as a mlir::Value which corresponds to the given
+  /// mlir::Type
   static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
                                  const mlir::Location loc,
                                  mlir::ConversionPatternRewriter &rewriter,
                                  mlir::Type type) {
-    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI) {
+    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI)
       return OMPIImplTraits::getDataType(moduleOp, loc, rewriter, type);
-    }
     return MPICHImplTraits::getDataType(moduleOp, loc, rewriter, type);
   }
 };
@@ -347,9 +343,9 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
     // if retval is checked, replace uses of retval with the results from the
     // call op
     SmallVector<Value> replacements;
-    if (op.getRetval()) {
+    if (op.getRetval())
       replacements.push_back(callOp.getResult());
-    }
+
     // replace all uses, then erase op
     replacements.push_back(loadedRank.getRes());
     rewriter.replaceOp(op, replacements);
@@ -408,11 +404,10 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
         loc, funcDecl,
         ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
                    commWorld});
-    if (op.getRetval()) {
+    if (op.getRetval())
       rewriter.replaceOp(op, funcCall.getResult());
-    } else {
+    else
       rewriter.eraseOp(op);
-    }
 
     return success();
   }
@@ -473,11 +468,10 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
         loc, funcDecl,
         ValueRange{dataPtr, size, dataType, adaptor.getSource(),
                    adaptor.getTag(), commWorld, statusIgnore});
-    if (op.getRetval()) {
+    if (op.getRetval())
       rewriter.replaceOp(op, funcCall.getResult());
-    } else {
+    else
       rewriter.eraseOp(op);
-    }
 
     return success();
   }

>From eaebf5e48320c803a96df91e6380857a7792a8b5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 21 Feb 2025 15:56:34 +0100
Subject: [PATCH 15/15] using virtual dispatch for MPIImplTraits; cleanup

---
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 206 ++++++++++----------
 1 file changed, 98 insertions(+), 108 deletions(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index e1c693912dabf..f1bd8562db4f4 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -19,13 +19,14 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include <memory>
 
 using namespace mlir;
 
 namespace {
 
 template <typename Op, typename... Args>
-static Op getOrDefineGlobal(mlir::ModuleOp &moduleOp, const Location loc,
+static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
                             ConversionPatternRewriter &rewriter, StringRef name,
                             Args &&...args) {
   Op ret;
@@ -46,11 +47,40 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
       moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
 }
 
+/// When lowering the mpi dialect to functions calls certain details
+/// differ between various MPI implementations. This class will provide
+/// these in a generic way, depending on the MPI implementation that got
+/// selected by the DLTI attribute on the module.
+class MPIImplTraits {
+  ModuleOp &moduleOp;
+
+public:
+  /// Instantiate a new MPIImplTraits object according to the DLTI attribute
+  /// on the given module.
+  static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
+
+  MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
+
+  ModuleOp &getModuleOp() { return moduleOp; }
+
+  /// Gets or creates MPI_COMM_WORLD as a Value.
+  virtual Value getCommWorld(const Location loc,
+                             ConversionPatternRewriter &rewriter) = 0;
+
+  /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
+  virtual intptr_t getStatusIgnore() = 0;
+
+  /// get/create MPI datatype as a Value which corresponds to the given
+  /// Type
+  virtual Value getDataType(const Location loc,
+                            ConversionPatternRewriter &rewriter, Type type) = 0;
+};
+
 //===----------------------------------------------------------------------===//
 // Implementation details for MPICH ABI compatible MPI implementations
 //===----------------------------------------------------------------------===//
 
-struct MPICHImplTraits {
+class MPICHImplTraits : public MPIImplTraits {
   static constexpr int MPI_FLOAT = 0x4c00040a;
   static constexpr int MPI_DOUBLE = 0x4c00080b;
   static constexpr int MPI_INT8_T = 0x4c000137;
@@ -62,20 +92,20 @@ struct MPICHImplTraits {
   static constexpr int MPI_UINT32_T = 0x4c00043d;
   static constexpr int MPI_UINT64_T = 0x4c00083e;
 
-  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
-                                  const mlir::Location loc,
-                                  mlir::ConversionPatternRewriter &rewriter) {
+public:
+  using MPIImplTraits::MPIImplTraits;
+
+  Value getCommWorld(const Location loc,
+                     ConversionPatternRewriter &rewriter) override {
     static const int MPI_COMM_WORLD = 0x44000000;
-    return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
-                                                   MPI_COMM_WORLD);
+    return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                             MPI_COMM_WORLD);
   }
 
-  static intptr_t getStatusIgnore() { return 1; }
+  intptr_t getStatusIgnore() override { return 1; }
 
-  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
-                                 const mlir::Location loc,
-                                 mlir::ConversionPatternRewriter &rewriter,
-                                 mlir::Type type) {
+  Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
+                    Type type) override {
     int32_t mtype = 0;
     if (type.isF32())
       mtype = MPI_FLOAT;
@@ -99,53 +129,50 @@ struct MPICHImplTraits {
       mtype = MPI_UINT8_T;
     else
       assert(false && "unsupported type");
-    return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
-                                                   mtype);
+    return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
   }
 };
 
 //===----------------------------------------------------------------------===//
 // Implementation details for OpenMPI
 //===----------------------------------------------------------------------===//
-struct OMPIImplTraits {
-
-  static mlir::LLVM::GlobalOp
-  getOrDefineExternalStruct(mlir::ModuleOp &moduleOp, const mlir::Location loc,
-                            mlir::ConversionPatternRewriter &rewriter,
-                            mlir::StringRef name,
-                            mlir::LLVM::LLVMStructType type) {
-
-    return getOrDefineGlobal<mlir::LLVM::GlobalOp>(
-        moduleOp, loc, rewriter, name, type, /*isConstant=*/false,
-        mlir::LLVM::Linkage::External, name,
-        /*value=*/mlir::Attribute(), /*alignment=*/0, 0);
+class OMPIImplTraits : public MPIImplTraits {
+  LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
+                                           ConversionPatternRewriter &rewriter,
+                                           StringRef name,
+                                           LLVM::LLVMStructType type) {
+
+    return getOrDefineGlobal<LLVM::GlobalOp>(
+        getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
+        LLVM::Linkage::External, name,
+        /*value=*/Attribute(), /*alignment=*/0, 0);
   }
 
-  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
-                                  const mlir::Location loc,
-                                  mlir::ConversionPatternRewriter &rewriter) {
+public:
+  using MPIImplTraits::MPIImplTraits;
+
+  Value getCommWorld(const Location loc,
+                     ConversionPatternRewriter &rewriter) override {
     auto context = rewriter.getContext();
     // get external opaque struct pointer type
     auto commStructT =
-        mlir::LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
-    mlir::StringRef name = "ompi_mpi_comm_world";
+        LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
+    StringRef name = "ompi_mpi_comm_world";
 
     // make sure global op definition exists
-    (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, name, commStructT);
+    (void)getOrDefineExternalStruct(loc, rewriter, name, commStructT);
 
     // get address of symbol
-    return rewriter.create<mlir::LLVM::AddressOfOp>(
-        loc, mlir::LLVM::LLVMPointerType::get(context),
-        mlir::SymbolRefAttr::get(context, name));
+    return rewriter.create<LLVM::AddressOfOp>(
+        loc, LLVM::LLVMPointerType::get(context),
+        SymbolRefAttr::get(context, name));
   }
 
-  static intptr_t getStatusIgnore() { return 0; }
+  intptr_t getStatusIgnore() override { return 0; }
 
-  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
-                                 const mlir::Location loc,
-                                 mlir::ConversionPatternRewriter &rewriter,
-                                 mlir::Type type) {
-    mlir::StringRef mtype;
+  Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
+                    Type type) override {
+    StringRef mtype;
     if (type.isF32())
       mtype = "ompi_mpi_float";
     else if (type.isF64())
@@ -171,67 +198,29 @@ struct OMPIImplTraits {
 
     auto context = rewriter.getContext();
     // get external opaque struct pointer type
-    auto commStructT = mlir::LLVM::LLVMStructType::getOpaque(
-        "ompi_predefined_datatype_t", context);
+    auto commStructT =
+        LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
     // make sure global op definition exists
-    (void)getOrDefineExternalStruct(moduleOp, loc, rewriter, mtype,
-                                    commStructT);
+    (void)getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
     // get address of symbol
-    return rewriter.create<mlir::LLVM::AddressOfOp>(
-        loc, mlir::LLVM::LLVMPointerType::get(context),
-        mlir::SymbolRefAttr::get(context, mtype));
+    return rewriter.create<LLVM::AddressOfOp>(
+        loc, LLVM::LLVMPointerType::get(context),
+        SymbolRefAttr::get(context, mtype));
   }
 };
 
-/// When lowering the mpi dialect to functions calls certain details
-/// differ between various MPI implementations. This class will provide
-/// these in a generic way, depending on the MPI implementation that got
-/// selected by the DLTI attribute on the module.
-struct MPIImplTraits {
-  enum MPIImpl { MPICH, OMPI };
-
-  /// Gets the MPI implementation from a DLTI attribute on the module.
-  /// Defaults to MPICH (and ABI compatible).
-  static MPIImpl getMPIImpl(mlir::ModuleOp &moduleOp) {
-    auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
-    if (failed(attr))
-      return MPICH;
-    auto strAttr = dyn_cast<StringAttr>(attr.value());
-    if (strAttr && strAttr.getValue() == "OpenMPI")
-      return OMPI;
-    if (!strAttr || strAttr.getValue() != "MPICH")
-      moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
-                             << strAttr.getValue() << "), defaulting to MPICH";
-    return MPICH;
-  }
-
-  /// Gets or creates MPI_COMM_WORLD as a mlir::Value.
-  static mlir::Value getCommWorld(mlir::ModuleOp &moduleOp,
-                                  const mlir::Location loc,
-                                  mlir::ConversionPatternRewriter &rewriter) {
-    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI)
-      return OMPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
-    return MPICHImplTraits::getCommWorld(moduleOp, loc, rewriter);
-  }
-
-  /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
-  static intptr_t getStatusIgnore(mlir::ModuleOp &moduleOp) {
-    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI)
-      return OMPIImplTraits::getStatusIgnore();
-    return MPICHImplTraits::getStatusIgnore();
-  }
-
-  /// get/create MPI datatype as a mlir::Value which corresponds to the given
-  /// mlir::Type
-  static mlir::Value getDataType(mlir::ModuleOp &moduleOp,
-                                 const mlir::Location loc,
-                                 mlir::ConversionPatternRewriter &rewriter,
-                                 mlir::Type type) {
-    if (MPIImplTraits::getMPIImpl(moduleOp) == OMPI)
-      return OMPIImplTraits::getDataType(moduleOp, loc, rewriter, type);
-    return MPICHImplTraits::getDataType(moduleOp, loc, rewriter, type);
-  }
-};
+std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
+  auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
+  if (failed(attr))
+    return std::make_unique<MPICHImplTraits>(moduleOp);
+  auto strAttr = dyn_cast<StringAttr>(attr.value());
+  if (strAttr && strAttr.getValue() == "OpenMPI")
+    return std::make_unique<OMPIImplTraits>(moduleOp);
+  if (!strAttr || strAttr.getValue() != "MPICH")
+    moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
+                           << strAttr.getValue() << "), defaulting to MPICH";
+  return std::make_unique<MPICHImplTraits>(moduleOp);
+}
 
 //===----------------------------------------------------------------------===//
 // InitOpLowering
@@ -320,8 +309,9 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
     // grab a reference to the global module op:
     auto moduleOp = op->getParentOfType<ModuleOp>();
 
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
     // get MPI_COMM_WORLD
-    Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+    Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
 
     // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
     auto rankFuncType =
@@ -387,9 +377,9 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
     Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
                                                        ArrayRef<int64_t>{3, 0});
     size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
-    Value dataType =
-        MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
-    Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+    Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
 
     // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
     // tag, comm)`
@@ -446,11 +436,11 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     Value size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
                                                        ArrayRef<int64_t>{3, 0});
     size = rewriter.create<LLVM::TruncOp>(loc, i32, size);
-    Value dataType =
-        MPIImplTraits::getDataType(moduleOp, loc, rewriter, elemType);
-    Value commWorld = MPIImplTraits::getCommWorld(moduleOp, loc, rewriter);
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+    Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
     Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
-        loc, i64, MPIImplTraits::getStatusIgnore(moduleOp));
+        loc, i64, mpiTraits->getStatusIgnore());
     statusIgnore =
         rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
 
@@ -498,13 +488,13 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
 // Pattern Population
 //===----------------------------------------------------------------------===//
 
-void mlir::mpi::populateMPIToLLVMConversionPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                              RewritePatternSet &patterns) {
   patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
                SendOpLowering, RecvOpLowering>(converter);
 }
 
-void mlir::mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
+void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
     dialect->addInterfaces<FuncToLLVMDialectInterface>();
   });



More information about the Mlir-commits mailing list