[Mlir-commits] [mlir] [mlir][mesh, MPI] Mesh2mpi (PR #104566)

Frank Schlimbach llvmlistbot at llvm.org
Wed Nov 27 07:48:45 PST 2024


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/104566

>From 1aa51f74277eace6dcaf6372ba645b4627548bb4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 14 Aug 2024 19:29:23 +0200
Subject: [PATCH 01/15] initial hack lowering mesh.update_halo to MPI

---
 .../mlir/Conversion/MeshToMPI/MeshToMPI.h     |  27 +++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  17 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  33 ++++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 mlir/lib/Conversion/MeshToMPI/CMakeLists.txt  |  22 +++
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 171 ++++++++++++++++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  19 ++
 .../MeshToMPI/convert-mesh-to-mpi.mlir        |  34 ++++
 9 files changed, 325 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
 create mode 100644 mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
 create mode 100644 mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir

diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
new file mode 100644
index 00000000000000..6a2c196da45577
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -0,0 +1,27 @@
+//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Lowers Mesh communication operations (updateHalo, AllGater, ...)
+/// to MPI primitives.
+std::unique_ptr<Pass> createConvertMeshToMPIPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 2ab32836c80b1c..b577aa83946f23 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -51,6 +51,7 @@
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 4d272ba219c6f1..83e0c5a06c43f7 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -878,6 +878,23 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MeshToMPI
+//===----------------------------------------------------------------------===//
+
+def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
+  let summary = "Convert Mesh dialect to MPI dialect.";
+  let description = [{
+    This pass converts communication operations
+    from the Mesh dialect to operations from the MPI dialect.
+  }];
+  let dependentDialects = [
+    "memref::MemRefDialect",
+    "mpi::MPIDialect",
+    "scf::SCFDialect"
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 19498fe5a32d69..2c2b6e20f3654d 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -156,6 +156,39 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
   ];
 }
 
+def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary =
+      "For given split axes get the linear index the direct neighbor processes.";
+  let description = [{
+    Example:
+    ```
+    %idx = mesh.neighbor_linear_index on @mesh for $device 
+               split_axes = $split_axes : index
+    ```
+    Given `@mesh` with shape `(10, 20, 30)`,
+          `device` = `(1, 2, 3)`
+          `$split_axes` = `[1]`
+    it returns the linear indices of the processes at positions `(1, 1, 3)`: `633`
+    and `(1, 3, 3)`: `693`.
+
+    A negative value is returned if `$device` has no neighbor in the given
+    direction along the given `split_axes`.
+  }];
+  let arguments = (ins FlatSymbolRefAttr:$mesh,
+                       Variadic<Index>:$device,
+                       Mesh_MeshAxesAttr:$split_axes);
+  let results = (outs Index:$neighbor_down, Index:$neighbor_up);
+  let assemblyFormat =  [{
+      `on` $mesh `[` $device `]`
+      `split_axes` `=` $split_axes
+      attr-dict `:` type(results)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Sharding operations.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 6651d87162257f..62461c0cea08af 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
+add_subdirectory(MeshToMPI)
 add_subdirectory(NVGPUToNVVM)
 add_subdirectory(NVVMToLLVM)
 add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
new file mode 100644
index 00000000000000..95815a683f6d6a
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMeshToMPI
+  MeshToMPI.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRIR
+  MLIRLinalgTransforms
+  MLIRMemRefDialect
+  MLIRPass
+  MLIRMeshDialect
+  MLIRMPIDialect
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
new file mode 100644
index 00000000000000..b4cf9da8497a2d
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -0,0 +1,171 @@
+//===- MeshToMPI.cpp - Mesh to MPI  dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Mesh communicatin ops tp MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+#define DEBUG_TYPE "mesh-to-mpi"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+struct ConvertMeshToMPIPass
+    : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+  using Base::Base;
+
+  /// Run the dialect converter on the module.
+  void runOnOperation() override {
+    getOperation()->walk([&](UpdateHaloOp op) {
+      SymbolTableCollection symbolTableCollection;
+      OpBuilder builder(op);
+      auto loc = op.getLoc();
+
+      auto toValue = [&builder, &loc](OpFoldResult &v) {
+        return v.is<Value>()
+                   ? v.get<Value>()
+                   : builder.create<::mlir::arith::ConstantOp>(
+                         loc,
+                         builder.getIndexAttr(
+                             cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+      };
+
+      auto array = op.getInput();
+      auto rank = array.getType().getRank();
+      auto mesh = op.getMesh();
+      auto meshOp = getMesh(op, symbolTableCollection);
+      auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+                                      op.getDynamicHaloSizes(), builder);
+      for (auto &sz : haloSizes) {
+        if (sz.is<Value>()) {
+          sz = builder
+                   .create<arith::IndexCastOp>(loc, builder.getIndexType(),
+                                               sz.get<Value>())
+                   .getResult();
+        }
+      }
+
+      SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+      SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
+      SmallVector<OpFoldResult> shape(rank);
+      for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+        if (ShapedType::isDynamic(s)) {
+          shape[i] = builder.create<memref::DimOp>(loc, array, s).getResult();
+        } else {
+          shape[i] = builder.getIndexAttr(s);
+        }
+      }
+
+      auto tagAttr = builder.getI32IntegerAttr(91); // whatever
+      auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+      auto zeroAttr = builder.getI32IntegerAttr(0); // whatever
+      auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+      SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+                                         builder.getIndexType());
+      auto myMultiIndex =
+          builder.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+              .getResult();
+      auto currHaloDim = 0;
+
+      for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+        if (!splitAxes.empty()) {
+          auto tmp = builder
+                         .create<NeighborsLinearIndicesOp>(
+                             loc, mesh, myMultiIndex, splitAxes)
+                         .getResults();
+          Value neighbourIDs[2] = {builder.create<arith::IndexCastOp>(
+                                       loc, builder.getI32Type(), tmp[0]),
+                                   builder.create<arith::IndexCastOp>(
+                                       loc, builder.getI32Type(), tmp[1])};
+          auto orgDimSize = shape[dim];
+          auto upperOffset = builder.create<arith::SubIOp>(
+              loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1]));
+
+          // make sure we send/recv in a way that does not lead to a dead-lock
+          // This is by far not optimal, this should be at least MPI_sendrecv
+          // and - probably even more importantly - buffers should be re-used
+          // Currently using temporary, contiguous buffer for MPI communication
+          auto genSendRecv = [&](auto dim, bool upperHalo) {
+            auto orgOffset = offsets[dim];
+            shape[dim] =
+                upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2];
+            auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+            auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+            auto hasFrom = builder.create<arith::CmpIOp>(
+                loc, arith::CmpIPredicate::sge, from, zero);
+            auto hasTo = builder.create<arith::CmpIOp>(
+                loc, arith::CmpIPredicate::sge, to, zero);
+            auto buffer = builder.create<memref::AllocOp>(
+                loc, shape, array.getType().getElementType());
+            builder.create<scf::IfOp>(
+                loc, hasTo, [&](OpBuilder &builder, Location loc) {
+                  offsets[dim] = upperHalo
+                                     ? OpFoldResult(builder.getIndexAttr(0))
+                                     : OpFoldResult(upperOffset);
+                  auto subview = builder.create<memref::SubViewOp>(
+                      loc, array, offsets, shape, strides);
+                  builder.create<memref::CopyOp>(loc, subview, buffer);
+                  builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag,
+                                              to);
+                  builder.create<scf::YieldOp>(loc);
+                });
+            builder.create<scf::IfOp>(
+                loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+                  offsets[dim] = upperHalo
+                                     ? OpFoldResult(upperOffset)
+                                     : OpFoldResult(builder.getIndexAttr(0));
+                  builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
+                                              from);
+                  auto subview = builder.create<memref::SubViewOp>(
+                      loc, array, offsets, shape, strides);
+                  builder.create<memref::CopyOp>(loc, buffer, subview);
+                  builder.create<scf::YieldOp>(loc);
+                });
+            builder.create<memref::DeallocOp>(loc, buffer);
+            offsets[dim] = orgOffset;
+          };
+
+          genSendRecv(dim, false);
+          genSendRecv(dim, true);
+
+          shape[dim] = builder
+                           .create<arith::SubIOp>(
+                               loc, toValue(orgDimSize),
+                               builder
+                                   .create<arith::AddIOp>(
+                                       loc, toValue(haloSizes[dim * 2]),
+                                       toValue(haloSizes[dim * 2 + 1]))
+                                   .getResult())
+                           .getResult();
+          offsets[dim] = haloSizes[dim * 2];
+          ++currHaloDim;
+        }
+      }
+    });
+  }
+};
+} // namespace
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c5570d8ee8a443..33460ff25e9e45 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -837,6 +837,25 @@ void ProcessLinearIndexOp::getAsmResultNames(
   setNameFn(getResult(), "proc_linear_idx");
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.neighbors_linear_indices op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return success();
+}
+
+void NeighborsLinearIndicesOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getNeighborDown(), "down_linear_idx");
+  setNameFn(getNeighborUp(), "up_linear_idx");
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
new file mode 100644
index 00000000000000..9ef826ca0cdace
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s
+
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 2x2x4)
+
+// -----
+
+// CHECK-LABEL: func @update_halo
+func.func @update_halo_1d(
+    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+    %arg0 : memref<12x12xi8>) {
+  // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
+  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  // CHECK-SAME: split_axes = {{\[\[}}0]]
+  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
+  %c2 = arith.constant 2 : i64
+  mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+    halo_sizes = [2, %c2] : memref<12x12xi8>
+  return
+}
+
+func.func @update_halo_2d(
+    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+    %arg0 : memref<12x12xi8>) {
+  %c2 = arith.constant 2 : i64
+  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+  // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
+  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
+  // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
+  mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
+    halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
+    : memref<12x12xi8>
+  return
+}

>From 8b8c6e4a12e1301d126b6dcd78ae69a506a58e12 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 16 Aug 2024 10:55:28 +0200
Subject: [PATCH 02/15] dim fixes, proper testing

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 306 ++++++++++--------
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 179 ++++++++--
 2 files changed, 339 insertions(+), 146 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b4cf9da8497a2d..42d885a109ee79 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements a translation of Mesh communicatin ops tp MPI ops.
+// This file implements a translation of Mesh communication ops tp MPI ops.
 //
 //===----------------------------------------------------------------------===//
 
