[Mlir-commits] [mlir] [mlir][mesh, MPI] Mesh2mpi (PR #104566)

Frank Schlimbach llvmlistbot at llvm.org
Fri Aug 16 02:09:07 PDT 2024


https://github.com/fschlimb created https://github.com/llvm/llvm-project/pull/104566

Pass for lowering `Mesh` to `MPI`.
Initial commit lowers `UpdateHaloOp` only.


>From 6b6751dd16754763dcb0384dc59fcab5d6f4e367 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 14 Aug 2024 19:29:23 +0200
Subject: [PATCH 1/2] initial hack lowering mesh.update_halo to MPI

---
 .../mlir/Conversion/MeshToMPI/MeshToMPI.h     |  27 +++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  17 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  33 ++++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 mlir/lib/Conversion/MeshToMPI/CMakeLists.txt  |  22 +++
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 171 ++++++++++++++++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  19 ++
 .../MeshToMPI/convert-mesh-to-mpi.mlir        |  34 ++++
 9 files changed, 325 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
 create mode 100644 mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
 create mode 100644 mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir

diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
new file mode 100644
index 00000000000000..6a2c196da45577
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -0,0 +1,27 @@
+//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Lowers Mesh communication operations (updateHalo, AllGater, ...)
+/// to MPI primitives.
+std::unique_ptr<Pass> createConvertMeshToMPIPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 208f26489d6c39..ad8e98442ab8bc 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -51,6 +51,7 @@
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7bde9e490e4f4e..f9a6f52a22c6ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -869,6 +869,23 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MeshToMPI
+//===----------------------------------------------------------------------===//
+
+def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
+  let summary = "Convert Mesh dialect to MPI dialect.";
+  let description = [{
+    This pass converts communication operations
+    from the Mesh dialect to operations from the MPI dialect.
+  }];
+  let dependentDialects = [
+    "memref::MemRefDialect",
+    "mpi::MPIDialect",
+    "scf::SCFDialect"
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..9d1684b78f34f2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -155,6 +155,39 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
   ];
 }
 