@@ -21,6 +21,8 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "mesh-to-mpi"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -34,138 +36,190 @@ using namespace mlir;
 using namespace mlir::mesh;
 
 namespace {
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertUpdateHaloOp
+    : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::mesh::UpdateHaloOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    // Halos are exchanged as 2 blocks per dimension (one for each side: down
+    // and up). It is assumed that the last dim in a default memref is
+    // contiguous, hence iteration starts with the complete halo on the first
+    // dim which should be contiguous (unless the source is not). The size of
+    // the exchanged data will decrease when iterating over dimensions. That's
+    // good because the halos of last dim will be most fragmented.
+    // memref.subview is used to read and write the halo data from and to the
+    // local data. subviews and halos have dynamic and static values, so
+    // OpFoldResults are used whenever possible.
+
+    SymbolTableCollection symbolTableCollection;
+    auto loc = op.getLoc();
+
+    // convert a OpFoldResult into a Value
+    auto toValue = [&rewriter, &loc](OpFoldResult &v) {
+      return v.is<Value>()
+                 ? v.get<Value>()
+                 : rewriter.create<::mlir::arith::ConstantOp>(
+                       loc,
+                       rewriter.getIndexAttr(
+                           cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+    };
+
+    auto array = op.getInput();
+    auto rank = array.getType().getRank();
+    auto mesh = op.getMesh();
+    auto meshOp = getMesh(op, symbolTableCollection);
+    auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+                                    op.getDynamicHaloSizes(), rewriter);
+    // subviews need Index values
+    for (auto &sz : haloSizes) {
+      if (sz.is<Value>()) {
+        sz = rewriter
+                 .create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+                                             sz.get<Value>())
+                 .getResult();
+      }
+    }
+
+    // most of the offset/size/stride data is the same for all dims
+    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> shape(rank);
+    // we need the actual shape to compute offsets and sizes
+    for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+      if (ShapedType::isDynamic(s)) {
+        shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+      } else {
+        shape[i] = rewriter.getIndexAttr(s);
+      }
+    }
+
+    auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
+    auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+    auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
+    auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+    SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+                                       rewriter.getIndexType());
+    auto myMultiIndex =
+        rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+            .getResult();
+    // halo sizes are provided for split dimensions only
+    auto currHaloDim = 0;
+
+    for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+      if (splitAxes.empty()) {
+        continue;
+      }
+      // Get the linearized ids of the neighbors (down and up) for the
+      // given split
+      auto tmp = rewriter
+                     .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
+                                                       splitAxes)
+                     .getResults();
+      // MPI operates on i32...
+      Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
+                                   loc, rewriter.getI32Type(), tmp[0]),
+                               rewriter.create<arith::IndexCastOp>(
+                                   loc, rewriter.getI32Type(), tmp[1])};
+      // store for later
+      auto orgDimSize = shape[dim];
+      // this dim's offset to the start of the upper halo
+      auto upperOffset = rewriter.create<arith::SubIOp>(
+          loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
+
+      // Make sure we send/recv in a way that does not lead to a dead-lock.
+      // The current approach is by far not optimal, this should be at least
+      // be a red-black pattern or using MPI_sendrecv.
+      // Also, buffers should be re-used.
+      // Still using temporary contiguous buffers for MPI communication...
+      // Still yielding a "serialized" communication pattern...
+      auto genSendRecv = [&](auto dim, bool upperHalo) {
+        auto orgOffset = offsets[dim];
+        shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+                               : haloSizes[currHaloDim * 2];
+        // Check if we need to send and/or receive
+        // Processes on the mesh borders have only one neighbor
+        auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+        auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+        auto hasFrom = rewriter.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::sge, from, zero);
+        auto hasTo = rewriter.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::sge, to, zero);
+        auto buffer = rewriter.create<memref::AllocOp>(
+            loc, shape, array.getType().getElementType());
+        // if has neighbor: copy halo data from array to buffer and send
+        rewriter.create<scf::IfOp>(
+            loc, hasTo, [&](OpBuilder &builder, Location loc) {
+              offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
+                                       : OpFoldResult(upperOffset);
+              auto subview = builder.create<memref::SubViewOp>(
+                  loc, array, offsets, shape, strides);
+              builder.create<memref::CopyOp>(loc, subview, buffer);
+              builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
+              builder.create<scf::YieldOp>(loc);
+            });
+        // if has neighbor: receive halo data into buffer and copy to array
+        rewriter.create<scf::IfOp>(
+            loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+              offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
+                                       : OpFoldResult(builder.getIndexAttr(0));
+              builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
+              auto subview = builder.create<memref::SubViewOp>(
+                  loc, array, offsets, shape, strides);
+              builder.create<memref::CopyOp>(loc, buffer, subview);
+              builder.create<scf::YieldOp>(loc);
+            });
+        rewriter.create<memref::DeallocOp>(loc, buffer);
+        offsets[dim] = orgOffset;
+      };
+
+      genSendRecv(dim, false);
+      genSendRecv(dim, true);
+
+      // prepare shape and offsets for next split dim
+      auto _haloSz =
+          rewriter
+              .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+                                     toValue(haloSizes[currHaloDim * 2 + 1]))
+              .getResult();
+      // the shape for next halo excludes the halo on both ends for the
+      // current dim
+      shape[dim] =
+          rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
+              .getResult();
+      // the offsets for next halo starts after the down halo for the
+      // current dim
+      offsets[dim] = haloSizes[currHaloDim * 2];
+      // on to next halo
+      ++currHaloDim;
+    }
+    rewriter.eraseOp(op);
+    return mlir::success();
+  }
+};
+
 struct ConvertMeshToMPIPass
     : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
   using Base::Base;
 
   /// Run the dialect converter on the module.
   void runOnOperation() override {
-    getOperation()->walk([&](UpdateHaloOp op) {
-      SymbolTableCollection symbolTableCollection;
-      OpBuilder builder(op);
-      auto loc = op.getLoc();
-
-      auto toValue = [&builder, &loc](OpFoldResult &v) {
-        return v.is<Value>()
-                   ? v.get<Value>()
-                   : builder.create<::mlir::arith::ConstantOp>(
-                         loc,
-                         builder.getIndexAttr(
-                             cast<IntegerAttr>(v.get<Attribute>()).getInt()));
-      };
+    auto *ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
 
-      auto array = op.getInput();
-      auto rank = array.getType().getRank();
-      auto mesh = op.getMesh();
-      auto meshOp = getMesh(op, symbolTableCollection);
-      auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
-                                      op.getDynamicHaloSizes(), builder);
-      for (auto &sz : haloSizes) {
-        if (sz.is<Value>()) {
-          sz = builder
-                   .create<arith::IndexCastOp>(loc, builder.getIndexType(),
-                                               sz.get<Value>())
-                   .getResult();
-        }
-      }
-
-      SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
-      SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
-      SmallVector<OpFoldResult> shape(rank);
-      for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
-        if (ShapedType::isDynamic(s)) {
-          shape[i] = builder.create<memref::DimOp>(loc, array, s).getResult();
-        } else {
-          shape[i] = builder.getIndexAttr(s);
-        }
-      }
+    patterns.insert<ConvertUpdateHaloOp>(ctx);
 
-      auto tagAttr = builder.getI32IntegerAttr(91); // whatever
-      auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr);
-      auto zeroAttr = builder.getI32IntegerAttr(0); // whatever
-      auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
-      SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
-                                         builder.getIndexType());
-      auto myMultiIndex =
-          builder.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
-              .getResult();
-      auto currHaloDim = 0;
-
-      for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
-        if (!splitAxes.empty()) {
-          auto tmp = builder
-                         .create<NeighborsLinearIndicesOp>(
-                             loc, mesh, myMultiIndex, splitAxes)
-                         .getResults();
-          Value neighbourIDs[2] = {builder.create<arith::IndexCastOp>(
-                                       loc, builder.getI32Type(), tmp[0]),
-                                   builder.create<arith::IndexCastOp>(
-                                       loc, builder.getI32Type(), tmp[1])};
-          auto orgDimSize = shape[dim];
-          auto upperOffset = builder.create<arith::SubIOp>(
-              loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1]));
-
-          // make sure we send/recv in a way that does not lead to a dead-lock
-          // This is by far not optimal, this should be at least MPI_sendrecv
-          // and - probably even more importantly - buffers should be re-used
-          // Currently using temporary, contiguous buffer for MPI communication
-          auto genSendRecv = [&](auto dim, bool upperHalo) {
-            auto orgOffset = offsets[dim];
-            shape[dim] =
-                upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2];
-            auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
-            auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
-            auto hasFrom = builder.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::sge, from, zero);
-            auto hasTo = builder.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::sge, to, zero);
-            auto buffer = builder.create<memref::AllocOp>(
-                loc, shape, array.getType().getElementType());
-            builder.create<scf::IfOp>(
-                loc, hasTo, [&](OpBuilder &builder, Location loc) {
-                  offsets[dim] = upperHalo
-                                     ? OpFoldResult(builder.getIndexAttr(0))
-                                     : OpFoldResult(upperOffset);
-                  auto subview = builder.create<memref::SubViewOp>(
-                      loc, array, offsets, shape, strides);
-                  builder.create<memref::CopyOp>(loc, subview, buffer);
-                  builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag,
-                                              to);
-                  builder.create<scf::YieldOp>(loc);
-                });
-            builder.create<scf::IfOp>(
-                loc, hasFrom, [&](OpBuilder &builder, Location loc) {
-                  offsets[dim] = upperHalo
-                                     ? OpFoldResult(upperOffset)
-                                     : OpFoldResult(builder.getIndexAttr(0));
-                  builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
-                                              from);
-                  auto subview = builder.create<memref::SubViewOp>(
-                      loc, array, offsets, shape, strides);
-                  builder.create<memref::CopyOp>(loc, buffer, subview);
-                  builder.create<scf::YieldOp>(loc);
-                });
-            builder.create<memref::DeallocOp>(loc, buffer);
-            offsets[dim] = orgOffset;
-          };
-
-          genSendRecv(dim, false);
-          genSendRecv(dim, true);
-
-          shape[dim] = builder
-                           .create<arith::SubIOp>(
-                               loc, toValue(orgDimSize),
-                               builder
-                                   .create<arith::AddIOp>(
-                                       loc, toValue(haloSizes[dim * 2]),
-                                       toValue(haloSizes[dim * 2 + 1]))
-                                   .getResult())
-                           .getResult();
-          offsets[dim] = haloSizes[dim * 2];
-          ++currHaloDim;
-        }
-      }
-    });
+    (void)mlir::applyPatternsAndFoldGreedily(getOperation(),
+                                             std::move(patterns));
   }
 };
-} // namespace
\ No newline at end of file
+
+} // namespace
+
+// Create a pass that convert Mesh to MPI
+std::unique_ptr<::mlir::OperationPass<void>> createConvertMeshToMPIPass() {
+  return std::make_unique<ConvertMeshToMPIPass>();
+}
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 9ef826ca0cdace..5f563364272d96 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -1,34 +1,173 @@
-// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s
+// RUN: mlir-opt %s -convert-mesh-to-mpi | FileCheck %s
 
 // CHECK: mesh.mesh @mesh0
 mesh.mesh @mesh0(shape = 2x2x4)
 