+def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary =
+      "For given split axes get the linear index the direct neighbor processes.";
+  let description = [{
+    Example:
+    ```
+    %idx = mesh.neighbor_linear_index on @mesh for $device 
+               split_axes = $split_axes : index
+    ```
+    Given `@mesh` with shape `(10, 20, 30)`,
+          `device` = `(1, 2, 3)`
+          `$split_axes` = `[1]`
+    it returns the linear indices of the processes at positions `(1, 1, 3)`: `633`
+    and `(1, 3, 3)`: `693`.
+
+    A negative value is returned if `$device` has no neighbor in the given
+    direction along the given `split_axes`.
+  }];
+  let arguments = (ins FlatSymbolRefAttr:$mesh,
+                       Variadic<Index>:$device,
+                       Mesh_MeshAxesAttr:$split_axes);
+  let results = (outs Index:$neighbor_down, Index:$neighbor_up);
+  let assemblyFormat =  [{
+      `on` $mesh `[` $device `]`
+      `split_axes` `=` $split_axes
+      attr-dict `:` type(results)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Sharding operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 813f700c5556e1..3ee237f4e62acd 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
+add_subdirectory(MeshToMPI)
 add_subdirectory(NVGPUToNVVM)
 add_subdirectory(NVVMToLLVM)
 add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
new file mode 100644
index 00000000000000..95815a683f6d6a
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMeshToMPI
+  MeshToMPI.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRIR
+  MLIRLinalgTransforms
+  MLIRMemRefDialect
+  MLIRPass
+  MLIRMeshDialect
+  MLIRMPIDialect
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
new file mode 100644
index 00000000000000..b4cf9da8497a2d
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -0,0 +1,171 @@
+//===- MeshToMPI.cpp - Mesh to MPI  dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Mesh communicatin ops tp MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+#define DEBUG_TYPE "mesh-to-mpi"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+struct ConvertMeshToMPIPass
+    : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+  using Base::Base;
+
+  /// Run the dialect converter on the module.
+  void runOnOperation() override {
+    getOperation()->walk([&](UpdateHaloOp op) {
+      SymbolTableCollection symbolTableCollection;
+      OpBuilder builder(op);
+      auto loc = op.getLoc();
+
+      auto toValue = [&builder, &loc](OpFoldResult &v) {
+        return v.is<Value>()
+                   ? v.get<Value>()
+                   : builder.create<::mlir::arith::ConstantOp>(
+                         loc,
+                         builder.getIndexAttr(
+                             cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+      };
+
+      auto array = op.getInput();
+      auto rank = array.getType().getRank();
+      auto mesh = op.getMesh();
+      auto meshOp = getMesh(op, symbolTableCollection);
+      auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+                                      op.getDynamicHaloSizes(), builder);
+      for (auto &sz : haloSizes) {
+        if (sz.is<Value>()) {
+          sz = builder
+                   .create<arith::IndexCastOp>(loc, builder.getIndexType(),
+                                               sz.get<Value>())
+                   .getResult();
+        }
+      }
+
+      SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+      SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
+      SmallVector<OpFoldResult> shape(rank);
+      for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+        if (ShapedType::isDynamic(s)) {
+          shape[i] = builder.create<memref::DimOp>(loc, array, s).getResult();
+        } else {
+          shape[i] = builder.getIndexAttr(s);
+        }
+      }
+
+      auto tagAttr = builder.getI32IntegerAttr(91); // whatever
+      auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+      auto zeroAttr = builder.getI32IntegerAttr(0); // whatever
+      auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+      SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+                                         builder.getIndexType());
+      auto myMultiIndex =
+          builder.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+              .getResult();
+      auto currHaloDim = 0;
+
+      for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+        if (!splitAxes.empty()) {
+          auto tmp = builder
+                         .create<NeighborsLinearIndicesOp>(
+                             loc, mesh, myMultiIndex, splitAxes)
+                         .getResults();
+          Value neighbourIDs[2] = {builder.create<arith::IndexCastOp>(
+                                       loc, builder.getI32Type(), tmp[0]),
+                                   builder.create<arith::IndexCastOp>(
+                                       loc, builder.getI32Type(), tmp[1])};
+          auto orgDimSize = shape[dim];
+          auto upperOffset = builder.create<arith::SubIOp>(
+              loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1]));
+
+          // make sure we send/recv in a way that does not lead to a dead-lock
+          // This is by far not optimal, this should be at least MPI_sendrecv
+          // and - probably even more importantly - buffers should be re-used
+          // Currently using temporary, contiguous buffer for MPI communication
+          auto genSendRecv = [&](auto dim, bool upperHalo) {
+            auto orgOffset = offsets[dim];
+            shape[dim] =
+                upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2];
+            auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+            auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+            auto hasFrom = builder.create<arith::CmpIOp>(
+                loc, arith::CmpIPredicate::sge, from, zero);
+            auto hasTo = builder.create<arith::CmpIOp>(
+                loc, arith::CmpIPredicate::sge, to, zero);
+            auto buffer = builder.create<memref::AllocOp>(
+                loc, shape, array.getType().getElementType());
+            builder.create<scf::IfOp>(
+                loc, hasTo, [&](OpBuilder &builder, Location loc) {
+                  offsets[dim] = upperHalo
+                                     ? OpFoldResult(builder.getIndexAttr(0))
+                                     : OpFoldResult(upperOffset);
+                  auto subview = builder.create<memref::SubViewOp>(
+                      loc, array, offsets, shape, strides);
+                  builder.create<memref::CopyOp>(loc, subview, buffer);
+                  builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag,
+                                              to);
+                  builder.create<scf::YieldOp>(loc);
+                });
+            builder.create<scf::IfOp>(
+                loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+                  offsets[dim] = upperHalo
+                                     ? OpFoldResult(upperOffset)
+                                     : OpFoldResult(builder.getIndexAttr(0));
+                  builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
+                                              from);
+                  auto subview = builder.create<memref::SubViewOp>(
+                      loc, array, offsets, shape, strides);
+                  builder.create<memref::CopyOp>(loc, buffer, subview);
+                  builder.create<scf::YieldOp>(loc);
+                });
+            builder.create<memref::DeallocOp>(loc, buffer);
+            offsets[dim] = orgOffset;
+          };
+
+          genSendRecv(dim, false);
+          genSendRecv(dim, true);
+
+          shape[dim] = builder
+                           .create<arith::SubIOp>(
+                               loc, toValue(orgDimSize),
+                               builder
+                                   .create<arith::AddIOp>(
+                                       loc, toValue(haloSizes[dim * 2]),
+                                       toValue(haloSizes[dim * 2 + 1]))
+                                   .getResult())
+                           .getResult();
+          offsets[dim] = haloSizes[dim * 2];
+          ++currHaloDim;
+        }
+      }
+    });
+  }
+};
+} // namespace
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c35020b4c20ccc..f25bbbf8e274b6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -730,6 +730,25 @@ void ProcessLinearIndexOp::getAsmResultNames(
   setNameFn(getResult(), "proc_linear_idx");
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.neighbors_linear_indices op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return success();
+}
+
+void NeighborsLinearIndicesOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getNeighborDown(), "down_linear_idx");
+  setNameFn(getNeighborUp(), "up_linear_idx");
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
new file mode 100644
index 00000000000000..9ef826ca0cdace
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s
+
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 2x2x4)
+
+// -----
+
+// CHECK-LABEL: func @update_halo
+func.func @update_halo_1d(
+    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+    %arg0 : memref<12x12xi8>) {
+  // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
+  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  // CHECK-SAME: split_axes = {{\[\[}}0]]
+  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
+  %c2 = arith.constant 2 : i64
+  mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+    halo_sizes = [2, %c2] : memref<12x12xi8>
+  return
+}
+
+func.func @update_halo_2d(
+    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+    %arg0 : memref<12x12xi8>) {
+  %c2 = arith.constant 2 : i64
+  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
+  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
+  // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
+  mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
+    halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
+    : memref<12x12xi8>
+  return
+}