-// -----
-
-// CHECK-LABEL: func @update_halo
-func.func @update_halo_1d(
-    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+// CHECK-LABEL: func @update_halo_1d_first
+func.func @update_halo_1d_first(
+  // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
-  // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
-  // CHECK-SAME: split_axes = {{\[\[}}0]]
-  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
-  %c2 = arith.constant 2 : i64
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
+  // CHECK-NEXT: return
   mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
-    halo_sizes = [2, %c2] : memref<12x12xi8>
+    halo_sizes = [2, 3] : memref<12x12xi8>
+  return
+}
+
+// CHECK-LABEL: func @update_halo_1d_second
+func.func @update_halo_1d_second(
+  // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
+  %arg0 : memref<12x12xi8>) {
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1]>> to memref<12x3xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
+  // CHECK-NEXT: return
+  mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
+    halo_sizes = [2, 3] : memref<12x12xi8>
   return
 }
 
+// CHECK-LABEL: func @update_halo_2d
 func.func @update_halo_2d(
-    // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+    // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
-  %c2 = arith.constant 2 : i64
-  // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
-  // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
-  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
-  // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
+  // CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<1x12xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<1x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<1x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<1x12xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<2x12xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<2x12xi8, strided<[12, 1]>> to memref<2x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<2x12xi8>
+  // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
+  // CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32
+  // CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32
+  // CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
+  // CHECK-NEXT: scf.if [[v9]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_3]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<?x3xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v8]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<?x3xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT:   memref.copy [[valloc_3]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<?x3xi8>
+  // CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
+  // CHECK-NEXT: scf.if [[v11]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_4]] : memref<?x4xi8, strided<[12, 1], offset: 12>> to memref<?x4xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<?x4xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v10]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<?x4xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_4]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<?x4xi8>
+  // CHECK-NEXT: return
   mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
-    halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
-    : memref<12x12xi8>
+      halo_sizes = [1, 2, 3, 4]
+      : memref<12x12xi8>
   return
 }

>From aeee16c6e3454ab7c04168364d544ffda2e7f344 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Aug 2024 19:23:13 +0200
Subject: [PATCH 03/15] fixed corner halos by reversing data-exchanges from
 high to low dims

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   |  91 +++++-----
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 161 +++++++++---------
 2 files changed, 137 insertions(+), 115 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 42d885a109ee79..9cf9458ce2b687 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -70,6 +70,7 @@ struct ConvertUpdateHaloOp
 
     auto array = op.getInput();
     auto rank = array.getType().getRank();
+    auto opSplitAxes = op.getSplitAxes().getAxes();
     auto mesh = op.getMesh();
     auto meshOp = getMesh(op, symbolTableCollection);
     auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
@@ -87,32 +88,54 @@ struct ConvertUpdateHaloOp
     // most of the offset/size/stride data is the same for all dims
     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> shape(rank);
+    SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
+    auto currHaloDim = -1; // halo sizes are provided for split dimensions only
     // we need the actual shape to compute offsets and sizes
-    for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+    for (auto i = 0; i < rank; ++i) {
+      auto s = array.getType().getShape()[i];
       if (ShapedType::isDynamic(s)) {
         shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
       } else {
         shape[i] = rewriter.getIndexAttr(s);
       }
+
+      if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
+        ++currHaloDim;
+        // the offsets for lower dim sstarts after their down halo
+        offsets[i] = haloSizes[currHaloDim * 2];
+
+        // prepare shape and offsets of highest dim's halo exchange
+        auto _haloSz =
+            rewriter
+                .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+                                       toValue(haloSizes[currHaloDim * 2 + 1]))
+                .getResult();
+        // the halo shape of lower dims exlude the halos
+        dimSizes[i] =
+            rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
+                .getResult();
+      } else {
+        dimSizes[i] = shape[i];
+      }
     }
 
     auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
     auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
     auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
     auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+
     SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
                                        rewriter.getIndexType());
     auto myMultiIndex =
         rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
             .getResult();
-    // halo sizes are provided for split dimensions only
-    auto currHaloDim = 0;
-
-    for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+    // traverse all split axes from high to low dim
+    for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
+      auto splitAxes = opSplitAxes[dim];
       if (splitAxes.empty()) {
         continue;
       }
+      assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
       // Get the linearized ids of the neighbors (down and up) for the
       // given split
       auto tmp = rewriter
@@ -124,11 +147,13 @@ struct ConvertUpdateHaloOp
                                    loc, rewriter.getI32Type(), tmp[0]),
                                rewriter.create<arith::IndexCastOp>(
                                    loc, rewriter.getI32Type(), tmp[1])};