>From c370db1ab727853f48bf121966dc76ffceed8ea7 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 16 Aug 2024 10:55:28 +0200
Subject: [PATCH 2/2] dim fixes, proper testing

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 306 ++++++++++--------
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 179 ++++++++--
 2 files changed, 339 insertions(+), 146 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b4cf9da8497a2d..42d885a109ee79 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements a translation of Mesh communicatin ops tp MPI ops.
+// This file implements a translation of Mesh communication ops tp MPI ops.
 //
 //===----------------------------------------------------------------------===//
 
@@ -21,6 +21,8 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "mesh-to-mpi"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -34,138 +36,190 @@ using namespace mlir;
 using namespace mlir::mesh;
 
 namespace {
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertUpdateHaloOp
+    : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::mesh::UpdateHaloOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    // Halos are exchanged as 2 blocks per dimension (one for each side: down
+    // and up). It is assumed that the last dim in a default memref is
+    // contiguous, hence iteration starts with the complete halo on the first
+    // dim which should be contiguous (unless the source is not). The size of
+    // the exchanged data will decrease when iterating over dimensions. That's
+    // good because the halos of last dim will be most fragmented.
+    // memref.subview is used to read and write the halo data from and to the
+    // local data. subviews and halos have dynamic and static values, so
+    // OpFoldResults are used whenever possible.
+
+    SymbolTableCollection symbolTableCollection;
+    auto loc = op.getLoc();
+
+    // convert a OpFoldResult into a Value
+    auto toValue = [&rewriter, &loc](OpFoldResult &v) {
+      return v.is<Value>()
+                 ? v.get<Value>()
+                 : rewriter.create<::mlir::arith::ConstantOp>(
+                       loc,
+                       rewriter.getIndexAttr(
+                           cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+    };
+
+    auto array = op.getInput();
+    auto rank = array.getType().getRank();
+    auto mesh = op.getMesh();
+    auto meshOp = getMesh(op, symbolTableCollection);
+    auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+                                    op.getDynamicHaloSizes(), rewriter);
+    // subviews need Index values
+    for (auto &sz : haloSizes) {
+      if (sz.is<Value>()) {
+        sz = rewriter
+                 .create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+                                             sz.get<Value>())
+                 .getResult();
+      }
+    }
+
+    // most of the offset/size/stride data is the same for all dims
+    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> shape(rank);
+    // we need the actual shape to compute offsets and sizes
+    for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+      if (ShapedType::isDynamic(s)) {
+        shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+      } else {
+        shape[i] = rewriter.getIndexAttr(s);
+      }
+    }
+
+    auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
+    auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+    auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
+    auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+    SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+                                       rewriter.getIndexType());
+    auto myMultiIndex =
+        rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+            .getResult();
+    // halo sizes are provided for split dimensions only
+    auto currHaloDim = 0;
+
+    for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+      if (splitAxes.empty()) {
+        continue;
+      }
+      // Get the linearized ids of the neighbors (down and up) for the
+      // given split
+      auto tmp = rewriter
+                     .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
+                                                       splitAxes)
+                     .getResults();
+      // MPI operates on i32...
+      Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
+                                   loc, rewriter.getI32Type(), tmp[0]),
+                               rewriter.create<arith::IndexCastOp>(
+                                   loc, rewriter.getI32Type(), tmp[1])};
+      // store for later
+      auto orgDimSize = shape[dim];
+      // this dim's offset to the start of the upper halo
+      auto upperOffset = rewriter.create<arith::SubIOp>(
+          loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
+
+      // Make sure we send/recv in a way that does not lead to a dead-lock.
+      // The current approach is by far not optimal, this should be at least
+      // be a red-black pattern or using MPI_sendrecv.
+      // Also, buffers should be re-used.
+      // Still using temporary contiguous buffers for MPI communication...
+      // Still yielding a "serialized" communication pattern...
+      auto genSendRecv = [&](auto dim, bool upperHalo) {
+        auto orgOffset = offsets[dim];
+        shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+                               : haloSizes[currHaloDim * 2];
+        // Check if we need to send and/or receive
+        // Processes on the mesh borders have only one neighbor
+        auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+        auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+        auto hasFrom = rewriter.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::sge, from, zero);
+        auto hasTo = rewriter.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::sge, to, zero);
+        auto buffer = rewriter.create<memref::AllocOp>(
+            loc, shape, array.getType().getElementType());
+        // if has neighbor: copy halo data from array to buffer and send
+        rewriter.create<scf::IfOp>(
+            loc, hasTo, [&](OpBuilder &builder, Location loc) {
+              offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
+                                       : OpFoldResult(upperOffset);
+              auto subview = builder.create<memref::SubViewOp>(
+                  loc, array, offsets, shape, strides);
+              builder.create<memref::CopyOp>(loc, subview, buffer);
+              builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
+              builder.create<scf::YieldOp>(loc);
+            });
+        // if has neighbor: receive halo data into buffer and copy to array
+        rewriter.create<scf::IfOp>(
+            loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+              offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
+                                       : OpFoldResult(builder.getIndexAttr(0));
+              builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
+              auto subview = builder.create<memref::SubViewOp>(
+                  loc, array, offsets, shape, strides);
+              builder.create<memref::CopyOp>(loc, buffer, subview);
+              builder.create<scf::YieldOp>(loc);
+            });
+        rewriter.create<memref::DeallocOp>(loc, buffer);
+        offsets[dim] = orgOffset;
+      };
+
+      genSendRecv(dim, false);
+      genSendRecv(dim, true);
+
+      // prepare shape and offsets for next split dim
+      auto _haloSz =
+          rewriter
+              .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+                                     toValue(haloSizes[currHaloDim * 2 + 1]))
+              .getResult();
+      // the shape for next halo excludes the halo on both ends for the
+      // current dim
+      shape[dim] =
+          rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
+              .getResult();
+      // the offsets for next halo starts after the down halo for the
+      // current dim
+      offsets[dim] = haloSizes[currHaloDim * 2];
+      // on to next halo
+      ++currHaloDim;
+    }
+    rewriter.eraseOp(op);
+    return mlir::success();
+  }
+};
+
 struct ConvertMeshToMPIPass
     : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
   using Base::Base;
 
   /// Run the dialect converter on the module.
   void runOnOperation() override {
-    getOperation()->walk([&](UpdateHaloOp op) {
-      SymbolTableCollection symbolTableCollection;
-      OpBuilder builder(op);
-      auto loc = op.getLoc();
-
-      auto toValue = [&builder, &loc](OpFoldResult &v) {
-        return v.is<Value>()
-                   ? v.get<Value>()
-                   : builder.create<::mlir::arith::ConstantOp>(
-                         loc,
-                         builder.getIndexAttr(
-                             cast<IntegerAttr>(v.get<Attribute>()).getInt()));
-      };
+    auto *ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
 
-      auto array = op.getInput();
-      auto rank = array.getType().getRank();
-      auto mesh = op.getMesh();
-      auto meshOp = getMesh(op, symbolTableCollection);
-      auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
-                                      op.getDynamicHaloSizes(), builder);
-      for (auto &sz : haloSizes) {
-        if (sz.is<Value>()) {
-          sz = builder
-                   .create<arith::IndexCastOp>(loc, builder.getIndexType(),
-                                               sz.get<Value>())
-                   .getResult();
-        }
-      }
-
-      SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
-      SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
-      SmallVector<OpFoldResult> shape(rank);
-      for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
-        if (ShapedType::isDynamic(s)) {
-          shape[i] = builder.create<memref::DimOp>(loc, array, s).getResult();
-        } else {
-          shape[i] = builder.getIndexAttr(s);
-        }
-      }
+    patterns.insert<ConvertUpdateHaloOp>(ctx);
 
-      auto tagAttr = builder.getI32IntegerAttr(91); // whatever
-      auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr);
-      auto zeroAttr = builder.getI32IntegerAttr(0); // whatever
-      auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
-      SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
-                                         builder.getIndexType());
-      auto myMultiIndex =
-          builder.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
-              .getResult();
-      auto currHaloDim = 0;
-
-      for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
-        if (!splitAxes.empty()) {
-          auto tmp = builder
-                         .create<NeighborsLinearIndicesOp>(
-                             loc, mesh, myMultiIndex, splitAxes)
-                         .getResults();
-          Value neighbourIDs[2] = {builder.create<arith::IndexCastOp>(
-                                       loc, builder.getI32Type(), tmp[0]),
-                                   builder.create<arith::IndexCastOp>(
-                                       loc, builder.getI32Type(), tmp[1])};
-          auto orgDimSize = shape[dim];
-          auto upperOffset = builder.create<arith::SubIOp>(
-              loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1]));
-
-          // make sure we send/recv in a way that does not lead to a dead-lock
-          // This is by far not optimal, this should be at least MPI_sendrecv
-          // and - probably even more importantly - buffers should be re-used
-          // Currently using temporary, contiguous buffer for MPI communication
-          auto genSendRecv = [&](auto dim, bool upperHalo) {
-            auto orgOffset = offsets[dim];
-            shape[dim] =
-                upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2];
-            auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
-            auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
-            auto hasFrom = builder.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::sge, from, zero);
-            auto hasTo = builder.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::sge, to, zero);
-            auto buffer = builder.create<memref::AllocOp>(
-                loc, shape, array.getType().getElementType());
-            builder.create<scf::IfOp>(
-                loc, hasTo, [&](OpBuilder &builder, Location loc) {
-                  offsets[dim] = upperHalo
-                                     ? OpFoldResult(builder.getIndexAttr(0))
-                                     : OpFoldResult(upperOffset);
-                  auto subview = builder.create<memref::SubViewOp>(
-                      loc, array, offsets, shape, strides);
-                  builder.create<memref::CopyOp>(loc, subview, buffer);
-                  builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag,
-                                              to);
-                  builder.create<scf::YieldOp>(loc);
-                });
-            builder.create<scf::IfOp>(
-                loc, hasFrom, [&](OpBuilder &builder, Location loc) {
-                  offsets[dim] = upperHalo
-                                     ? OpFoldResult(upperOffset)
-                                     : OpFoldResult(builder.getIndexAttr(0));
-                  builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
-                                              from);
-                  auto subview = builder.create<memref::SubViewOp>(
-                      loc, array, offsets, shape, strides);
-                  builder.create<memref::CopyOp>(loc, buffer, subview);
-                  builder.create<scf::YieldOp>(loc);
-                });
-            builder.create<memref::DeallocOp>(loc, buffer);
-            offsets[dim] = orgOffset;
-          };
-
-          genSendRecv(dim, false);
-          genSendRecv(dim, true);
-
-          shape[dim] = builder
-                           .create<arith::SubIOp>(
-                               loc, toValue(orgDimSize),
-                               builder
-                                   .create<arith::AddIOp>(
-                                       loc, toValue(haloSizes[dim * 2]),
-                                       toValue(haloSizes[dim * 2 + 1]))
-                                   .getResult())
-                           .getResult();
-          offsets[dim] = haloSizes[dim * 2];
-          ++currHaloDim;
-        }
-      }
-    });
+    (void)mlir::applyPatternsAndFoldGreedily(getOperation(),
+                                             std::move(patterns));
   }
 };