-      // store for later
-      auto orgDimSize = shape[dim];
-      // this dim's offset to the start of the upper halo
-      auto upperOffset = rewriter.create<arith::SubIOp>(
+
+      auto lowerRecvOffset = rewriter.getIndexAttr(0);
+      auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
+      auto upperRecvOffset = rewriter.create<arith::SubIOp>(
           loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
+      auto upperSendOffset = rewriter.create<arith::SubIOp>(
+          loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
 
       // Make sure we send/recv in a way that does not lead to a dead-lock.
       // The current approach is by far not optimal, this should be at least
@@ -136,10 +161,10 @@ struct ConvertUpdateHaloOp
       // Also, buffers should be re-used.
       // Still using temporary contiguous buffers for MPI communication...
       // Still yielding a "serialized" communication pattern...
-      auto genSendRecv = [&](auto dim, bool upperHalo) {
+      auto genSendRecv = [&](bool upperHalo) {
         auto orgOffset = offsets[dim];
-        shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
-                               : haloSizes[currHaloDim * 2];
+        dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+                                  : haloSizes[currHaloDim * 2];
         // Check if we need to send and/or receive
         // Processes on the mesh borders have only one neighbor
         auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
@@ -149,14 +174,14 @@ struct ConvertUpdateHaloOp
         auto hasTo = rewriter.create<arith::CmpIOp>(
             loc, arith::CmpIPredicate::sge, to, zero);
         auto buffer = rewriter.create<memref::AllocOp>(
-            loc, shape, array.getType().getElementType());
+            loc, dimSizes, array.getType().getElementType());
         // if has neighbor: copy halo data from array to buffer and send
         rewriter.create<scf::IfOp>(
             loc, hasTo, [&](OpBuilder &builder, Location loc) {
-              offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
-                                       : OpFoldResult(upperOffset);
+              offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
+                                       : OpFoldResult(upperSendOffset);
               auto subview = builder.create<memref::SubViewOp>(
-                  loc, array, offsets, shape, strides);
+                  loc, array, offsets, dimSizes, strides);
               builder.create<memref::CopyOp>(loc, subview, buffer);
               builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
               builder.create<scf::YieldOp>(loc);
@@ -164,11 +189,11 @@ struct ConvertUpdateHaloOp
         // if has neighbor: receive halo data into buffer and copy to array
         rewriter.create<scf::IfOp>(
             loc, hasFrom, [&](OpBuilder &builder, Location loc) {
-              offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
-                                       : OpFoldResult(builder.getIndexAttr(0));
+              offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
+                                       : OpFoldResult(lowerRecvOffset);
               builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
               auto subview = builder.create<memref::SubViewOp>(
-                  loc, array, offsets, shape, strides);
+                  loc, array, offsets, dimSizes, strides);
               builder.create<memref::CopyOp>(loc, buffer, subview);
               builder.create<scf::YieldOp>(loc);
             });
@@ -176,25 +201,15 @@ struct ConvertUpdateHaloOp
         offsets[dim] = orgOffset;
       };
 
-      genSendRecv(dim, false);
-      genSendRecv(dim, true);
-
-      // prepare shape and offsets for next split dim
-      auto _haloSz =
-          rewriter
-              .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
-                                     toValue(haloSizes[currHaloDim * 2 + 1]))
-              .getResult();
-      // the shape for next halo excludes the halo on both ends for the
-      // current dim
-      shape[dim] =
-          rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
-              .getResult();
-      // the offsets for next halo starts after the down halo for the
-      // current dim
-      offsets[dim] = haloSizes[currHaloDim * 2];
+      genSendRecv(false);
+      genSendRecv(true);
+
+      // the shape for lower dims include higher dims' halos
+      dimSizes[dim] = shape[dim];
+      // -> the offset for higher dims is always 0
+      offsets[dim] = rewriter.getIndexAttr(0);
       // on to next halo
-      ++currHaloDim;
+      --currHaloDim;
     }
     rewriter.eraseOp(op);
     return mlir::success();
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 5f563364272d96..c3b0dc12e6d746 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -6,8 +6,10 @@ mesh.mesh @mesh0(shape = 2x2x4)
 // CHECK-LABEL: func @update_halo_1d_first
 func.func @update_halo_1d_first(
   // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
-    %arg0 : memref<12x12xi8>) {
+  %arg0 : memref<12x12xi8>) {
+  // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
   // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
   // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
   // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
   // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
@@ -18,7 +20,7 @@ func.func @update_halo_1d_first(
   // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
   // CHECK-NEXT: scf.if [[v3]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
   // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
   // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
   // CHECK-NEXT: }
@@ -32,8 +34,8 @@ func.func @update_halo_1d_first(
   // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
   // CHECK-NEXT: scf.if [[v5]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8>
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8>
   // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
   // CHECK-NEXT: }
   // CHECK-NEXT: scf.if [[v4]] {
@@ -42,9 +44,9 @@ func.func @update_halo_1d_first(
   // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
   // CHECK-NEXT: }
   // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
-  // CHECK-NEXT: return
   mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
     halo_sizes = [2, 3] : memref<12x12xi8>
+  // CHECK-NEXT: return
   return
 }
 
@@ -52,44 +54,46 @@ func.func @update_halo_1d_first(
 func.func @update_halo_1d_second(
   // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
   %arg0 : memref<12x12xi8>) {
-  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
-  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
-  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
-  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
-  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
-  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
-  // CHECK-NEXT: scf.if [[v3]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
-  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v2]] {
-  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
-  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
-  // CHECK-NEXT: scf.if [[v5]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1]>> to memref<12x3xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v4]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
-  // CHECK-NEXT: return
+  //CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
+  //CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  //CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
+  //CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  //CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  //CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  //CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
+  //CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  //CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  //CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  //CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  //CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
+  //CHECK-NEXT: scf.if [[v3]] {
+  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c7] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
+  //CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
+  //CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
+  //CHECK-NEXT: }
+  //CHECK-NEXT: scf.if [[v2]] {
+  //CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
+  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
+  //CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
+  //CHECK-NEXT: }
+  //CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
+  //CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  //CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  //CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
+  //CHECK-NEXT: scf.if [[v5]] {
+  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c2] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  //CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1], offset: ?>> to memref<12x3xi8>
+  //CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
+  //CHECK-NEXT: }
+  //CHECK-NEXT: scf.if [[v4]] {
+  //CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
+  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  //CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+  //CHECK-NEXT: }
+  //CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
   mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
     halo_sizes = [2, 3] : memref<12x12xi8>
+  //CHECK-NEXT: return
   return
 }
 
@@ -97,77 +101,80 @@ func.func @update_halo_1d_second(
 func.func @update_halo_2d(
     // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
+  // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
+  // CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index
+  // CHECK-NEXT: [[vc5:%.*]] = arith.constant 5 : index
   // CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index
+  // CHECK-NEXT: [[vc3:%.*]] = arith.constant 3 : index
   // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
-  // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
   // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
   // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
   // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
   // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
   // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
   // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<1x12xi8>
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
   // CHECK-NEXT: scf.if [[v3]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
-  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<1x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c5] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<?x3xi8>, i32, i32
   // CHECK-NEXT: }
   // CHECK-NEXT: scf.if [[v2]] {
-  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<1x12xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<?x3xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
   // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<1x12xi8>
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<?x3xi8>
   // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<2x12xi8>
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
   // CHECK-NEXT: scf.if [[v5]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<2x12xi8, strided<[12, 1]>> to memref<2x12xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c3] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<?x4xi8, strided<[12, 1], offset: ?>> to memref<?x4xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<?x4xi8>, i32, i32
   // CHECK-NEXT: }
   // CHECK-NEXT: scf.if [[v4]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<?x4xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
   // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<2x12xi8>
-  // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<?x4xi8>
+  // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
   // CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32
   // CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32
   // CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
+  // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc() : memref<1x12xi8>
   // CHECK-NEXT: scf.if [[v9]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_3]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<?x3xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_3]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<1x12xi8>, i32, i32
   // CHECK-NEXT: }
   // CHECK-NEXT: scf.if [[v8]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<?x3xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
-  // CHECK-NEXT:   memref.copy [[valloc_3]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+  // CHECK-NEXT:   mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<1x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc_3]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
   // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<?x3xi8>
+  // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<1x12xi8>
   // CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
   // CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
+  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<2x12xi8>
   // CHECK-NEXT: scf.if [[v11]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: 12>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_4]] : memref<?x4xi8, strided<[12, 1], offset: 12>> to memref<?x4xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<?x4xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc1]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_4]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<2x12xi8>, i32, i32
   // CHECK-NEXT: }
   // CHECK-NEXT: scf.if [[v10]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<?x4xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[valloc_4]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_4]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
   // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<?x4xi8>
-  // CHECK-NEXT: return
+  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<2x12xi8>
   mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
       halo_sizes = [1, 2, 3, 4]
       : memref<12x12xi8>
+  // CHECK-NEXT: return
   return
 }

>From 6e967fb9f51b542af7b5244eb609875e47433cd8 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 21 Aug 2024 12:25:08 +0200
Subject: [PATCH 04/15] addressed review comments (docs, formatting)

---
 .../mlir/Conversion/MeshToMPI/MeshToMPI.h        |  2 +-
 mlir/include/mlir/Conversion/Passes.td           |  2 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td     |  6 +++---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp      | 16 +++++++++-------
 4 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
index 6a2c196da45577..b8803f386f7356 100644
--- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -1,4 +1,4 @@
-//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===//
+//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 83e0c5a06c43f7..2781fab917048d 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -886,7 +886,7 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
   let summary = "Convert Mesh dialect to MPI dialect.";
   let description = [{
     This pass converts communication operations
-    from the Mesh dialect to operations from the MPI dialect.
+    from the Mesh dialect to the MPI dialect.
   }];
   let dependentDialects = [
     "memref::MemRefDialect",
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 2c2b6e20f3654d..e6f61aa84a1312 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -162,7 +162,7 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
   DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
   let summary =
-      "For given split axes get the linear index the direct neighbor processes.";
+      "For given split axes get the linear indices of the direct neighbor processes.";
   let description = [{
     Example:
     ```
@@ -172,8 +172,8 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
     Given `@mesh` with shape `(10, 20, 30)`,
           `device` = `(1, 2, 3)`
           `$split_axes` = `[1]`
-    it returns the linear indices of the processes at positions `(1, 1, 3)`: `633`
-    and `(1, 3, 3)`: `693`.
+    returns two indices, `633` and `693`, which correspond to the index of the previous
+    process `(1, 1, 3)`, and the next process `(1, 3, 3) along the split axis `1`.
 
     A negative value is returned if `$device` has no neighbor in the given
     direction along the given `split_axes`.
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 9cf9458ce2b687..ea1323e43462cd 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -45,15 +45,17 @@ struct ConvertUpdateHaloOp
   mlir::LogicalResult
   matchAndRewrite(mlir::mesh::UpdateHaloOp op,
                   mlir::PatternRewriter &rewriter) const override {
+    // The input/output memref is assumed to be in C memory order.
     // Halos are exchanged as 2 blocks per dimension (one for each side: down
-    // and up). It is assumed that the last dim in a default memref is
-    // contiguous, hence iteration starts with the complete halo on the first
-    // dim which should be contiguous (unless the source is not). The size of
-    // the exchanged data will decrease when iterating over dimensions. That's
-    // good because the halos of last dim will be most fragmented.
+    // and up). For each haloed dimension `d`, the exchanged blocks are
+    // expressed as multi-dimensional subviews. The subviews include potential
+    // halos of higher dimensions `dh > d`, no halos for the lower dimensions
+    // `dl < d` and for dimension `d` the currently exchanged halo only.
+    // By iterating form higher to lower dimensions this also updates the halos
+    // in the 'corners'.
     // memref.subview is used to read and write the halo data from and to the
-    // local data. subviews and halos have dynamic and static values, so
-    // OpFoldResults are used whenever possible.
+    // local data. Because subviews and halos can have mixed dynamic and static
+    // shapes, OpFoldResults are used whenever possible.
 
     SymbolTableCollection symbolTableCollection;
     auto loc = op.getLoc();

>From a63cfa3446a065001a34512e31998a51f738ca39 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 3 Sep 2024 10:41:22 +0200
Subject: [PATCH 05/15] newline

---
 mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
index b8803f386f7356..04271f8ab67b95 100644
--- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -24,4 +24,4 @@ std::unique_ptr<Pass> createConvertMeshToMPIPass();
 
 } // namespace mlir
 
-#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
\ No newline at end of file
+#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H

>From 38c21af59efdd57b31f8d8daf5b02f84c4d83dbc Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 31 Oct 2024 17:54:52 +0100
Subject: [PATCH 06/15] removing source from UpdateHaloOp, because not required
 for destination passing style

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 19 +++++++------------
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 10 +++++-----
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  4 +---
 mlir/test/Dialect/Mesh/ops.mlir               | 16 ++++++++--------
 mlir/test/Dialect/Mesh/spmdization.mlir       |  4 ++--
 5 files changed, 23 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index e6f61aa84a1312..3c52c63330e95f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -1095,8 +1095,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
   TypesMatchWith<
     "result has same type as destination",
     "result", "destination", "$_self">,
-  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
-  AttrSizedOperandSegments
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>
 ]> {
   let summary = "Update halo data.";
   let description = [{
@@ -1105,7 +1104,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
     on the remote devices. Changes might be caused by mutating operations
     and/or if the new halo regions are larger than the existing ones.
 
-    Source and destination might have different halo sizes.
+    Destination is supposed to be initialized with the local data (not halos).
 
     Assumes all devices hold tensors with same-sized halo data as specified
     by `source_halo_sizes/static_source_halo_sizes` and
@@ -1117,25 +1116,21 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
 
   }];
   let arguments = (ins
-    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
     FlatSymbolRefAttr:$mesh,
     Mesh_MeshAxesArrayAttr:$split_axes,
-    Variadic<I64>:$source_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
-    Variadic<I64>:$destination_halo_sizes,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
+    Variadic<I64>:$halo_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes
   );
   let results = (outs
     AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
   );
   let assemblyFormat = [{
-    $source `into` $destination
+    $destination
     `on` $mesh
     `split_axes` `=` $split_axes
-    (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
-    (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
-    attr-dict `:` type($source) `->` type($result)
+    (`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
+    attr-dict `:` type($result)
   }];
   let extraClassDeclaration = [{
     MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index ea1323e43462cd..11d7c0e08f1a67 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -70,13 +70,13 @@ struct ConvertUpdateHaloOp
                            cast<IntegerAttr>(v.get<Attribute>()).getInt()));
     };
 
-    auto array = op.getInput();
-    auto rank = array.getType().getRank();
+    auto array = op.getDestination();
+    auto rank = cast<ShapedType>(array.getType()).getRank();
     auto opSplitAxes = op.getSplitAxes().getAxes();
     auto mesh = op.getMesh();
     auto meshOp = getMesh(op, symbolTableCollection);
     auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
-                                    op.getDynamicHaloSizes(), rewriter);
+                                    op.getHaloSizes(), rewriter);
     // subviews need Index values
     for (auto &sz : haloSizes) {
       if (sz.is<Value>()) {
@@ -94,7 +94,7 @@ struct ConvertUpdateHaloOp
     auto currHaloDim = -1; // halo sizes are provided for split dimensions only
     // we need the actual shape to compute offsets and sizes
     for (auto i = 0; i < rank; ++i) {
-      auto s = array.getType().getShape()[i];
+      auto s = cast<ShapedType>(array.getType()).getShape()[i];
       if (ShapedType::isDynamic(s)) {
         shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
       } else {
@@ -176,7 +176,7 @@ struct ConvertUpdateHaloOp
         auto hasTo = rewriter.create<arith::CmpIOp>(
             loc, arith::CmpIPredicate::sge, to, zero);
         auto buffer = rewriter.create<memref::AllocOp>(
-            loc, dimSizes, array.getType().getElementType());
+            loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
         // if has neighbor: copy halo data from array to buffer and send
         rewriter.create<scf::IfOp>(
             loc, hasTo, [&](OpBuilder &builder, Location loc) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b4d088cbd7088d..327ea0991e4e1e 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -496,11 +496,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
               sourceShard.getLoc(),
               RankedTensorType::get(outShape,
                                     sourceShard.getType().getElementType()),
-              sourceShard, initOprnd, mesh.getSymName(),
+              initOprnd, mesh.getSymName(),
               MeshAxesArrayAttr::get(builder.getContext(),
                                      sourceSharding.getSplitAxes()),
-              sourceSharding.getDynamicHaloSizes(),
-              sourceSharding.getStaticHaloSizes(),
               targetSharding.getDynamicHaloSizes(),
               targetSharding.getStaticHaloSizes())
           .getResult();
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index d8df01c3d6520d..978de4939ee77c 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -615,16 +615,16 @@ func.func @update_halo(
     // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
     %arg0 : memref<12x12xi8>) {
   // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
-  // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] into %[[ARG]] on @mesh0
+  // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0
   // CHECK-SAME: split_axes = {{\[\[}}0]]
-  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> -> memref<12x12xi8>
+  // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
   %c2 = arith.constant 2 : i64
-  %uh1 = mesh.update_halo %arg0 into %arg0 on @mesh0 split_axes = [[0]]
-    source_halo_sizes = [2, %c2] : memref<12x12xi8> -> memref<12x12xi8>
-  // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[ARG]] into %[[UH1]] on @mesh0
+  %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+    halo_sizes = [2, %c2] : memref<12x12xi8>
+  // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0
   // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
-  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> -> memref<12x12xi8>
-  %uh2 = mesh.update_halo %arg0 into %uh1 on @mesh0 split_axes = [[0], [1]]
-    source_halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> -> memref<12x12xi8>
+  // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8>
+  %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]]
+    halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8>
   return
 }
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 22ddb72569835d..c1b96fda0f4a74 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -226,7 +226,7 @@ func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1
   %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
   // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
   // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
-  // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64>
+  // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
   %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
   %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
   %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
@@ -242,7 +242,7 @@ func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200
   %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
   // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
   // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
-  // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64>
+  // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
   %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
   %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
   %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>

>From 919498bdfdadd031449dc334e2d75dcd6794b514 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 5 Nov 2024 16:55:08 +0100
Subject: [PATCH 07/15] clang-format

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 11d7c0e08f1a67..5d9ea9cfccf8d4 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -75,8 +75,8 @@ struct ConvertUpdateHaloOp
     auto opSplitAxes = op.getSplitAxes().getAxes();
     auto mesh = op.getMesh();
     auto meshOp = getMesh(op, symbolTableCollection);
-    auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
-                                    op.getHaloSizes(), rewriter);
+    auto haloSizes =
+        getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
     // subviews need Index values
     for (auto &sz : haloSizes) {
       if (sz.is<Value>()) {

>From 60de21f6a063637db4d5474b42afde078a7e05fc Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 5 Nov 2024 19:50:56 +0100
Subject: [PATCH 08/15] allow tensor as destination in UpdateHaloOp and fixing
 its tests

---
 mlir/include/mlir/Conversion/Passes.td        |  3 +-
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 23 ++++++-
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 65 ++++++++++++++++---
 3 files changed, 79 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2781fab917048d..43015ad5b11e65 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -891,7 +891,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
   let dependentDialects = [
     "memref::MemRefDialect",
     "mpi::MPIDialect",
-    "scf::SCFDialect"
+    "scf::SCFDialect",
+    "bufferization::BufferizationDialect"
   ];
 }
 
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 5d9ea9cfccf8d4..b1b58584aaae24 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
@@ -70,7 +71,16 @@ struct ConvertUpdateHaloOp
                            cast<IntegerAttr>(v.get<Attribute>()).getInt()));
     };
 
-    auto array = op.getDestination();
+    auto dest = op.getDestination();
+    auto dstShape = cast<ShapedType>(dest.getType()).getShape();
+    Value array = dest;
+    if (isa<RankedTensorType>(array.getType())) {
+      // If the destination is a memref, we need to cast it to a tensor
+      auto tensorType = MemRefType::get(
+          dstShape, cast<ShapedType>(array.getType()).getElementType());
+      array = rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array)
+                  .getResult();
+    }
     auto rank = cast<ShapedType>(array.getType()).getRank();
     auto opSplitAxes = op.getSplitAxes().getAxes();
     auto mesh = op.getMesh();
@@ -94,7 +104,7 @@ struct ConvertUpdateHaloOp
     auto currHaloDim = -1; // halo sizes are provided for split dimensions only
     // we need the actual shape to compute offsets and sizes
     for (auto i = 0; i < rank; ++i) {
-      auto s = cast<ShapedType>(array.getType()).getShape()[i];
+      auto s = dstShape[i];
       if (ShapedType::isDynamic(s)) {
         shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
       } else {
@@ -213,7 +223,14 @@ struct ConvertUpdateHaloOp
       // on to next halo
       --currHaloDim;
     }
-    rewriter.eraseOp(op);
+
+    if (isa<MemRefType>(op.getResult().getType())) {
+      rewriter.replaceOp(op, array);
+    } else {
+      assert(isa<RankedTensorType>(op.getResult().getType()));
+      rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
+                                 loc, op.getResult().getType(), array));
+    }
     return mlir::success();
   }
 };
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index c3b0dc12e6d746..d05c53bd83aaf9 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -53,7 +53,7 @@ func.func @update_halo_1d_first(
 // CHECK-LABEL: func @update_halo_1d_second
 func.func @update_halo_1d_second(
   // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
-  %arg0 : memref<12x12xi8>) {
+  %arg0 : memref<12x12xi8>) -> memref<12x12xi8> {
   //CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
   //CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
   //CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
@@ -91,16 +91,16 @@ func.func @update_halo_1d_second(
   //CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
   //CHECK-NEXT: }
   //CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
+  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
     halo_sizes = [2, 3] : memref<12x12xi8>
-  //CHECK-NEXT: return
-  return
+  //CHECK-NEXT: return [[varg0]] : memref<12x12xi8>
+  return %res : memref<12x12xi8>
 }
 
 // CHECK-LABEL: func @update_halo_2d
 func.func @update_halo_2d(
     // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
-    %arg0 : memref<12x12xi8>) {
+    %arg0 : memref<12x12xi8>) -> memref<12x12xi8> {
   // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
   // CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index
   // CHECK-NEXT: [[vc5:%.*]] = arith.constant 5 : index
@@ -172,9 +172,58 @@ func.func @update_halo_2d(
   // CHECK-NEXT:   memref.copy [[valloc_4]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
   // CHECK-NEXT: }
   // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<2x12xi8>
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
+  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
       halo_sizes = [1, 2, 3, 4]
       : memref<12x12xi8>
-  // CHECK-NEXT: return
-  return
+  // CHECK-NEXT: return [[varg0]] : memref<12x12xi8>
+  return %res : memref<12x12xi8>
+}
+
+// CHECK-LABEL: func @update_halo_1d_tnsr
+func.func @update_halo_1d_tnsr(
+  // CHECK-SAME: [[varg0:%.*]]: tensor<12x12xi8>
+  %arg0 : tensor<12x12xi8>) -> tensor<12x12xi8> {
+  // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
+  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+  // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[mref:%.*]] = bufferization.to_memref %arg0 : memref<12x12xi8>
+  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
+  // CHECK-NEXT: scf.if [[v3]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v2]] {
+  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8>
+  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
+  // CHECK-NEXT: scf.if [[v5]] {
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8>
+  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
+  // CHECK-NEXT: }
+  // CHECK-NEXT: scf.if [[v4]] {
+  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32
+  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
+  // CHECK-NEXT: [[res:%.*]] = bufferization.to_tensor [[mref]] : memref<12x12xi8>
+  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+    halo_sizes = [2, 3] : tensor<12x12xi8>
+  // CHECK-NEXT: return [[res]]
+  return %res : tensor<12x12xi8>
 }

>From a90787028263f632b63ce7f277e59c86c6bd890b Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 7 Nov 2024 13:17:00 +0100
Subject: [PATCH 09/15] converting LinearIndex, MultiIndex and NeighborsIndex
 to MPI

---
 .../mlir/Conversion/MeshToMPI/MeshToMPI.h     |   2 +-
 mlir/include/mlir/Conversion/Passes.td        |   1 +
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 128 +++++++++++++++++-
 3 files changed, 128 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
index 04271f8ab67b95..44a1cc0adb6a0c 100644
--- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -20,7 +20,7 @@ class Pass;
 
 /// Lowers Mesh communication operations (updateHalo, AllGater, ...)
 /// to MPI primitives.
-std::unique_ptr<Pass> createConvertMeshToMPIPass();
+std::unique_ptr<::mlir::Pass> createConvertMeshToMPIPass();
 
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 43015ad5b11e65..15fc13f5e12d83 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -884,6 +884,7 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
 
 def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
   let summary = "Convert Mesh dialect to MPI dialect.";
+  let constructor = "mlir::createConvertMeshToMPIPass()";
   let description = [{
     This pass converts communication operations
     from the Mesh dialect to the MPI dialect.
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b1b58584aaae24..0f0cc28ca363a2 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -37,6 +37,130 @@ using namespace mlir;
 using namespace mlir::mesh;
 
 namespace {
+static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex, ValueRange dimensions) {
+  int n = dimensions.size();
+  SmallVector<Value> multiIndex(n);
+
+  for (int i = n - 1; i >= 0; --i) {
+    b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
+    multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
+    if(i > 0) {
+      linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
+    }
+  }
+
+  return multiIndex;
+}
+
+Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, ValueRange dimensions) {
+  auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
+  auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+
+  for (int i = multiIndex.size() - 1; i >= 0; --i) {
+      linearIndex = b.create<arith::AddIOp>(loc, multiIndex[i], stride);
+      stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
+  }
+
+  return linearIndex;
+}
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertProcessMultiIndexOp
+    : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    auto loc = op.getLoc();
+    auto meshOp = getMesh(op, symbolTableCollection);
+    // For now we only support static mesh shapes
+    if(ShapedType::isDynamicShape(meshOp.getShape())) {
+      return mlir::failure();
+    }
+
+    SmallVector<Value> dims;
+    llvm::transform(meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+      return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
+    });
+    auto rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
+    auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
+
+    // optionally extract subset of mesh axes
+    auto axes = op.getAxes();
+    if(!axes.empty()) {
+      SmallVector<Value> subIndex;
+      for(auto axis : axes) {
+        subIndex.push_back(mIdx[axis]);
+      }
+      mIdx = subIndex;
+    }
+
+    rewriter.replaceOp(op, mIdx);
+    return mlir::success();
+  }
+};
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertProcessLinearIndexOp
+    : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto rank = rewriter.create<mpi::CommRankOp>(op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()), rewriter.getI32Type()}).getRank();
+    rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), rank);
+    return mlir::success();
+  }
+};
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertNeighborsLinearIndicesOp
+    : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto axes = op.getSplitAxes();
+    // For now only single axis sharding is supported
+    if(axes.size() != 1) {
+      return mlir::failure();
+    }
+
+    auto loc = op.getLoc();
+    SymbolTableCollection symbolTableCollection;
+    auto meshOp = getMesh(op, symbolTableCollection);
+    auto mIdx = op.getDevice();
+    auto orgIdx = mIdx[axes[0]];
+    SmallVector<Value> dims;
+    llvm::transform(meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+      return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
+    });
+    auto dimSz = dims[axes[0]];
+    auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
+    auto atBorder = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, dimSz, rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
+    auto down = rewriter.create<scf::IfOp>(
+        loc, atBorder, [&](OpBuilder &builder, Location loc) {
+          builder.create<scf::YieldOp>(loc, minus1);
+        }, [&](OpBuilder &builder, Location loc) {
+          mIdx[axes[0]] = rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, dimSz).getResult();
+          builder.create<scf::YieldOp>(loc, multiToLinearIndex(loc, rewriter, mIdx, dims));
+        });
+    atBorder = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dimSz, rewriter.create<arith::AddIOp>(loc, dimSz, minus1).getResult());
+    auto up = rewriter.create<scf::IfOp>(
+        loc, atBorder, [&](OpBuilder &builder, Location loc) {
+          builder.create<scf::YieldOp>(loc, minus1);
+        }, [&](OpBuilder &builder, Location loc) {
+          mIdx[axes[0]] = rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, dimSz).getResult();
+          builder.create<scf::YieldOp>(loc, multiToLinearIndex(loc, rewriter, mIdx, dims));
+        });
+    rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
+    return mlir::success();
+  }
+};
 
 // This pattern converts the mesh.update_halo operation to MPI calls
 struct ConvertUpdateHaloOp
@@ -244,7 +368,7 @@ struct ConvertMeshToMPIPass
     auto *ctx = &getContext();
     mlir::RewritePatternSet patterns(ctx);
 
-    patterns.insert<ConvertUpdateHaloOp>(ctx);
+    patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp, ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(ctx);
 
     (void)mlir::applyPatternsAndFoldGreedily(getOperation(),
                                              std::move(patterns));
@@ -254,6 +378,6 @@ struct ConvertMeshToMPIPass
 } // namespace
 
 // Create a pass that convert Mesh to MPI
-std::unique_ptr<::mlir::OperationPass<void>> createConvertMeshToMPIPass() {
+std::unique_ptr<::mlir::Pass> mlir::createConvertMeshToMPIPass() {
   return std::make_unique<ConvertMeshToMPIPass>();
 }

>From 7c3eddcac2c9cb9e84f2d59406d759174c4fa2bb Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 7 Nov 2024 16:57:50 +0100
Subject: [PATCH 10/15] allow constant shape propagation & fusion thoguh
 static_mpi_rank

---
 mlir/include/mlir/Conversion/Passes.td       |   8 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td |   1 +
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp  | 106 +++++++++++++------
 3 files changed, 83 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 15fc13f5e12d83..4d6be8d18d1fe6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -886,8 +886,12 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
   let summary = "Convert Mesh dialect to MPI dialect.";
   let constructor = "mlir::createConvertMeshToMPIPass()";
   let description = [{
-    This pass converts communication operations
-    from the Mesh dialect to the MPI dialect.
+    This pass converts communication operations from the Mesh dialect to the
+    MPI dialect.
+    If it finds a global named "static_mpi_rank" it will use that splat value
+    instead of calling MPI_Comm_rank. This allows optimizations like constant
+    shape propagation and fusion because shard/partition sizes depend on the
+    rank.
   }];
   let dependentDialects = [
     "memref::MemRefDialect",
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 3c52c63330e95f..726c92d6ec4697 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -1091,6 +1091,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 }
 
 def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
+  Pure,
   DestinationStyleOpInterface,
   TypesMatchWith<
     "result has same type as destination",
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 0f0cc28ca363a2..f20068c9a43dfd 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -18,11 +18,13 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "mesh-to-mpi"
@@ -37,14 +39,17 @@ using namespace mlir;
 using namespace mlir::mesh;
 
 namespace {
-static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex, ValueRange dimensions) {
+// Create operations converting a linear index to a multi-dimensional index
+static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
+                                             Value linearIndex,
+                                             ValueRange dimensions) {
   int n = dimensions.size();
   SmallVector<Value> multiIndex(n);
 
   for (int i = n - 1; i >= 0; --i) {
     b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
     multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
-    if(i > 0) {
+    if (i > 0) {
       linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
     }
   }
@@ -52,13 +57,16 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, Value li
   return multiIndex;
 }
 
-Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, ValueRange dimensions) {
+// Create operations converting a multi-dimensional index to a linear index
+Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
+                         ValueRange dimensions) {
+
   auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
   auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
 
   for (int i = multiIndex.size() - 1; i >= 0; --i) {
-      linearIndex = b.create<arith::AddIOp>(loc, multiIndex[i], stride);
-      stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
+    linearIndex = b.create<arith::AddIOp>(loc, multiIndex[i], stride);
+    stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
   }
 
   return linearIndex;
@@ -76,22 +84,24 @@ struct ConvertProcessMultiIndexOp
     auto loc = op.getLoc();
     auto meshOp = getMesh(op, symbolTableCollection);
     // For now we only support static mesh shapes
-    if(ShapedType::isDynamicShape(meshOp.getShape())) {
+    if (ShapedType::isDynamicShape(meshOp.getShape())) {
       return mlir::failure();
     }
 
     SmallVector<Value> dims;
-    llvm::transform(meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
-      return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
-    });
-    auto rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
+    llvm::transform(
+        meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+          return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
+        });
+    auto rank =
+        rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
     auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
 
     // optionally extract subset of mesh axes
     auto axes = op.getAxes();
-    if(!axes.empty()) {
+    if (!axes.empty()) {
       SmallVector<Value> subIndex;
-      for(auto axis : axes) {
+      for (auto axis : axes) {
         subIndex.push_back(mIdx[axis]);
       }
       mIdx = subIndex;
@@ -102,7 +112,9 @@ struct ConvertProcessMultiIndexOp
   }
 };
 
-// This pattern converts the mesh.update_halo operation to MPI calls
+// This pattern converts the mesh.update_halo operation to MPI calls.
+// If it finds a global named "static_mpi_rank" it will use that splat value.
+// Otherwise it defaults to mpi.comm_rank.
 struct ConvertProcessLinearIndexOp
     : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -110,8 +122,25 @@ struct ConvertProcessLinearIndexOp
   mlir::LogicalResult
   matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    auto rank = rewriter.create<mpi::CommRankOp>(op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()), rewriter.getI32Type()}).getRank();
-    rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), rank);
+    auto loc = op.getLoc();
+    auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
+    if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
+            op, rankOpName)) {
+      if (auto initTnsr = globalOp.getInitialValueAttr()) {
+        auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
+        rewriter.replaceOp(op,
+                           rewriter.create<arith::ConstantIndexOp>(loc, val));
+        return mlir::success();
+      }
+    }
+    auto rank =
+        rewriter
+            .create<mpi::CommRankOp>(
+                op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()),
+                                       rewriter.getI32Type()})
+            .getRank();
+    rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
+                                                    rank);
     return mlir::success();
   }
 };
@@ -126,7 +155,7 @@ struct ConvertNeighborsLinearIndicesOp
                   mlir::PatternRewriter &rewriter) const override {
     auto axes = op.getSplitAxes();
     // For now only single axis sharding is supported
-    if(axes.size() != 1) {
+    if (axes.size() != 1) {
       return mlir::failure();
     }
 
@@ -136,26 +165,41 @@ struct ConvertNeighborsLinearIndicesOp
     auto mIdx = op.getDevice();
     auto orgIdx = mIdx[axes[0]];
     SmallVector<Value> dims;
-    llvm::transform(meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
-      return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
-    });
+    llvm::transform(
+        meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+          return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
+        });
     auto dimSz = dims[axes[0]];
     auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
-    auto atBorder = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, dimSz, rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
+    auto atBorder = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sle, dimSz,
+        rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
     auto down = rewriter.create<scf::IfOp>(
-        loc, atBorder, [&](OpBuilder &builder, Location loc) {
+        loc, atBorder,
+        [&](OpBuilder &builder, Location loc) {
           builder.create<scf::YieldOp>(loc, minus1);
-        }, [&](OpBuilder &builder, Location loc) {
-          mIdx[axes[0]] = rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, dimSz).getResult();
-          builder.create<scf::YieldOp>(loc, multiToLinearIndex(loc, rewriter, mIdx, dims));
+        },
+        [&](OpBuilder &builder, Location loc) {
+          mIdx[axes[0]] =
+              rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, dimSz)
+                  .getResult();
+          builder.create<scf::YieldOp>(
+              loc, multiToLinearIndex(loc, rewriter, mIdx, dims));
         });
-    atBorder = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dimSz, rewriter.create<arith::AddIOp>(loc, dimSz, minus1).getResult());
+    atBorder = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sge, dimSz,
+        rewriter.create<arith::AddIOp>(loc, dimSz, minus1).getResult());
     auto up = rewriter.create<scf::IfOp>(
-        loc, atBorder, [&](OpBuilder &builder, Location loc) {
+        loc, atBorder,
+        [&](OpBuilder &builder, Location loc) {
           builder.create<scf::YieldOp>(loc, minus1);
-        }, [&](OpBuilder &builder, Location loc) {
-          mIdx[axes[0]] = rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, dimSz).getResult();
-          builder.create<scf::YieldOp>(loc, multiToLinearIndex(loc, rewriter, mIdx, dims));
+        },
+        [&](OpBuilder &builder, Location loc) {
+          mIdx[axes[0]] =
+              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, dimSz)
+                  .getResult();
+          builder.create<scf::YieldOp>(
+              loc, multiToLinearIndex(loc, rewriter, mIdx, dims));
         });
     rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
     return mlir::success();
@@ -368,7 +412,9 @@ struct ConvertMeshToMPIPass
     auto *ctx = &getContext();
     mlir::RewritePatternSet patterns(ctx);
 
-    patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp, ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(ctx);
+    patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
+                    ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
+        ctx);
 
     (void)mlir::applyPatternsAndFoldGreedily(getOperation(),
                                              std::move(patterns));

>From f0695fd50ded1d204a981b87fbf8e3f2ae7081f5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 7 Nov 2024 19:08:02 +0100
Subject: [PATCH 11/15] fixing sned/recv border check

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index f20068c9a43dfd..ee7c77b35b285c 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -47,7 +47,6 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
   SmallVector<Value> multiIndex(n);
 
   for (int i = n - 1; i >= 0; --i) {
-    b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
     multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
     if (i > 0) {
       linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
@@ -172,7 +171,7 @@ struct ConvertNeighborsLinearIndicesOp
     auto dimSz = dims[axes[0]];
     auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
     auto atBorder = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sle, dimSz,
+        loc, arith::CmpIPredicate::sle, orgIdx,
         rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
     auto down = rewriter.create<scf::IfOp>(
         loc, atBorder,
@@ -187,7 +186,7 @@ struct ConvertNeighborsLinearIndicesOp
               loc, multiToLinearIndex(loc, rewriter, mIdx, dims));
         });
     atBorder = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sge, dimSz,
+        loc, arith::CmpIPredicate::sge, orgIdx,
         rewriter.create<arith::AddIOp>(loc, dimSz, minus1).getResult());
     auto up = rewriter.create<scf::IfOp>(
         loc, atBorder,

>From 725c7343d1ce0af1e165665fbc8dab4c2162ad8b Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 8 Nov 2024 12:38:14 +0100
Subject: [PATCH 12/15] fixes and tests

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  19 +-
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   |  20 +-
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 421 +++++++++---------
 3 files changed, 226 insertions(+), 234 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 726c92d6ec4697..6039e61a93fadc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -162,20 +162,21 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
   DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
 ]> {
   let summary =
-      "For given split axes get the linear indices of the direct neighbor processes.";
+      "For given mesh index get the linear indices of the direct neighbor processes along the given split.";
   let description = [{
     Example:
     ```
-    %idx = mesh.neighbor_linear_index on @mesh for $device 
-               split_axes = $split_axes : index
+    mesh.mesh @mesh0(shape = 10x20x30)
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c3 = arith.constant 3 : index
+    %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index
     ```
-    Given `@mesh` with shape `(10, 20, 30)`,
-          `device` = `(1, 2, 3)`
-          `$split_axes` = `[1]`
-    returns two indices, `633` and `693`, which correspond to the index of the previous
-    process `(1, 1, 3)`, and the next process `(1, 3, 3) along the split axis `1`.
+    The above returns two indices, `633` and `693`, which correspond to the
+    index of the previous process `(1, 1, 3)`, and the next process 
+    `(1, 3, 3) along the split axis `1`.
 
-    A negative value is returned if `$device` has no neighbor in the given
+    A negative value is returned if there is no neighbor in the respective
     direction along the given `split_axes`.
   }];
   let arguments = (ins FlatSymbolRefAttr:$mesh,
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index ee7c77b35b285c..c51c5335fc6092 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -64,7 +64,8 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
   auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
 
   for (int i = multiIndex.size() - 1; i >= 0; --i) {
-    linearIndex = b.create<arith::AddIOp>(loc, multiIndex[i], stride);
+    auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+    linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
     stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
   }
 
@@ -169,6 +170,7 @@ struct ConvertNeighborsLinearIndicesOp
           return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
         });
     auto dimSz = dims[axes[0]];
+    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
     auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
     auto atBorder = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::sle, orgIdx,
@@ -179,26 +181,28 @@ struct ConvertNeighborsLinearIndicesOp
           builder.create<scf::YieldOp>(loc, minus1);
         },
         [&](OpBuilder &builder, Location loc) {
-          mIdx[axes[0]] =
-              rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, dimSz)
+          SmallVector<Value> tmp = mIdx;
+          tmp[axes[0]] =
+              rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one)
                   .getResult();
           builder.create<scf::YieldOp>(
-              loc, multiToLinearIndex(loc, rewriter, mIdx, dims));
+              loc, multiToLinearIndex(loc, rewriter, tmp, dims));
         });
     atBorder = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::sge, orgIdx,
-        rewriter.create<arith::AddIOp>(loc, dimSz, minus1).getResult());
+        rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult());
     auto up = rewriter.create<scf::IfOp>(
         loc, atBorder,
         [&](OpBuilder &builder, Location loc) {
           builder.create<scf::YieldOp>(loc, minus1);
         },
         [&](OpBuilder &builder, Location loc) {
-          mIdx[axes[0]] =
-              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, dimSz)
+          SmallVector<Value> tmp = mIdx;
+          tmp[axes[0]] =
+              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
                   .getResult();
           builder.create<scf::YieldOp>(
-              loc, multiToLinearIndex(loc, rewriter, mIdx, dims));
+              loc, multiToLinearIndex(loc, rewriter, tmp, dims));
         });
     rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
     return mlir::success();
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index d05c53bd83aaf9..38b7a12daef52b 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -1,229 +1,216 @@
-// RUN: mlir-opt %s -convert-mesh-to-mpi | FileCheck %s
+// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s
 
+// -----
 // CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 2x2x4)
+mesh.mesh @mesh0(shape = 3x4x5)
+func.func @process_multi_index() -> (index, index, index) {
+  // CHECK: mpi.comm_rank : !mpi.retval, i32
+  // CHECK-DAG: %[[v4:.*]] = arith.remsi
+  // CHECK-DAG: %[[v0:.*]] = arith.remsi
+  // CHECK-DAG: %[[v1:.*]] = arith.remsi
+  %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+  // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
 
-// CHECK-LABEL: func @update_halo_1d_first
-func.func @update_halo_1d_first(
-  // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
-  %arg0 : memref<12x12xi8>) {
-  // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
-  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
-  // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
-  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
-  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
-  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
-  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
-  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
-  // CHECK-NEXT: scf.if [[v3]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
-  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v2]] {
-  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8>
-  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
-  // CHECK-NEXT: scf.if [[v5]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v4]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
-  mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
-    halo_sizes = [2, 3] : memref<12x12xi8>
-  // CHECK-NEXT: return
-  return
+// CHECK-LABEL: func @process_linear_index
+func.func @process_linear_index() -> index {
+  // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank : !mpi.retval, i32
+  // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index
+  %0 = mesh.process_linear_index on @mesh0 : index
+  // CHECK: return %[[cast]] : index
+  return %0 : index
 }
 
-// CHECK-LABEL: func @update_halo_1d_second
-func.func @update_halo_1d_second(
-  // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
-  %arg0 : memref<12x12xi8>) -> memref<12x12xi8> {
-  //CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
-  //CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
-  //CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
-  //CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  //CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
-  //CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  //CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
-  //CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
-  //CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
-  //CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  //CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  //CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
-  //CHECK-NEXT: scf.if [[v3]] {
-  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c7] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
-  //CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
-  //CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
-  //CHECK-NEXT: }
-  //CHECK-NEXT: scf.if [[v2]] {
-  //CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
-  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
-  //CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
-  //CHECK-NEXT: }
-  //CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
-  //CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  //CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  //CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
-  //CHECK-NEXT: scf.if [[v5]] {
-  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c2] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
-  //CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1], offset: ?>> to memref<12x3xi8>
-  //CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
-  //CHECK-NEXT: }
-  //CHECK-NEXT: scf.if [[v4]] {
-  //CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
-  //CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
-  //CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
-  //CHECK-NEXT: }
-  //CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
-  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
-    halo_sizes = [2, 3] : memref<12x12xi8>
-  //CHECK-NEXT: return [[varg0]] : memref<12x12xi8>
-  return %res : memref<12x12xi8>
+// CHECK-LABEL: func @neighbors_dim0
+func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  // CHECK-DAG: [[up:%.*]] = arith.constant 44 : index
+  // CHECK-DAG: [[down:%.*]] = arith.constant 4 : index
+  %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index
+  // CHECK: return [[down]], [[up]] : index, index
+  return %idx#0, %idx#1 : index, index
+}
+
+// CHECK-LABEL: func @neighbors_dim1
+func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  // CHECK-DAG: [[up:%.*]] = arith.constant 29 : index
+  // CHECK-DAG: [[down:%.*]] = arith.constant -1 : index
+  %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index
+  // CHECK: return [[down]], [[up]] : index, index
+  return %idx#0, %idx#1 : index, index
+}
+
+// CHECK-LABEL: func @neighbors_dim2
+func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  // CHECK-DAG: [[up:%.*]] = arith.constant -1 : index
+  // CHECK-DAG: [[down:%.*]] = arith.constant 23 : index
+  %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index
+  // CHECK: return [[down]], [[up]] : index, index
+  return %idx#0, %idx#1 : index, index
+}
+
+// -----
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 3x4x5)
+memref.global constant @static_mpi_rank : memref<index> = dense<24>
+func.func @process_multi_index() -> (index, index, index) {
+  // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+  %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+  // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_linear_index
+func.func @process_linear_index() -> index {
+  // CHECK: %[[c24:.*]] = arith.constant 24 : index
+  %0 = mesh.process_linear_index on @mesh0 : index
+  // CHECK: return %[[c24]] : index
+  return %0 : index
+}
+
+// -----
+mesh.mesh @mesh0(shape = 3x4x5)
+// CHECK-LABEL: func @update_halo_1d_first
+func.func @update_halo_1d_first(
+  // CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8>
+  %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+  // CHECK: memref.subview [[arg0]][115, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+  // CHECK: mpi.send(
+  // CHECK-SAME: : memref<2x120x120xi8>, i32, i32
+  // CHECK: mpi.recv(
+  // CHECK-SAME: : memref<2x120x120xi8>, i32, i32
+  // CHECK-NEXT: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+  // CHECK: memref.subview [[arg0]][2, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
+  // CHECK: mpi.send(
+  // CHECK-SAME: : memref<3x120x120xi8>, i32, i32
+  // CHECK: mpi.recv(
+  // CHECK-SAME: : memref<3x120x120xi8>, i32, i32
+  // CHECK-NEXT: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
+  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
+  // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
+  return %res : memref<120x120x120xi8>
 }
 
-// CHECK-LABEL: func @update_halo_2d
-func.func @update_halo_2d(
-    // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
-    %arg0 : memref<12x12xi8>) -> memref<12x12xi8> {
-  // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
-  // CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index
-  // CHECK-NEXT: [[vc5:%.*]] = arith.constant 5 : index
-  // CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index
-  // CHECK-NEXT: [[vc3:%.*]] = arith.constant 3 : index
-  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+// -----
+mesh.mesh @mesh0(shape = 3x4x5)
+memref.global constant @static_mpi_rank : memref<index> = dense<24>
+// CHECK-LABEL: func @update_halo_3d
+func.func @update_halo_3d(
+  // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+  %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+  // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
+  // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
   // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
-  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
-  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
-  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
-  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
-  // CHECK-NEXT: scf.if [[v3]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c5] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
-  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<?x3xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v2]] {
-  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<?x3xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
-  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<?x3xi8>
-  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
-  // CHECK-NEXT: scf.if [[v5]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c3] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<?x4xi8, strided<[12, 1], offset: ?>> to memref<?x4xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<?x4xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v4]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<?x4xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<?x4xi8>
-  // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
-  // CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32
-  // CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32
-  // CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc() : memref<1x12xi8>
-  // CHECK-NEXT: scf.if [[v9]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_3]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<1x12xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v8]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<1x12xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[valloc_3]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<1x12xi8>
-  // CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<2x12xi8>
-  // CHECK-NEXT: scf.if [[v11]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc1]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_4]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<2x12xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v10]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<2x12xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[valloc_4]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<2x12xi8>
-  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
-      halo_sizes = [1, 2, 3, 4]
-      : memref<12x12xi8>
-  // CHECK-NEXT: return [[varg0]] : memref<12x12xi8>
-  return %res : memref<12x12xi8>
+  // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+  // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+  // CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref<?x?x5xi8>
+  // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+  // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+  // CHECK-NEXT: mpi.send([[vcast]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x5xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[vcast]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x5xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+  // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+  // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+  // CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref<?x?x6xi8>
+  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+  // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+  // CHECK-NEXT: mpi.send([[vcast_2]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x6xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[vcast_2]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x6xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+  // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+  // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
+  // CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref<?x3x120xi8>
+  // CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref<?x3x120xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
+  // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
+  // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
+  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
+  // CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref<?x4x120xi8>
+  // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
+  // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
+  // CHECK-NEXT: mpi.send([[vcast_9]], [[vc91_i32]], [[vc29_i32]]) : memref<?x4x120xi8>, i32, i32
+  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
+  // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
+  // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
+  // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
+  // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+  // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
+  // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
+  // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+  // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+  // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
+  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
+  // CHECK: return [[varg0]] : memref<120x120x120xi8>
+  return %res : memref<120x120x120xi8>
 }
 
-// CHECK-LABEL: func @update_halo_1d_tnsr
-func.func @update_halo_1d_tnsr(
-  // CHECK-SAME: [[varg0:%.*]]: tensor<12x12xi8>
-  %arg0 : tensor<12x12xi8>) -> tensor<12x12xi8> {
-  // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
-  // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
-  // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
-  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+// CHECK-LABEL: func @update_halo_3d_tensor
+func.func @update_halo_3d_tensor(
+  // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
+  %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
+  // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
+  // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+  // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+  // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
   // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  // CHECK-NEXT: [[mref:%.*]] = bufferization.to_memref %arg0 : memref<12x12xi8>
-  // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
-  // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
-  // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
-  // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
-  // CHECK-NEXT: scf.if [[v3]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
-  // CHECK-NEXT:   mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v2]] {
-  // CHECK-NEXT:   mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT:   memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8>
-  // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
-  // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
-  // CHECK-NEXT: scf.if [[v5]] {
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8>
-  // CHECK-NEXT:   mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
-  // CHECK-NEXT: }
-  // CHECK-NEXT: scf.if [[v4]] {
-  // CHECK-NEXT:   mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32
-  // CHECK-NEXT:   [[vsubview:%.*]] = memref.subview [[mref]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT:   memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
-  // CHECK-NEXT: [[res:%.*]] = bufferization.to_tensor [[mref]] : memref<12x12xi8>
-  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
-    halo_sizes = [2, 3] : tensor<12x12xi8>
-  // CHECK-NEXT: return [[res]]
-  return %res : tensor<12x12xi8>
+  // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : memref<120x120x120xi8>
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+  // CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref<?x?x5xi8>
+  // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+  // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+  // CHECK-NEXT: mpi.send([[vcast]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x5xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[vcast]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x5xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+  // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+  // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+  // CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref<?x?x6xi8>
+  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+  // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+  // CHECK-NEXT: mpi.send([[vcast_2]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x6xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[vcast_2]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x6xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+  // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+  // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
+  // CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref<?x3x120xi8>
+  // CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref<?x3x120xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
+  // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
+  // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
+  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
+  // CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref<?x4x120xi8>
+  // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
+  // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
+  // CHECK-NEXT: mpi.send([[vcast_9]], [[vc91_i32]], [[vc29_i32]]) : memref<?x4x120xi8>, i32, i32
+  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
+  // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
+  // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
+  // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
+  // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+  // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
+  // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
+  // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+  // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+  // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
+  // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] : memref<120x120x120xi8>
+  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
+  // CHECK: return [[v1]] : tensor<120x120x120xi8>
+  return %res : tensor<120x120x120xi8>
 }

>From b5013d0a542e78714df9925083c8d5e78433373c Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 8 Nov 2024 18:52:09 +0100
Subject: [PATCH 13/15] using restrict

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index c51c5335fc6092..1c82881f67d3ae 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -400,7 +400,8 @@ struct ConvertUpdateHaloOp
     } else {
       assert(isa<RankedTensorType>(op.getResult().getType()));
       rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
-                                 loc, op.getResult().getType(), array));
+                                 loc, op.getResult().getType(), array,
+                                 /*restrict=*/true, /*writable=*/true));
     }
     return mlir::success();
   }

>From 1ad7725a3aa4a0fa0b9bc1a8fd8e07d33bfe3a51 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 11 Nov 2024 16:45:31 +0100
Subject: [PATCH 14/15] canonicalizing send and recv towrads static memref
 shapes

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td    |  2 +
 mlir/lib/Dialect/MPI/IR/MPIOps.cpp            | 40 +++++++++++++++++++
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 34 ++++++----------
 3 files changed, 55 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 768f376e24da4c..240fac5104c34f 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -84,6 +84,7 @@ def MPI_SendOp : MPI_Op<"send", []> {
   let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
                        "type($ref) `,` type($tag) `,` type($rank)"
                        "(`->` type($retval)^)?";
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -114,6 +115,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
   let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
                        "type($ref) `,` type($tag) `,` type($rank)"
                        "(`->` type($retval)^)?";
+  let hasCanonicalizer = 1;
 }
 
 
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index ddd77b8f586ee0..dcb55d8921364f 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -7,12 +7,52 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/PatternMatch.h"
 
 using namespace mlir;
 using namespace mlir::mpi;
 
+namespace {
+
+// If input memref has dynamic shape and is a cast and if the cast's input has
+// static shape, fold the cast's static input into the given operation.
+template <typename OpT>
+struct FoldCast final : public mlir::OpRewritePattern<OpT> {
+  using mlir::OpRewritePattern<OpT>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpT op,
+                                mlir::PatternRewriter &b) const override {
+    auto mRef = op.getRef();
+    if (mRef.getType().hasStaticShape()) {
+      return mlir::failure();
+    }
+    auto defOp = mRef.getDefiningOp();
+    if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
+      return mlir::failure();
+    }
+    auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
+    if (!src.getType().hasStaticShape()) {
+      return mlir::failure();
+    }
+    op.getRefMutable().assign(src);
+    return mlir::success();
+  }
+};
+} // namespace
+
+void mlir::mpi::SendOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+  results.add<FoldCast<mlir::mpi::SendOp>>(context);
+}
+
+void mlir::mpi::RecvOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+  results.add<FoldCast<mlir::mpi::RecvOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 38b7a12daef52b..25d585a108c8ae 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -115,34 +115,30 @@ func.func @update_halo_3d(
   // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
   // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
   // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
-  // CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref<?x?x5xi8>
   // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
   // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-  // CHECK-NEXT: mpi.send([[vcast]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x5xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[vcast]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x5xi8>, i32, i32
+  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
   // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
   // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
   // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
   // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
-  // CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref<?x?x6xi8>
   // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
   // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-  // CHECK-NEXT: mpi.send([[vcast_2]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x6xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[vcast_2]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x6xi8>, i32, i32
+  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
   // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
   // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
   // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
   // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
-  // CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref<?x3x120xi8>
-  // CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref<?x3x120xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
   // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
   // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
   // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
   // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
-  // CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref<?x4x120xi8>
   // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
   // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
-  // CHECK-NEXT: mpi.send([[vcast_9]], [[vc91_i32]], [[vc29_i32]]) : memref<?x4x120xi8>, i32, i32
+  // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
   // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
   // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
   // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
@@ -170,34 +166,30 @@ func.func @update_halo_3d_tensor(
   // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
   // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : memref<120x120x120xi8>
   // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
-  // CHECK-NEXT: [[vcast:%.*]] = memref.cast [[valloc]] : memref<117x113x5xi8> to memref<?x?x5xi8>
   // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
   // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-  // CHECK-NEXT: mpi.send([[vcast]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x5xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[vcast]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x5xi8>, i32, i32
+  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
   // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
   // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
   // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
   // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
-  // CHECK-NEXT: [[vcast_2:%.*]] = memref.cast [[valloc_1]] : memref<117x113x6xi8> to memref<?x?x6xi8>
   // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
   // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-  // CHECK-NEXT: mpi.send([[vcast_2]], [[vc91_i32]], [[vc44_i32]]) : memref<?x?x6xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[vcast_2]], [[vc91_i32]], [[vc4_i32]]) : memref<?x?x6xi8>, i32, i32
+  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
   // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
   // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
   // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
   // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
-  // CHECK-NEXT: [[vcast_6:%.*]] = memref.cast [[valloc_5]] : memref<117x3x120xi8> to memref<?x3x120xi8>
-  // CHECK-NEXT: mpi.recv([[vcast_6]], [[vc91_i32]], [[vc29_i32]]) : memref<?x3x120xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
   // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
   // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
   // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
   // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
-  // CHECK-NEXT: [[vcast_9:%.*]] = memref.cast [[valloc_8]] : memref<117x4x120xi8> to memref<?x4x120xi8>
   // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
   // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
-  // CHECK-NEXT: mpi.send([[vcast_9]], [[vc91_i32]], [[vc29_i32]]) : memref<?x4x120xi8>, i32, i32
+  // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
   // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
   // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
   // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
@@ -209,7 +201,7 @@ func.func @update_halo_3d_tensor(
   // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
   // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
   // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
-  // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] : memref<120x120x120xi8>
+  // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8>
   %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
   // CHECK: return [[v1]] : tensor<120x120x120xi8>
   return %res : tensor<120x120x120xi8>

>From c48c7f0b25ec0b5f9a22bbdba28c79057c298fb0 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 27 Nov 2024 16:48:25 +0100
Subject: [PATCH 15/15] fixing comments

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 1c82881f67d3ae..6dd89ecf4d5c2d 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -72,7 +72,6 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
   return linearIndex;
 }
 
-// This pattern converts the mesh.update_halo operation to MPI calls
 struct ConvertProcessMultiIndexOp
     : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -80,6 +79,9 @@ struct ConvertProcessMultiIndexOp
   mlir::LogicalResult
   matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
                   mlir::PatternRewriter &rewriter) const override {
+
+    // Currently converts its linear index to a multi-dimensional index.
+
     SymbolTableCollection symbolTableCollection;
     auto loc = op.getLoc();
     auto meshOp = getMesh(op, symbolTableCollection);
@@ -112,9 +114,6 @@ struct ConvertProcessMultiIndexOp
   }
 };
 
-// This pattern converts the mesh.update_halo operation to MPI calls.
-// If it finds a global named "static_mpi_rank" it will use that splat value.
-// Otherwise it defaults to mpi.comm_rank.
 struct ConvertProcessLinearIndexOp
     : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -122,6 +121,10 @@ struct ConvertProcessLinearIndexOp
   mlir::LogicalResult
   matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
                   mlir::PatternRewriter &rewriter) const override {
+
+    // Finds a global named "static_mpi_rank" it will use that splat value.
+    // Otherwise it defaults to mpi.comm_rank.
+
     auto loc = op.getLoc();
     auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
     if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
@@ -145,7 +148,6 @@ struct ConvertProcessLinearIndexOp
   }
 };
 
-// This pattern converts the mesh.update_halo operation to MPI calls
 struct ConvertNeighborsLinearIndicesOp
     : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -153,6 +155,11 @@ struct ConvertNeighborsLinearIndicesOp
   mlir::LogicalResult
   matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
                   mlir::PatternRewriter &rewriter) const override {
+
+    // Computes the neighbors indices along a split axis by simply
+    // adding/subtracting 1 to the current index in that dimension.
+    // Assigns -1 if neighbor is out of bounds.
+
     auto axes = op.getSplitAxes();
     // For now only single axis sharding is supported
     if (axes.size() != 1) {
@@ -209,7 +216,6 @@ struct ConvertNeighborsLinearIndicesOp
   }
 };
 
-// This pattern converts the mesh.update_halo operation to MPI calls
 struct ConvertUpdateHaloOp
     : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -217,6 +223,7 @@ struct ConvertUpdateHaloOp
   mlir::LogicalResult
   matchAndRewrite(mlir::mesh::UpdateHaloOp op,
                   mlir::PatternRewriter &rewriter) const override {
+
     // The input/output memref is assumed to be in C memory order.
     // Halos are exchanged as 2 blocks per dimension (one for each side: down
     // and up). For each haloed dimension `d`, the exchanged blocks are



More information about the Mlir-commits mailing list