-} // namespace
\ No newline at end of file
+
+} // namespace
+
+// Create a pass that convert Mesh to MPI
+std::unique_ptr<::mlir::OperationPass<void>> createConvertMeshToMPIPass() {
+  return std::make_unique<ConvertMeshToMPIPass>();
+}
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 9ef826ca0cdace..5f563364272d96 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -1,34 +1,173 @@
-// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s
+// RUN: mlir-opt %s -convert-mesh-to-mpi | FileCheck %s
 
 // CHECK: mesh.mesh @mesh0
 mesh.mesh @mesh0(shape = 2x2x4)
 
-// -----
-
-// CHECK-LABEL: func @update_halo
-func.func @update_halo_1d(
-    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+// CHECK-LABEL: func @update_halo_1d_first
+func.func @update_halo_1d_first(
+  // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
-  // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
-  // CHECK-SAME: split_axes = {{\[\[}}0]]
-  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
-  %c2 = arith.constant 2 : i64
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
+  // CHECK-NEXT: return
   mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
-    halo_sizes = [2, %c2] : memref<12x12xi8>
+    halo_sizes = [2, 3] : memref<12x12xi8>
+  return
+}
+
+// CHECK-LABEL: func @update_halo_1d_second
+func.func @update_halo_1d_second(
+  // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
+  %arg0 : memref<12x12xi8>) {
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1]>> to memref<12x3xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
+  // CHECK-NEXT: return
+  mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
+    halo_sizes = [2, 3] : memref<12x12xi8>
   return
 }
 
+// CHECK-LABEL: func @update_halo_2d
 func.func @update_halo_2d(
-    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+    // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
-  %c2 = arith.constant 2 : i64
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
-  // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
-  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
-  // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
+  // CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<1x12xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<1x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<1x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<1x12xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<2x12xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<2x12xi8, strided<[12, 1]>> to memref<2x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<2x12xi8>
+  // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
+  // CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32
+  // CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32
+  // CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
+  // CHECK-NEXT: scf.if [[v9]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_3]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<?x3xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v8]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<?x3xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT:   memref.copy [[valloc_3]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<?x3xi8>
+  // CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
+  // CHECK-NEXT: scf.if [[v11]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_4]] : memref<?x4xi8, strided<[12, 1], offset: 12>> to memref<?x4xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<?x4xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v10]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<?x4xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_4]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<?x4xi8>
+  // CHECK-NEXT: return
   mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
-    halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
-    : memref<12x12xi8>
+      halo_sizes = [1, 2, 3, 4]
+      : memref<12x12xi8>
   return
 }



More information about the Mlir-commits mailing list