[Mlir-commits] [mlir] More on MeshToMPI (PR #129048)

Frank Schlimbach llvmlistbot at llvm.org
Thu Feb 27 05:09:56 PST 2025


https://github.com/fschlimb created https://github.com/llvm/llvm-project/pull/129048

- do not create MPI operations if no halo exchange is needed on a given boundary
- allow returning sharding information by return `!mesh.sharding` (gets converted into a tuple of tensors)
- lowering `mesh.shard_shape` including fixes to the operation itself
- global symbol `static_mpi_rank` replaced by an DLTI attribute (now aligned with MPIToLLVM)
- smaller fixes and some minor cleanup

>From 8dabd5755d095bfa1b231375f9e1907f4e0c3bb5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 13 Dec 2024 16:24:43 +0100
Subject: [PATCH 1/8] handling empty halos when lowering mesh.update_halos

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 48b3764d520c2..e36498b72009c 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -236,6 +236,14 @@ struct ConvertUpdateHaloOp
     // local data. Because subviews and halos can have mixed dynamic and static
     // shapes, OpFoldResults are used whenever possible.
 
+    auto haloSizes =
+        getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
+    if (haloSizes.empty()) {
+      // no halos -> nothing to do
+      rewriter.replaceOp(op, op.getDestination());
+      return mlir::success();
+    }
+
     SymbolTableCollection symbolTableCollection;
     auto loc = op.getLoc();
 
@@ -262,8 +270,6 @@ struct ConvertUpdateHaloOp
     auto opSplitAxes = op.getSplitAxes().getAxes();
     auto mesh = op.getMesh();
     auto meshOp = getMesh(op, symbolTableCollection);
-    auto haloSizes =
-        getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
     // subviews need Index values
     for (auto &sz : haloSizes) {
       if (auto value = dyn_cast<Value>(sz)) {

>From 29d7fb03b08c03d82d171bb286b7cf6ca688d0d4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 19 Dec 2024 13:18:07 +0100
Subject: [PATCH 2/8] fixing send/recv direction for halo updates; don't create
 send/recv for empty halos

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 23 ++++++++++++++----
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 24 +++++++++++++++++++
 2 files changed, 43 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index e36498b72009c..9fa718f1f1acb 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -362,8 +362,8 @@ struct ConvertUpdateHaloOp
                                   : haloSizes[currHaloDim * 2];
         // Check if we need to send and/or receive
         // Processes on the mesh borders have only one neighbor
-        auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
-        auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+        auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+        auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
         auto hasFrom = rewriter.create<arith::CmpIOp>(
             loc, arith::CmpIPredicate::sge, from, zero);
         auto hasTo = rewriter.create<arith::CmpIOp>(
@@ -396,8 +396,23 @@ struct ConvertUpdateHaloOp
         offsets[dim] = orgOffset;
       };
 
-      genSendRecv(false);
-      genSendRecv(true);
+      auto get_i32val = [&](OpFoldResult &v) {
+        return v.is<Value>()
+                   ? v.get<Value>()
+                   : rewriter.create<::mlir::arith::ConstantOp>(
+                         loc, rewriter.getI32IntegerAttr(
+                                  cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+      };
+
+      for (int i=0; i<2; ++i) {
+        Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);                              
+        auto hasSize = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, haloSz, zero);
+        rewriter.create<scf::IfOp>(
+            loc, hasSize, [&](OpBuilder &builder, Location loc) {
+              genSendRecv(i > 0);
+              builder.create<scf::YieldOp>(loc);
+            });
+      }
 
       // the shape for lower dims include higher dims' halos
       dimSizes[dim] = shape[dim];
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index c1aef97438bd5..04832ce11b7b3 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -102,6 +102,30 @@ func.func @update_halo_1d_first(
   return %res : memref<120x120x120xi8>
 }
 
+// -----
+mesh.mesh @mesh0(shape = 4)
+memref.global "public" constant @static_mpi_rank : memref<index> = dense<1>
+// CHECK-LABEL: func @update_halo_1d_with_zero
+func.func @update_halo_1d_with_zero (
+  // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+  %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+  // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
+  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+  // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
+  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
+  // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+  // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
+  // CHECK-SAME: to memref<2x120x120xi8>
+  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+  // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
+  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
+  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
+  // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
+  return %res : memref<120x120x120xi8>
+}
+
 // -----
 mesh.mesh @mesh0(shape = 3x4x5)
 memref.global constant @static_mpi_rank : memref<index> = dense<24>

>From d9bd663f5a0aa2b23510b32ac615a0d008e37c52 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 20 Jan 2025 13:57:44 +0100
Subject: [PATCH 3/8] adding type conversion of Sharding to 3 tensors

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 292 ++++++++++++++++--
 .../MeshToMPI/convert-mesh-to-mpi.mlir        |  68 ++++
 2 files changed, 328 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 9fa718f1f1acb..91f235673f817 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -14,8 +14,12 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -25,6 +29,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "mesh-to-mpi"
@@ -72,13 +77,181 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
   return linearIndex;
 }
 
+// replace GetShardingOp with related ShardingOp
+struct ConvertGetShardingOp
+    : public mlir::OpConversionPattern<mlir::mesh::GetShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::mesh::GetShardingOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+    if (!shardOp) {
+      return mlir::failure();
+    }
+    auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+    if (!shardingOp) {
+      return mlir::failure();
+    }
+
+    rewriter.replaceOp(op, shardingOp.getResult());
+    return mlir::success();
+  }
+};
+
+struct ConvertShardingOp
+    : public mlir::OpConversionPattern<mlir::mesh::ShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::mesh::ShardingOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    auto splitAxes = op.getSplitAxes().getAxes();
+    int64_t maxNAxes = 0;
+    for (auto axes : splitAxes) {
+      maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+    }
+
+    // To hold the split axes, create empty 2d tensor with shape
+    // {splitAxes.size(), max-size-of-split-groups}.
+    // Set trailing elements for smaller split-groups to -1.
+    auto loc = op.getLoc();
+    auto i16 = rewriter.getI16Type();
+    auto i64 = rewriter.getI64Type();
+    int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+    Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+    auto attr = IntegerAttr::get(i16, 0xffff);
+    auto fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+    resSplitAxes =
+        rewriter
+            .create<linalg::FillOp>(loc, fillValue.getResult(), resSplitAxes)
+            .getResult(0);
+
+    // explicitly write values into tensor row by row
+    int64_t strides[] = {1, 1};
+    int64_t nSplits = 0;
+    ValueRange empty = {};
+    for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+      int64_t size = axes.size();
+      if (size > 0) {
+        ++nSplits;
+      }
+      int64_t offs[] = {(int64_t)i, 0};
+      int64_t sizes[] = {1, size};
+      auto tensorType = RankedTensorType::get({size}, i16);
+      auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+      auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
+      resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
+          loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+    }
+
+    // convert vec of OpFoldResults (ints) into vec of Values
+    auto getMixedAsValues = [&](llvm::ArrayRef<int64_t> statics,
+                                ValueRange dynamics) {
+      SmallVector<Value> values;
+      auto dyn = dynamics.begin();
+      for (auto s : statics) {
+        values.emplace_back(
+            ShapedType::isDynamic(s)
+                ? *(dyn++)
+                : rewriter
+                      .create<arith::ConstantOp>(loc, i64,
+                                                 rewriter.getI64IntegerAttr(s))
+                      .getResult());
+      }
+      return values;
+    };
+
+    // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+    // Store the halo sizes in the tensor.
+    auto haloSizes = getMixedAsValues(adaptor.getStaticHaloSizes(),
+                                      adaptor.getDynamicHaloSizes());
+    auto type = RankedTensorType::get({nSplits, 2}, i64);
+    Value resHaloSizes =
+        haloSizes.empty()
+            ? rewriter
+                  .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
+                                           i64)
+                  .getResult()
+            : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+                  .getResult();
+
+    // To hold sharded dims offsets, create Tensor with shape {nSplits,
+    // maxSplitSize+1}. Store the offsets in the tensor but set trailing
+    // elements for smaller split-groups to -1. Computing the max size of the
+    // split groups needs using collectiveProcessGroupSize (which needs the
+    // MeshOp)
+    Value resOffsets;
+    if (adaptor.getStaticShardedDimsOffsets().empty()) {
+      resOffsets = rewriter.create<tensor::EmptyOp>(
+          loc, std::array<int64_t, 2>{0, 0}, i64);
+    } else {
+      SymbolTableCollection symbolTableCollection;
+      auto meshOp = getMesh(op, symbolTableCollection);
+      auto maxSplitSize = 0;
+      for (auto axes : splitAxes) {
+        auto splitSize =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        assert(splitSize != ShapedType::kDynamic);
+        maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
+      }
+      assert(maxSplitSize);
+      ++maxSplitSize; // add one for the total size
+
+      resOffsets = rewriter.create<tensor::EmptyOp>(
+          loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+      auto zero =
+          rewriter
+              .create<arith::ConstantOp>(
+                  loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic))
+              .getResult();
+      resOffsets =
+          rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+      auto offsets = getMixedAsValues(adaptor.getStaticShardedDimsOffsets(),
+                                      adaptor.getDynamicShardedDimsOffsets());
+      int64_t curr = 0;
+      for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+        auto splitSize =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
+        ++splitSize; // add one for the total size
+        ArrayRef<Value> values(&offsets[curr], splitSize);
+        auto vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+        int64_t offs[] = {(int64_t)i, 0};
+        int64_t sizes[] = {1, splitSize};
+        resOffsets = rewriter.create<tensor::InsertSliceOp>(
+            loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
+        curr += splitSize;
+      }
+    }
+
+    // return a tuple of tensors as defined by type converter
+    SmallVector<Type> resTypes;
+    if (failed(getTypeConverter()->convertType(op.getResult().getType(),
+                                               resTypes))) {
+      return mlir::failure();
+    }
+    resSplitAxes =
+        rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
+    resHaloSizes =
+        rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
+    resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
+
+    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+        op, TupleType::get(op.getContext(), resTypes),
+        ValueRange{resSplitAxes, resHaloSizes, resOffsets});
+
+    return mlir::success();
+  }
+};
+
 struct ConvertProcessMultiIndexOp
-    : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public mlir::OpConversionPattern<mlir::mesh::ProcessMultiIndexOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
 
     // Currently converts its linear index to a multi-dimensional index.
 
@@ -100,7 +273,7 @@ struct ConvertProcessMultiIndexOp
     auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
 
     // optionally extract subset of mesh axes
-    auto axes = op.getAxes();
+    auto axes = adaptor.getAxes();
     if (!axes.empty()) {
       SmallVector<Value> subIndex;
       for (auto axis : axes) {
@@ -115,12 +288,12 @@ struct ConvertProcessMultiIndexOp
 };
 
 struct ConvertProcessLinearIndexOp
-    : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public mlir::OpConversionPattern<mlir::mesh::ProcessLinearIndexOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
 
     // Finds a global named "static_mpi_rank" it will use that splat value.
     // Otherwise it defaults to mpi.comm_rank.
@@ -149,18 +322,18 @@ struct ConvertProcessLinearIndexOp
 };
 
 struct ConvertNeighborsLinearIndicesOp
-    : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public mlir::OpConversionPattern<mlir::mesh::NeighborsLinearIndicesOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
 
     // Computes the neighbors indices along a split axis by simply
     // adding/subtracting 1 to the current index in that dimension.
     // Assigns -1 if neighbor is out of bounds.
 
-    auto axes = op.getSplitAxes();
+    auto axes = adaptor.getSplitAxes();
     // For now only single axis sharding is supported
     if (axes.size() != 1) {
       return mlir::failure();
@@ -169,7 +342,7 @@ struct ConvertNeighborsLinearIndicesOp
     auto loc = op.getLoc();
     SymbolTableCollection symbolTableCollection;
     auto meshOp = getMesh(op, symbolTableCollection);
-    auto mIdx = op.getDevice();
+    auto mIdx = adaptor.getDevice();
     auto orgIdx = mIdx[axes[0]];
     SmallVector<Value> dims;
     llvm::transform(
@@ -217,12 +390,12 @@ struct ConvertNeighborsLinearIndicesOp
 };
 
 struct ConvertUpdateHaloOp
-    : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public mlir::OpConversionPattern<mlir::mesh::UpdateHaloOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::UpdateHaloOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(mlir::mesh::UpdateHaloOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
 
     // The input/output memref is assumed to be in C memory order.
     // Halos are exchanged as 2 blocks per dimension (one for each side: down
@@ -236,11 +409,11 @@ struct ConvertUpdateHaloOp
     // local data. Because subviews and halos can have mixed dynamic and static
     // shapes, OpFoldResults are used whenever possible.
 
-    auto haloSizes =
-        getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
+    auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
+                                    adaptor.getHaloSizes(), rewriter);
     if (haloSizes.empty()) {
       // no halos -> nothing to do
-      rewriter.replaceOp(op, op.getDestination());
+      rewriter.replaceOp(op, adaptor.getDestination());
       return mlir::success();
     }
 
@@ -256,7 +429,7 @@ struct ConvertUpdateHaloOp
                    cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
     };
 
-    auto dest = op.getDestination();
+    auto dest = adaptor.getDestination();
     auto dstShape = cast<ShapedType>(dest.getType()).getShape();
     Value array = dest;
     if (isa<RankedTensorType>(array.getType())) {
@@ -267,8 +440,8 @@ struct ConvertUpdateHaloOp
                   .getResult();
     }
     auto rank = cast<ShapedType>(array.getType()).getRank();
-    auto opSplitAxes = op.getSplitAxes().getAxes();
-    auto mesh = op.getMesh();
+    auto opSplitAxes = adaptor.getSplitAxes().getAxes();
+    auto mesh = adaptor.getMesh();
     auto meshOp = getMesh(op, symbolTableCollection);
     // subviews need Index values
     for (auto &sz : haloSizes) {
@@ -440,14 +613,69 @@ struct ConvertMeshToMPIPass
 
   /// Run the dialect converter on the module.
   void runOnOperation() override {
-    auto *ctx = &getContext();
-    mlir::RewritePatternSet patterns(ctx);
-
-    patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
-                    ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
-        ctx);
+    auto *ctxt = &getContext();
+    mlir::RewritePatternSet patterns(ctxt);
+    ConversionTarget target(getContext());
+    mlir::TypeConverter typeConverter;
+
+    typeConverter.addConversion([](Type type) { return type; });
+    // convert mesh::ShardingType to a tuple of RankedTensorTypes
+    typeConverter.addConversion(
+        [](ShardingType type,
+           SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+          auto i16 = IntegerType::get(type.getContext(), 16);
+          auto i64 = IntegerType::get(type.getContext(), 64);
+          std::array<int64_t, 2> shp{ShapedType::kDynamic,
+                                     ShapedType::kDynamic};
+          results.emplace_back(RankedTensorType::get(shp, i16));
+          results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+          results.emplace_back(RankedTensorType::get(shp, i64));
+          return mlir::success();
+        });
+    // To extract components a UnrealizedConversionCastOp is expected to define
+    // the input
+    typeConverter.addTargetMaterialization(
+        [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+            Location loc) {
+          if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType())) {
+            return SmallVector<Value>();
+          }
+          auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
+          if (!castOp) {
+            return SmallVector<Value>();
+          }
+          SmallVector<Value> results;
+          for (auto oprnd : castOp.getInputs()) {
+            if (!isa<RankedTensorType>(oprnd.getType())) {
+              return SmallVector<Value>();
+            }
+            results.emplace_back(oprnd);
+          }
+          return results;
+        });
 
-    (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
+    target.addIllegalDialect<mesh::MeshDialect>();
+    target.addLegalOp<mesh::MeshOp>();
+    target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
+                           arith::ArithDialect, tensor::TensorDialect,
+                           bufferization::BufferizationDialect,
+                           linalg::LinalgDialect, memref::MemRefDialect>();
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+        [&](Operation *op) { return typeConverter.isLegal(op); });
+
+    patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
+                 ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp,
+                 ConvertShardingOp, ConvertGetShardingOp>(typeConverter, ctxt);
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+        patterns, typeConverter);
+    populateCallOpTypeConversionPattern(patterns, typeConverter);
+    populateReturnOpTypeConversionPattern(patterns, typeConverter);
+
+    (void)mlir::applyPartialConversion(getOperation(), target,
+                                       std::move(patterns));
   }
 };
 
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 04832ce11b7b3..4ed6a2ffcd683 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -230,3 +230,71 @@ func.func @update_halo_3d_tensor(
   // CHECK: return [[v1]] : tensor<120x120x120xi8>
   return %res : tensor<120x120x120xi8>
 }
+
+// -----
+mesh.mesh @mesh0(shape = 2x2x4)
+// CHECK-LABEL: func.func @return_sharding(
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding
+  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+  // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor<?x?xi16>
+  // CHECK-NEXT: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+  return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding
+}
+
+// CHECK-LABEL: func.func @return_sharding_halos(
+// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding
+  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
+  // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+  // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor<?x?xi16>
+  // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+  return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding
+}
+
+// CHECK-LABEL: func.func @return_sharding_offs(
+// CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !mesh.sharding) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding
+  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
+  // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
+  // CHECK-NEXT: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
+  // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK-NEXT: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+  // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64>
+  // CHECK-NEXT: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64>
+  // CHECK-NEXT: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
+  // CHECK-NEXT: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
+  // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor<?x?xi16>
+  // CHECK-NEXT: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+  return %arg0, %sharding : tensor<?x?xf32>, !mesh.sharding
+}

>From d4dbd746e0b2f8fc728c00c0053923146d9c74f1 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 24 Jan 2025 11:29:09 +0100
Subject: [PATCH 4/8] fixingconvert-mesh-to-mpi.mlir tests

---
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 118 +++++++++---------
 1 file changed, 59 insertions(+), 59 deletions(-)

diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 4ed6a2ffcd683..6d865ddaf7ce5 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -133,7 +133,7 @@ memref.global constant @static_mpi_rank : memref<index> = dense<24>
 func.func @update_halo_3d(
   // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
   %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
-  // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
+  // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
   // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
   // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
   // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
@@ -141,39 +141,39 @@ func.func @update_halo_3d(
   // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
   // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
   // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
   // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
   // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
   // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
   // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
-  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
-  // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+  // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+  // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
   // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
-  // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
-  // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
-  // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
-  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
-  // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
-  // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
-  // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
-  // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
-  // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
-  // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
-  // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
+  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+  // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+  // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+  // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+  // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+  // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+  // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+  // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+  // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+  // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+  // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+  // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+  // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+  // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+  // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
   %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
   // CHECK: return [[varg0]] : memref<120x120x120xi8>
   return %res : memref<120x120x120xi8>
@@ -183,7 +183,7 @@ func.func @update_halo_3d(
 func.func @update_halo_3d_tensor(
   // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
   %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
-  // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
+  // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
   // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
   // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
   // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
@@ -192,40 +192,40 @@ func.func @update_halo_3d_tensor(
   // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
   // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
   // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
   // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
   // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
   // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
   // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
-  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
-  // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+  // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+  // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
   // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
-  // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
-  // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
-  // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
-  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
-  // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
-  // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
-  // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
-  // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
-  // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
-  // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
-  // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
-  // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8>
+  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+  // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+  // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+  // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+  // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+  // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+  // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+  // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+  // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+  // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+  // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+  // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+  // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+  // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+  // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+  // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+  // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
   %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
   // CHECK: return [[v1]] : tensor<120x120x120xi8>
   return %res : tensor<120x120x120xi8>

>From 806ffd4207acf157f6a6def219b3e06da6e003c7 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 24 Feb 2025 12:50:23 +0100
Subject: [PATCH 5/8] clang-format

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 22 +++++++++++----------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 91f235673f817..21f1ba411dafb 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -573,18 +573,20 @@ struct ConvertUpdateHaloOp
         return v.is<Value>()
                    ? v.get<Value>()
                    : rewriter.create<::mlir::arith::ConstantOp>(
-                         loc, rewriter.getI32IntegerAttr(
-                                  cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+                         loc,
+                         rewriter.getI32IntegerAttr(
+                             cast<IntegerAttr>(v.get<Attribute>()).getInt()));
       };
 
-      for (int i=0; i<2; ++i) {
-        Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);                              
-        auto hasSize = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, haloSz, zero);
-        rewriter.create<scf::IfOp>(
-            loc, hasSize, [&](OpBuilder &builder, Location loc) {
-              genSendRecv(i > 0);
-              builder.create<scf::YieldOp>(loc);
-            });
+      for (int i = 0; i < 2; ++i) {
+        Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);
+        auto hasSize = rewriter.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::sgt, haloSz, zero);
+        rewriter.create<scf::IfOp>(loc, hasSize,
+                                   [&](OpBuilder &builder, Location loc) {
+                                     genSendRecv(i > 0);
+                                     builder.create<scf::YieldOp>(loc);
+                                   });
       }
 
       // the shape for lower dims include higher dims' halos

>From 35f31d840023cc9fdf4050f0604586141794d848 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 26 Feb 2025 12:15:58 +0100
Subject: [PATCH 6/8] fixes to and lowering of mesh.shard_shape

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  22 ++-
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 176 +++++++++++++++---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  15 +-
 .../Extensions/MeshShardingExtensions.cpp     |  19 +-
 .../MeshToMPI/convert-shardshape-to-mpi.mlir  |  75 ++++++++
 mlir/test/Dialect/Mesh/ops.mlir               |  10 +-
 .../test/Dialect/Tensor/mesh-spmdization.mlir |  11 +-
 7 files changed, 275 insertions(+), 53 deletions(-)
 create mode 100644 mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 031e6f63bcb42..f59c4c4c67517 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -345,24 +345,32 @@ def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
   }];
 }
 
-def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
-  let summary = "Get the shard shape of a given process/device.";
+def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
+    Pure, AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Get the shard shape for a given process/device.";
   let description = [{
-    The device/process id is a linearized id of the device/process in the mesh.
+    The device/process id is a multi-index of the device/process in the mesh.
     This operation might be used during spmdization when the shard shape depends
     on (non-constant) values used in `mesh.sharding`.
   }];
   let arguments = (ins
-    DenseI64ArrayAttr:$shape,
+    DenseI64ArrayAttr:$dims,
+    Variadic<Index>:$dims_dynamic,
     Mesh_Sharding:$sharding,
-    Index:$device
+    DenseI64ArrayAttr:$device,
+    Variadic<Index>:$device_dynamic
   );
   let results = (outs Variadic<Index>:$result);
   let assemblyFormat = [{
-      custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
+      `dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
+      `sharding` `=` $sharding
+      `device` `=` custom<DynamicIndexList>($device_dynamic, $device)
+      attr-dict `:` type(results)
   }];
   let builders = [
-    OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
+    OpBuilder<(ins "ArrayRef<int64_t>":$dims, "ArrayRef<Value>":$dims_dyn, "Value":$sharding, "ValueRange":$device)>
   ];
 }
 
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 21f1ba411dafb..b94181dd4f061 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -44,6 +44,29 @@ using namespace mlir;
 using namespace mlir::mesh;
 
 namespace {
+// convert vec of OpFoldResults (ints) into vec of Values
+static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
+                                           llvm::ArrayRef<int64_t> statics,
+                                           ValueRange dynamics,
+                                           Type type = Type()) {
+  SmallVector<Value> values;
+  auto dyn = dynamics.begin();
+  Type i64 = b.getI64Type();
+  if (!type)
+    type = i64;
+  assert(i64 == type || b.getIndexType() == type);
+  for (auto s : statics) {
+    values.emplace_back(
+        ShapedType::isDynamic(s)
+            ? *(dyn++)
+            : b.create<arith::ConstantOp>(loc, type,
+                                          i64 == type ? b.getI64IntegerAttr(s)
+                                                      : b.getIndexAttr(s))
+                  .getResult());
+  }
+  return values;
+};
+
 // Create operations converting a linear index to a multi-dimensional index
 static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
                                              Value linearIndex,
@@ -145,27 +168,11 @@ struct ConvertShardingOp
           loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
     }
 
-    // convert vec of OpFoldResults (ints) into vec of Values
-    auto getMixedAsValues = [&](llvm::ArrayRef<int64_t> statics,
-                                ValueRange dynamics) {
-      SmallVector<Value> values;
-      auto dyn = dynamics.begin();
-      for (auto s : statics) {
-        values.emplace_back(
-            ShapedType::isDynamic(s)
-                ? *(dyn++)
-                : rewriter
-                      .create<arith::ConstantOp>(loc, i64,
-                                                 rewriter.getI64IntegerAttr(s))
-                      .getResult());
-      }
-      return values;
-    };
-
     // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
     // Store the halo sizes in the tensor.
-    auto haloSizes = getMixedAsValues(adaptor.getStaticHaloSizes(),
-                                      adaptor.getDynamicHaloSizes());
+    auto haloSizes =
+        getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
+                         adaptor.getDynamicHaloSizes());
     auto type = RankedTensorType::get({nSplits, 2}, i64);
     Value resHaloSizes =
         haloSizes.empty()
@@ -207,8 +214,9 @@ struct ConvertShardingOp
               .getResult();
       resOffsets =
           rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
-      auto offsets = getMixedAsValues(adaptor.getStaticShardedDimsOffsets(),
-                                      adaptor.getDynamicShardedDimsOffsets());
+      auto offsets =
+          getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
+                           adaptor.getDynamicShardedDimsOffsets());
       int64_t curr = 0;
       for (auto [i, axes] : llvm::enumerate(splitAxes)) {
         auto splitSize =
@@ -389,6 +397,123 @@ struct ConvertNeighborsLinearIndicesOp
   }
 };
 
+struct ConvertShardShapeOp
+    : public mlir::OpConversionPattern<mlir::mesh::ShardShapeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::mesh::ShardShapeOp op, OneToNOpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
+    if (!sharding) {
+      return op->emitError()
+             << "Expected SharingOp as defining op for sharding"
+             << " but found " << adaptor.getSharding()[0].getDefiningOp();
+    }
+
+    auto loc = op.getLoc();
+    Type index = rewriter.getIndexType();
+
+    // Fill a vector with the dynamic dims
+    SmallVector<Value> dynDims, dynDevice;
+    for (auto dim : adaptor.getDimsDynamic()) {
+      // type conversion should be 1:1 for ints
+      assert(dim.size() == 1);
+      dynDims.emplace_back(dim[0]);
+    }
+    // same for device
+    for (auto device : adaptor.getDeviceDynamic()) {
+      assert(device.size() == 1);
+      dynDevice.emplace_back(device[0]);
+    }
+
+    // To keep the code simple, we convert dims/device to values when they are
+    // attributes
+    auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
+    auto multiIdx =
+        getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
+
+    SymbolTableCollection symbolTableCollection;
+    auto meshOp = getMesh(sharding, symbolTableCollection);
+    // For now we only support static mesh shapes
+    if (ShapedType::isDynamicShape(meshOp.getShape())) {
+      return mlir::failure();
+    }
+
+    auto splitAxes = sharding.getSplitAxes().getAxes();
+    Value shardedDimsOffs;
+    {
+      auto tmp = getMixedAsValues(
+          rewriter, loc, sharding.getStaticShardedDimsOffsets(),
+          sharding.getDynamicShardedDimsOffsets(), index);
+      if (!tmp.empty())
+        shardedDimsOffs =
+            rewriter
+                .create<tensor::FromElementsOp>(
+                    loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
+                    tmp)
+                .getResult();
+    }
+
+    int64_t pos = 0; // position in shardedDimsOffs
+    SmallVector<Value> shardShape;
+    Value zero =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
+    Value one =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
+    for (auto [i, dim] : llvm::enumerate(shape)) {
+      if (i < splitAxes.size()) { // trailing dimensions might not be annotated
+        auto axes = splitAxes[i];
+        if (!axes.empty()) {
+          Value posVal = rewriter.create<arith::ConstantOp>(
+              loc, rewriter.getIndexAttr(pos));
+          Value idx = multiIdx[axes[0]];
+          auto _numShards =
+              collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+          if (shardedDimsOffs) {
+            // If sharded dims offsets are provided, use them to compute the
+            // sharded shape.
+            if (axes.size() > 1)
+              return op->emitError() << "Only single axis sharding is "
+                                        "supported for each dimension.";
+            idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
+            Value off =
+                rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx)
+                    .getResult();
+            idx = rewriter.create<arith::AddIOp>(loc, idx, one);
+            Value nextOff =
+                rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx)
+                    .getResult();
+            Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
+            shardShape.emplace_back(sz);
+          } else {
+            auto numShards = rewriter.create<arith::ConstantOp>(
+                loc, rewriter.getIndexAttr(_numShards));
+            // compute shard dim size by distributing odd elements to trailing
+            // shards sz = dim / numShards + (idx >= (numShards - (dim %
+            // numShards)) ? one : zero)
+            Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShards);
+            Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShards);
+            sz1 = rewriter.create<arith::SubIOp>(loc, numShards, sz1);
+            auto cond = rewriter.create<arith::CmpIOp>(
+                loc, arith::CmpIPredicate::sge, idx, sz1);
+            Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
+            sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
+            shardShape.emplace_back(sz);
+          }
+          pos += _numShards + 1; // add one for the total size
+        } // else no sharding in this dimension
+      }
+      // If no size was added -> no sharding in this dimension.
+      if (shardShape.size() <= i)
+        shardShape.emplace_back(dim);
+    }
+    assert(shardShape.size() == shape.size());
+    rewriter.replaceOp(op, shardShape);
+    return mlir::success();
+  }
+};
+
 struct ConvertUpdateHaloOp
     : public mlir::OpConversionPattern<mlir::mesh::UpdateHaloOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -570,12 +695,12 @@ struct ConvertUpdateHaloOp
       };
 
       auto get_i32val = [&](OpFoldResult &v) {
-        return v.is<Value>()
-                   ? v.get<Value>()
+        return isa<Value>(v)
+                   ? cast<Value>(v)
                    : rewriter.create<::mlir::arith::ConstantOp>(
                          loc,
                          rewriter.getI32IntegerAttr(
-                             cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+                             cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
       };
 
       for (int i = 0; i < 2; ++i) {
@@ -670,7 +795,8 @@ struct ConvertMeshToMPIPass
 
     patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
                  ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp,
-                 ConvertShardingOp, ConvertGetShardingOp>(typeConverter, ctxt);
+                 ConvertShardingOp, ConvertGetShardingOp, ConvertShardShapeOp>(
+        typeConverter, ctxt);
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
         patterns, typeConverter);
     populateCallOpTypeConversionPattern(patterns, typeConverter);
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 304ede195c762..3e9f86fde64f3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -831,12 +831,19 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
 // mesh.shard_shape
 //===----------------------------------------------------------------------===//
 
+void ShardShapeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult()[0], "shard_shape");
+}
+
 void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
                          ::mlir::OperationState &odsState,
-                         ::llvm::ArrayRef<int64_t> shape,
-                         ::mlir::Value sharding, ::mlir::Value device) {
-  SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
-  build(odsBuilder, odsState, resType, shape, sharding, device);
+                         ::llvm::ArrayRef<int64_t> dims,
+                         ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
+                         ::mlir::ValueRange device) {
+  SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
+  build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
+        SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index b2acbf20b3fb9..b3d69eb5e1a23 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -50,10 +50,10 @@ struct CreatorOpShardingInterface
                         IRMapping &spmdizationMap,
                         SymbolTableCollection &symbolTable,
                         OpBuilder &builder) const {
-    auto shardType = cast<ShapedType>(mesh::shardType(
-        op->getResult(0).getType(),
-        mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable),
-        resultShardings[0]));
+    auto mesh =
+        mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+    auto shardType = cast<ShapedType>(
+        mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
     Operation *newOp = nullptr;
     // if the sharding introduces a new dynamic dimension, we take it from
     // the dynamic sharding info. For now bail out if it's not
@@ -66,18 +66,19 @@ struct CreatorOpShardingInterface
       assert(oldType.getRank() == shardType.getRank());
       int currOldOprndNum = -1;
       mesh::ShardShapeOp shapeForDevice;
-      Value device;
+      ValueRange device;
       Operation *newSharding = nullptr;
       for (auto i = 0; i < oldType.getRank(); ++i) {
         if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
           if (!newSharding) {
             newSharding =
                 builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
-            device = builder.create<mesh::ProcessLinearIndexOp>(
-                op->getLoc(), resultShardings[0].getMesh());
+            device =
+                builder.create<mesh::ProcessMultiIndexOp>(op->getLoc(), mesh)
+                    .getResults();
             shapeForDevice = builder.create<mesh::ShardShapeOp>(
-                op->getLoc(), oldType.getShape(), newSharding->getResult(0),
-                device);
+                op->getLoc(), oldType.getShape(), spmdizedOperands,
+                newSharding->getResult(0), device);
           }
           newOperands.emplace_back(shapeForDevice.getResult()[i]);
         } else if (oldType.isDynamicDim(i)) {
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
new file mode 100644
index 0000000000000..1d5cedf57ec6c
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s --convert-mesh-to-mpi \
+// RUN:    | sed 's/%retval, %rank = mpi.comm_rank : !mpi.retval, i32/%rank = arith.constant 24 : i32/g' \
+// RUN:    | mlir-opt -canonicalize \
+// RUN:    | FileCheck %s
+
+// Notice: The sed replacement above defines the linear index to be 24, which is multiindex [1, 0, 4]
+
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 3x4x5)
+
+// all shards are equal
+// CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
+func.func @shard_shape_equal() -> (index, index, index) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  %c9 = arith.constant 9 : index
+  %c12 = arith.constant 12 : index
+  // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+  %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+  // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// last shard in last dim gets an extra element
+// CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
+func.func @shard_shape_odd_1() -> (index, index, index) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  %c9 = arith.constant 9 : index
+  %c12 = arith.constant 12 : index
+  // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+  // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+  %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+  // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// all except first shard in second dim get an extra element
+// CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
+func.func @shard_shape_odd_2() -> (index, index, index) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  %c9 = arith.constant 9 : index
+  // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+  %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+  // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// all except first shard in first dim get an extra element
+// CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
+func.func @shard_shape_odd_3() -> (index, index, index) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+  // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+  %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+  // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// extract from sharded_dims_offsets
+// CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
+      sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
+  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+  %c9 = arith.constant 9 : index
+  %c12 = arith.constant 12 : index
+  // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+  // CHECK: [[vc2:%.*]] = arith.constant 2 : index
+  %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+  // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 43a75bf3d8040..3d133f2255772 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -157,10 +157,12 @@ func.func @mesh_shard_shape() {
   %c3 = arith.constant 3 : index
   // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
   %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
-  // CHECK-NEXT: mesh.shard_shape 8x? %[[S]] %[[C3]] : index, index
-  %shp:2 = mesh.shard_shape 8x? %s %c3 : index, index
-  // CHECK-NEXT: mesh.shard_shape 8x4 %[[S]] %[[C3]] : index, index
-  %shp1:2 = mesh.shard_shape 8x4 %s %c3 : index, index
+  // CHECK-NEXT: mesh.shard_shape dims = [8, %[[C3]]
+  // CHECK-SAME: ] sharding = %[[S]] device = [%[[C3]]
+  // CHECK-SAME: ] : index, index
+  %shp:2 = mesh.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index
+  // CHECK-NEXT: mesh.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index
+  %shp1:2 = mesh.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index
   return
 }
 
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 5443eea83aa2d..01cf5972177f4 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -10,8 +10,9 @@ func.func @tensor_empty_static_sharded_dims_offsets() -> () {
   %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
   // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
-  // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
-  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
+  // CHECK:  %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
+  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+  // CHECK-SAME: ] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
 
   return
@@ -24,8 +25,10 @@ func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
   %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
   // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
-  // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
-  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
+  // CHECK:  %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
+  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape dims = [8, %[[A0]]
+  // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+  // CHECK-SAME: ] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
 
   return

>From 1faf707c726f670cfadae43b489df00e233ec72b Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 11:58:17 +0100
Subject: [PATCH 7/8] cleanup & formatting

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 254 +++++++++-----------
 1 file changed, 112 insertions(+), 142 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b94181dd4f061..9d638100e637b 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -41,10 +41,10 @@ namespace mlir {
 } // namespace mlir
 
 using namespace mlir;
-using namespace mlir::mesh;
+using namespace mesh;
 
 namespace {
-// convert vec of OpFoldResults (ints) into vec of Values
+/// Convert vec of OpFoldResults (ints) into vector of Values.
 static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
                                            llvm::ArrayRef<int64_t> statics,
                                            ValueRange dynamics,
@@ -61,13 +61,12 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
             ? *(dyn++)
             : b.create<arith::ConstantOp>(loc, type,
                                           i64 == type ? b.getI64IntegerAttr(s)
-                                                      : b.getIndexAttr(s))
-                  .getResult());
+                                                      : b.getIndexAttr(s)));
   }
   return values;
 };
 
-// Create operations converting a linear index to a multi-dimensional index
+/// Create operations converting a linear index to a multi-dimensional index.
 static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
                                              Value linearIndex,
                                              ValueRange dimensions) {
@@ -76,23 +75,22 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
 
   for (int i = n - 1; i >= 0; --i) {
     multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
-    if (i > 0) {
+    if (i > 0)
       linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
-    }
   }
 
   return multiIndex;
 }
 
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
 Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
                          ValueRange dimensions) {
 
-  auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
-  auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+  Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
 
   for (int i = multiIndex.size() - 1; i >= 0; --i) {
-    auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+    Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
     linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
     stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
   }
@@ -100,35 +98,34 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
   return linearIndex;
 }
 
-// replace GetShardingOp with related ShardingOp
-struct ConvertGetShardingOp
-    : public mlir::OpConversionPattern<mlir::mesh::GetShardingOp> {
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
   using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::GetShardingOp op, OpAdaptor adaptor,
-                  mlir::ConversionPatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
-    if (!shardOp) {
-      return mlir::failure();
-    }
+    if (!shardOp)
+      return failure();
     auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
-    if (!shardingOp) {
-      return mlir::failure();
-    }
+    if (!shardingOp)
+      return failure();
 
     rewriter.replaceOp(op, shardingOp.getResult());
-    return mlir::success();
+    return success();
   }
 };
 
-struct ConvertShardingOp
-    : public mlir::OpConversionPattern<mlir::mesh::ShardingOp> {
+/// Convert a sharding op to a tuple of tensors of its components
+///   (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
   using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ShardingOp op, OpAdaptor adaptor,
-                  mlir::ConversionPatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     auto splitAxes = op.getSplitAxes().getAxes();
     int64_t maxNAxes = 0;
     for (auto axes : splitAxes) {
@@ -138,17 +135,15 @@ struct ConvertShardingOp
     // To hold the split axes, create empty 2d tensor with shape
     // {splitAxes.size(), max-size-of-split-groups}.
     // Set trailing elements for smaller split-groups to -1.
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     auto i16 = rewriter.getI16Type();
     auto i64 = rewriter.getI64Type();
     int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
     Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
     auto attr = IntegerAttr::get(i16, 0xffff);
-    auto fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
-    resSplitAxes =
-        rewriter
-            .create<linalg::FillOp>(loc, fillValue.getResult(), resSplitAxes)
-            .getResult(0);
+    Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+    resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
+                       .getResult(0);
 
     // explicitly write values into tensor row by row
     int64_t strides[] = {1, 1};
@@ -156,9 +151,8 @@ struct ConvertShardingOp
     ValueRange empty = {};
     for (auto [i, axes] : llvm::enumerate(splitAxes)) {
       int64_t size = axes.size();
-      if (size > 0) {
+      if (size > 0)
         ++nSplits;
-      }
       int64_t offs[] = {(int64_t)i, 0};
       int64_t sizes[] = {1, size};
       auto tensorType = RankedTensorType::get({size}, i16);
@@ -197,7 +191,7 @@ struct ConvertShardingOp
       auto meshOp = getMesh(op, symbolTableCollection);
       auto maxSplitSize = 0;
       for (auto axes : splitAxes) {
-        auto splitSize =
+        int64_t splitSize =
             collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
         assert(splitSize != ShapedType::kDynamic);
         maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
@@ -207,11 +201,8 @@ struct ConvertShardingOp
 
       resOffsets = rewriter.create<tensor::EmptyOp>(
           loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
-      auto zero =
-          rewriter
-              .create<arith::ConstantOp>(
-                  loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic))
-              .getResult();
+      Value zero = rewriter.create<arith::ConstantOp>(
+          loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
       resOffsets =
           rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
       auto offsets =
@@ -219,12 +210,12 @@ struct ConvertShardingOp
                            adaptor.getDynamicShardedDimsOffsets());
       int64_t curr = 0;
       for (auto [i, axes] : llvm::enumerate(splitAxes)) {
-        auto splitSize =
+        int64_t splitSize =
             collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
         assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
         ++splitSize; // add one for the total size
         ArrayRef<Value> values(&offsets[curr], splitSize);
-        auto vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+        Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
         int64_t offs[] = {(int64_t)i, 0};
         int64_t sizes[] = {1, splitSize};
         resOffsets = rewriter.create<tensor::InsertSliceOp>(
@@ -236,9 +227,9 @@ struct ConvertShardingOp
     // return a tuple of tensors as defined by type converter
     SmallVector<Type> resTypes;
     if (failed(getTypeConverter()->convertType(op.getResult().getType(),
-                                               resTypes))) {
-      return mlir::failure();
-    }
+                                               resTypes)))
+      return failure();
+
     resSplitAxes =
         rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
     resHaloSizes =
@@ -249,35 +240,33 @@ struct ConvertShardingOp
         op, TupleType::get(op.getContext(), resTypes),
         ValueRange{resSplitAxes, resHaloSizes, resOffsets});
 
-    return mlir::success();
+    return success();
   }
 };
 
 struct ConvertProcessMultiIndexOp
-    : public mlir::OpConversionPattern<mlir::mesh::ProcessMultiIndexOp> {
+    : public OpConversionPattern<ProcessMultiIndexOp> {
   using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op, OpAdaptor adaptor,
-                  mlir::ConversionPatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // Currently converts its linear index to a multi-dimensional index.
 
     SymbolTableCollection symbolTableCollection;
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     auto meshOp = getMesh(op, symbolTableCollection);
     // For now we only support static mesh shapes
-    if (ShapedType::isDynamicShape(meshOp.getShape())) {
-      return mlir::failure();
-    }
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return failure();
 
     SmallVector<Value> dims;
     llvm::transform(
         meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
           return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
         });
-    auto rank =
-        rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
+    Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
     auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
 
     // optionally extract subset of mesh axes
@@ -291,22 +280,22 @@ struct ConvertProcessMultiIndexOp
     }
 
     rewriter.replaceOp(op, mIdx);
-    return mlir::success();
+    return success();
   }
 };
 
 struct ConvertProcessLinearIndexOp
-    : public mlir::OpConversionPattern<mlir::mesh::ProcessLinearIndexOp> {
+    : public OpConversionPattern<ProcessLinearIndexOp> {
   using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op, OpAdaptor adaptor,
-                  mlir::ConversionPatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // Finds a global named "static_mpi_rank" it will use that splat value.
     // Otherwise it defaults to mpi.comm_rank.
 
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
     if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
             op, rankOpName)) {
@@ -314,7 +303,7 @@ struct ConvertProcessLinearIndexOp
         auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
         rewriter.replaceOp(op,
                            rewriter.create<arith::ConstantIndexOp>(loc, val));
-        return mlir::success();
+        return success();
       }
     }
     auto rank =
@@ -325,17 +314,17 @@ struct ConvertProcessLinearIndexOp
             .getRank();
     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
                                                     rank);
-    return mlir::success();
+    return success();
   }
 };
 
 struct ConvertNeighborsLinearIndicesOp
-    : public mlir::OpConversionPattern<mlir::mesh::NeighborsLinearIndicesOp> {
+    : public OpConversionPattern<NeighborsLinearIndicesOp> {
   using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op, OpAdaptor adaptor,
-                  mlir::ConversionPatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // Computes the neighbors indices along a split axis by simply
     // adding/subtracting 1 to the current index in that dimension.
@@ -343,11 +332,10 @@ struct ConvertNeighborsLinearIndicesOp
 
     auto axes = adaptor.getSplitAxes();
     // For now only single axis sharding is supported
-    if (axes.size() != 1) {
-      return mlir::failure();
-    }
+    if (axes.size() != 1)
+      return failure();
 
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     SymbolTableCollection symbolTableCollection;
     auto meshOp = getMesh(op, symbolTableCollection);
     auto mIdx = adaptor.getDevice();
@@ -357,12 +345,12 @@ struct ConvertNeighborsLinearIndicesOp
         meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
           return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
         });
-    auto dimSz = dims[axes[0]];
-    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
-    auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
-    auto atBorder = rewriter.create<arith::CmpIOp>(
+    Value dimSz = dims[axes[0]];
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
+    Value atBorder = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::sle, orgIdx,
-        rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
+        rewriter.create<arith::ConstantIndexOp>(loc, 0));
     auto down = rewriter.create<scf::IfOp>(
         loc, atBorder,
         [&](OpBuilder &builder, Location loc) {
@@ -387,23 +375,21 @@ struct ConvertNeighborsLinearIndicesOp
         [&](OpBuilder &builder, Location loc) {
           SmallVector<Value> tmp = mIdx;
           tmp[axes[0]] =
-              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
-                  .getResult();
+              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
           builder.create<scf::YieldOp>(
               loc, multiToLinearIndex(loc, rewriter, tmp, dims));
         });
     rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
-    return mlir::success();
+    return success();
   }
 };
 
-struct ConvertShardShapeOp
-    : public mlir::OpConversionPattern<mlir::mesh::ShardShapeOp> {
+struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
   using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ShardShapeOp op, OneToNOpAdaptor adaptor,
-                  mlir::ConversionPatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
     if (!sharding) {
       return op->emitError()
@@ -411,7 +397,7 @@ struct ConvertShardShapeOp
              << " but found " << adaptor.getSharding()[0].getDefiningOp();
     }
 
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     Type index = rewriter.getIndexType();
 
     // Fill a vector with the dynamic dims
@@ -436,9 +422,8 @@ struct ConvertShardShapeOp
     SymbolTableCollection symbolTableCollection;
     auto meshOp = getMesh(sharding, symbolTableCollection);
     // For now we only support static mesh shapes
-    if (ShapedType::isDynamicShape(meshOp.getShape())) {
-      return mlir::failure();
-    }
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return failure();
 
     auto splitAxes = sharding.getSplitAxes().getAxes();
     Value shardedDimsOffs;
@@ -447,12 +432,8 @@ struct ConvertShardShapeOp
           rewriter, loc, sharding.getStaticShardedDimsOffsets(),
           sharding.getDynamicShardedDimsOffsets(), index);
       if (!tmp.empty())
-        shardedDimsOffs =
-            rewriter
-                .create<tensor::FromElementsOp>(
-                    loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
-                    tmp)
-                .getResult();
+        shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
+            loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
     }
 
     int64_t pos = 0; // position in shardedDimsOffs
@@ -473,17 +454,16 @@ struct ConvertShardShapeOp
           if (shardedDimsOffs) {
             // If sharded dims offsets are provided, use them to compute the
             // sharded shape.
-            if (axes.size() > 1)
+            if (axes.size() > 1) {
               return op->emitError() << "Only single axis sharding is "
-                                        "supported for each dimension.";
+                                     << "supported for each dimension.";
+            }
             idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
             Value off =
-                rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx)
-                    .getResult();
+                rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
             idx = rewriter.create<arith::AddIOp>(loc, idx, one);
             Value nextOff =
-                rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx)
-                    .getResult();
+                rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
             Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
             shardShape.emplace_back(sz);
           } else {
@@ -510,17 +490,16 @@ struct ConvertShardShapeOp
     }
     assert(shardShape.size() == shape.size());
     rewriter.replaceOp(op, shardShape);
-    return mlir::success();
+    return success();
   }
 };
 
-struct ConvertUpdateHaloOp
-    : public mlir::OpConversionPattern<mlir::mesh::UpdateHaloOp> {
+struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
   using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::UpdateHaloOp op, OpAdaptor adaptor,
-                  mlir::ConversionPatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // The input/output memref is assumed to be in C memory order.
     // Halos are exchanged as 2 blocks per dimension (one for each side: down
@@ -539,17 +518,17 @@ struct ConvertUpdateHaloOp
     if (haloSizes.empty()) {
       // no halos -> nothing to do
       rewriter.replaceOp(op, adaptor.getDestination());
-      return mlir::success();
+      return success();
     }
 
     SymbolTableCollection symbolTableCollection;
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
 
     // convert a OpFoldResult into a Value
     auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
       if (auto value = dyn_cast<Value>(v))
         return value;
-      return rewriter.create<::mlir::arith::ConstantOp>(
+      return rewriter.create<arith::ConstantOp>(
           loc, rewriter.getIndexAttr(
                    cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
     };
@@ -561,8 +540,8 @@ struct ConvertUpdateHaloOp
       // If the destination is a memref, we need to cast it to a tensor
       auto tensorType = MemRefType::get(
           dstShape, cast<ShapedType>(array.getType()).getElementType());
-      array = rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array)
-                  .getResult();
+      array =
+          rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array);
     }
     auto rank = cast<ShapedType>(array.getType()).getRank();
     auto opSplitAxes = adaptor.getSplitAxes().getAxes();
@@ -570,12 +549,11 @@ struct ConvertUpdateHaloOp
     auto meshOp = getMesh(op, symbolTableCollection);
     // subviews need Index values
     for (auto &sz : haloSizes) {
-      if (auto value = dyn_cast<Value>(sz)) {
+      if (auto value = dyn_cast<Value>(sz))
         sz =
             rewriter
                 .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
                 .getResult();
-      }
     }
 
     // most of the offset/size/stride data is the same for all dims
@@ -586,11 +564,10 @@ struct ConvertUpdateHaloOp
     // we need the actual shape to compute offsets and sizes
     for (auto i = 0; i < rank; ++i) {
       auto s = dstShape[i];
-      if (ShapedType::isDynamic(s)) {
+      if (ShapedType::isDynamic(s))
         shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
-      } else {
+      else
         shape[i] = rewriter.getIndexAttr(s);
-      }
 
       if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
         ++currHaloDim;
@@ -598,11 +575,9 @@ struct ConvertUpdateHaloOp
         offsets[i] = haloSizes[currHaloDim * 2];
 
         // prepare shape and offsets of highest dim's halo exchange
-        auto _haloSz =
-            rewriter
-                .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
-                                       toValue(haloSizes[currHaloDim * 2 + 1]))
-                .getResult();
+        Value _haloSz = rewriter.create<arith::AddIOp>(
+            loc, toValue(haloSizes[currHaloDim * 2]),
+            toValue(haloSizes[currHaloDim * 2 + 1]));
         // the halo shape of lower dims exlude the halos
         dimSizes[i] =
             rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
@@ -613,9 +588,9 @@ struct ConvertUpdateHaloOp
     }
 
     auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
-    auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+    auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
     auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
-    auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+    auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
 
     SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
                                        rewriter.getIndexType());
@@ -625,9 +600,8 @@ struct ConvertUpdateHaloOp
     // traverse all split axes from high to low dim
     for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
       auto splitAxes = opSplitAxes[dim];
-      if (splitAxes.empty()) {
+      if (splitAxes.empty())
         continue;
-      }
       assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
       // Get the linearized ids of the neighbors (down and up) for the
       // given split
@@ -697,7 +671,7 @@ struct ConvertUpdateHaloOp
       auto get_i32val = [&](OpFoldResult &v) {
         return isa<Value>(v)
                    ? cast<Value>(v)
-                   : rewriter.create<::mlir::arith::ConstantOp>(
+                   : rewriter.create<arith::ConstantOp>(
                          loc,
                          rewriter.getI32IntegerAttr(
                              cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
@@ -730,7 +704,7 @@ struct ConvertUpdateHaloOp
                                  loc, op.getResult().getType(), array,
                                  /*restrict=*/true, /*writable=*/true));
     }
-    return mlir::success();
+    return success();
   }
 };
 
@@ -741,9 +715,9 @@ struct ConvertMeshToMPIPass
   /// Run the dialect converter on the module.
   void runOnOperation() override {
     auto *ctxt = &getContext();
-    mlir::RewritePatternSet patterns(ctxt);
+    RewritePatternSet patterns(ctxt);
     ConversionTarget target(getContext());
-    mlir::TypeConverter typeConverter;
+    TypeConverter typeConverter;
 
     typeConverter.addConversion([](Type type) { return type; });
     // convert mesh::ShardingType to a tuple of RankedTensorTypes
@@ -757,25 +731,22 @@ struct ConvertMeshToMPIPass
           results.emplace_back(RankedTensorType::get(shp, i16));
           results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
           results.emplace_back(RankedTensorType::get(shp, i64));
-          return mlir::success();
+          return success();
         });
     // To extract components a UnrealizedConversionCastOp is expected to define
     // the input
     typeConverter.addTargetMaterialization(
         [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
             Location loc) {
-          if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType())) {
+          if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
             return SmallVector<Value>();
-          }
           auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
-          if (!castOp) {
+          if (!castOp)
             return SmallVector<Value>();
-          }
           SmallVector<Value> results;
           for (auto oprnd : castOp.getInputs()) {
-            if (!isa<RankedTensorType>(oprnd.getType())) {
+            if (!isa<RankedTensorType>(oprnd.getType()))
               return SmallVector<Value>();
-            }
             results.emplace_back(oprnd);
           }
           return results;
@@ -802,8 +773,7 @@ struct ConvertMeshToMPIPass
     populateCallOpTypeConversionPattern(patterns, typeConverter);
     populateReturnOpTypeConversionPattern(patterns, typeConverter);
 
-    (void)mlir::applyPartialConversion(getOperation(), target,
-                                       std::move(patterns));
+    (void)applyPartialConversion(getOperation(), target, std::move(patterns));
   }
 };
 

>From 95582a8ea943a62efac7d343eec6b466521afdad Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 12:56:42 +0100
Subject: [PATCH 8/8] using DLTI instead of global symbol to get optional
 static MPI rank

---
 mlir/include/mlir/Conversion/Passes.td        |   8 +-
 mlir/lib/Conversion/MeshToMPI/CMakeLists.txt  |   1 +
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 121 ++++++--
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 279 +++++++++---------
 .../MeshToMPI/convert-shardshape-to-mpi.mlir  | 136 ++++-----
 5 files changed, 305 insertions(+), 240 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..6074e0e8d822c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
   let description = [{
     This pass converts communication operations from the Mesh dialect to the
     MPI dialect.
-    If it finds a global named "static_mpi_rank" it will use that splat value
-    instead of calling MPI_Comm_rank. This allows optimizations like constant
-    shape propagation and fusion because shard/partition sizes depend on the
-    rank.
+    If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
+    use that integer value instead of calling MPI_Comm_rank. This allows
+    optimizations like constant shape propagation and fusion because
+    shard/partition sizes depend on the rank.
   }];
   let dependentDialects = [
     "memref::MemRefDialect",
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
index 95815a683f6d6..15560aa61e145 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
   Core
 
   LINK_LIBS PUBLIC
+  MLIRDLTIDialect
   MLIRFuncDialect
   MLIRIR
   MLIRLinalgTransforms
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 9d638100e637b..84db6d456711c 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -284,28 +285,29 @@ struct ConvertProcessMultiIndexOp
   }
 };
 
-struct ConvertProcessLinearIndexOp
+class ConvertProcessLinearIndexOp
     : public OpConversionPattern<ProcessLinearIndexOp> {
+  int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
+
+public:
   using OpConversionPattern::OpConversionPattern;
 
+  // Constructor accepting worldRank
+  ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
+                              MLIRContext *context, int64_t worldRank_ = -1)
+      : OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
+
   LogicalResult
   matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    // Finds a global named "static_mpi_rank" it will use that splat value.
-    // Otherwise it defaults to mpi.comm_rank.
-
     Location loc = op.getLoc();
-    auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
-    if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
-            op, rankOpName)) {
-      if (auto initTnsr = globalOp.getInitialValueAttr()) {
-        auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
-        rewriter.replaceOp(op,
-                           rewriter.create<arith::ConstantIndexOp>(loc, val));
-        return success();
-      }
+    if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
+      return success();
     }
+
+    // Otherwise call create mpi::CommRankOp
     auto rank =
         rewriter
             .create<mpi::CommRankOp>(
@@ -397,10 +399,22 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
              << " but found " << adaptor.getSharding()[0].getDefiningOp();
     }
 
+    // Compute the sharded shape by applying the sharding to the input shape.
+    // Without shardedDimsOffsets in the sharding, the shard shape is computed
+    // by dividing the dimension size by the number of shards in that dimension
+    // (which is given by the size of the mesh axes provided in split-axes).
+    // Odd elements get distributed to trailing shards.
+    // If a shardedDimsOffsets is provided, the shard shape is computed by
+    // subtracting the offset of the current shard from the offset of the next
+    // shard.
+
     Location loc = op.getLoc();
     Type index = rewriter.getIndexType();
 
-    // Fill a vector with the dynamic dims
+    // This is a 1:N conversion because the sharding op is a 1:3 conversion.
+    // The operands in the adaptor are a vector<ValeRange>. For dims and device
+    // we have a 1:1 conversion.
+    // For simpler access fill a vector with the dynamic dims.
     SmallVector<Value> dynDims, dynDevice;
     for (auto dim : adaptor.getDimsDynamic()) {
       // type conversion should be 1:1 for ints
@@ -413,12 +427,13 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
       dynDevice.emplace_back(device[0]);
     }
 
-    // To keep the code simple, we convert dims/device to values when they are
-    // attributes
+    // To keep the code simple, convert dims/device to values when they are
+    // attributes. Count on canonicalization to fold static values.
     auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
     auto multiIdx =
         getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
 
+    // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
     SymbolTableCollection symbolTableCollection;
     auto meshOp = getMesh(sharding, symbolTableCollection);
     // For now we only support static mesh shapes
@@ -426,6 +441,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
       return failure();
 
     auto splitAxes = sharding.getSplitAxes().getAxes();
+    // shardedDimsOffsets are optional and might be Values (not attributes).
+    // Also, the shardId might be dynamic which means the position in the
+    // shardedDimsOffsets is not statically known. Create a tensor of the
+    // shardedDimsOffsets and later extract the offsets for computing the
+    // local shard-size.
     Value shardedDimsOffs;
     {
       auto tmp = getMixedAsValues(
@@ -436,18 +456,28 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
             loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
     }
 
-    int64_t pos = 0; // position in shardedDimsOffs
+    // With static mesh shape the sizes of the split axes are known.
+    // Hence the start/pos for each split axes in shardDimsOffsets can be
+    // computed statically.
+    int64_t pos = 0;
     SmallVector<Value> shardShape;
     Value zero =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
     Value one =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
+
+    // Iterate over the dimensions of the tensor shape, get their split Axes,
+    // and compute the sharded shape.
     for (auto [i, dim] : llvm::enumerate(shape)) {
-      if (i < splitAxes.size()) { // trailing dimensions might not be annotated
+      // Trailing dimensions might not be annotated.
+      if (i < splitAxes.size()) {
         auto axes = splitAxes[i];
+        // The current dimension might not be sharded.
         if (!axes.empty()) {
+          // Create a value from the static position in shardDimsOffsets.
           Value posVal = rewriter.create<arith::ConstantOp>(
               loc, rewriter.getIndexAttr(pos));
+          // Get the index of the local shard in the mesh axis.
           Value idx = multiIdx[axes[0]];
           auto _numShards =
               collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
@@ -459,6 +489,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
                                      << "supported for each dimension.";
             }
             idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
+            // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
             Value off =
                 rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
             idx = rewriter.create<arith::AddIOp>(loc, idx, one);
@@ -469,9 +500,10 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
           } else {
             auto numShards = rewriter.create<arith::ConstantOp>(
                 loc, rewriter.getIndexAttr(_numShards));
-            // compute shard dim size by distributing odd elements to trailing
-            // shards sz = dim / numShards + (idx >= (numShards - (dim %
-            // numShards)) ? one : zero)
+            // Compute shard dim size by distributing odd elements to trailing
+            // shards:
+            // sz = dim / numShards
+            //      + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
             Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShards);
             Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShards);
             sz1 = rewriter.create<arith::SubIOp>(loc, numShards, sz1);
@@ -481,9 +513,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
             sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
             shardShape.emplace_back(sz);
           }
-          pos += _numShards + 1; // add one for the total size
-        } // else no sharding in this dimension
-      }
+          pos += _numShards + 1; // add one for the total size.
+        } // else no sharding if split axis is empty.
+      } // else no sharding if no split axis
       // If no size was added -> no sharding in this dimension.
       if (shardShape.size() <= i)
         shardShape.emplace_back(dim);
@@ -714,12 +746,31 @@ struct ConvertMeshToMPIPass
 
   /// Run the dialect converter on the module.
   void runOnOperation() override {
+    uint64_t worldRank = -1;
+    // Try to get DLTI attribute for MPI:comm_world_rank
+    // If found, set worldRank to the value of the attribute.
+    {
+      auto dltiAttr =
+          dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
+      if (succeeded(dltiAttr)) {
+        if (!isa<IntegerAttr>(dltiAttr.value())) {
+          getOperation()->emitError()
+              << "Expected an integer attribute for MPI:comm_world_rank";
+          return signalPassFailure();
+        }
+        worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
+      }
+    }
+
     auto *ctxt = &getContext();
     RewritePatternSet patterns(ctxt);
     ConversionTarget target(getContext());
-    TypeConverter typeConverter;
 
+    // Define a type converter to convert mesh::ShardingType,
+    // mostly for use in return operations.
+    TypeConverter typeConverter;
     typeConverter.addConversion([](Type type) { return type; });
+
     // convert mesh::ShardingType to a tuple of RankedTensorTypes
     typeConverter.addConversion(
         [](ShardingType type,
@@ -733,16 +784,20 @@ struct ConvertMeshToMPIPass
           results.emplace_back(RankedTensorType::get(shp, i64));
           return success();
         });
-    // To extract components a UnrealizedConversionCastOp is expected to define
-    // the input
+
+    // To 'extract' components, a UnrealizedConversionCastOp is expected
+    // to define the input
     typeConverter.addTargetMaterialization(
         [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
             Location loc) {
+          // Expecting a single input.
           if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
             return SmallVector<Value>();
           auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
+          // Expecting an UnrealizedConversionCastOp.
           if (!castOp)
             return SmallVector<Value>();
+          // Fill a vector with elements of the tuple/castOp.
           SmallVector<Value> results;
           for (auto oprnd : castOp.getInputs()) {
             if (!isa<RankedTensorType>(oprnd.getType()))
@@ -752,12 +807,16 @@ struct ConvertMeshToMPIPass
           return results;
         });
 
+    // No mesh dialect should left after conversion...
     target.addIllegalDialect<mesh::MeshDialect>();
+    // ...except the global MeshOp
     target.addLegalOp<mesh::MeshOp>();
+    // Allow all the stuff that our patterns will convert to
     target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
                            arith::ArithDialect, tensor::TensorDialect,
                            bufferization::BufferizationDialect,
                            linalg::LinalgDialect, memref::MemRefDialect>();
+    // Make sure the function signature, calls etc. are legal
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
       return typeConverter.isSignatureLegal(op.getFunctionType());
     });
@@ -765,9 +824,11 @@ struct ConvertMeshToMPIPass
         [&](Operation *op) { return typeConverter.isLegal(op); });
 
     patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
-                 ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp,
-                 ConvertShardingOp, ConvertGetShardingOp, ConvertShardShapeOp>(
-        typeConverter, ctxt);
+                 ConvertProcessMultiIndexOp, ConvertGetShardingOp,
+                 ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
+    // ConvertProcessLinearIndexOp accepts an optional worldRank
+    patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
+
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
         patterns, typeConverter);
     populateCallOpTypeConversionPattern(patterns, typeConverter);
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 6d865ddaf7ce5..4e60c6f0d4e44 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -60,23 +60,24 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
 
 // -----
 // CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 3x4x5)
-memref.global constant @static_mpi_rank : memref<index> = dense<24>
-func.func @process_multi_index() -> (index, index, index) {
-  // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
-  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
-  %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
-  // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
-  return %0#0, %0#1, %0#2 : index, index, index
-}
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+  mesh.mesh @mesh0(shape = 3x4x5)
+  func.func @process_multi_index() -> (index, index, index) {
+    // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+    // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+    // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
+    return %0#0, %0#1, %0#2 : index, index, index
+  }
 
-// CHECK-LABEL: func @process_linear_index
-func.func @process_linear_index() -> index {
-  // CHECK: %[[c24:.*]] = arith.constant 24 : index
-  %0 = mesh.process_linear_index on @mesh0 : index
-  // CHECK: return %[[c24]] : index
-  return %0 : index
+  // CHECK-LABEL: func @process_linear_index
+  func.func @process_linear_index() -> index {
+    // CHECK: %[[c24:.*]] = arith.constant 24 : index
+    %0 = mesh.process_linear_index on @mesh0 : index
+    // CHECK: return %[[c24]] : index
+    return %0 : index
+  }
 }
 
 // -----
@@ -103,132 +104,134 @@ func.func @update_halo_1d_first(
 }
 
 // -----
-mesh.mesh @mesh0(shape = 4)
-memref.global "public" constant @static_mpi_rank : memref<index> = dense<1>
-// CHECK-LABEL: func @update_halo_1d_with_zero
-func.func @update_halo_1d_with_zero (
-  // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
-  %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
-  // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
-  // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
-  // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
-  // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
-  // CHECK-SAME: to memref<2x120x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
-  // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
-  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
-  // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
-  return %res : memref<120x120x120xi8>
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
+  mesh.mesh @mesh0(shape = 4)
+  // CHECK-LABEL: func @update_halo_1d_with_zero
+  func.func @update_halo_1d_with_zero (
+    // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+    %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+    // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
+    // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+    // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
+    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
+    // CHECK-SAME: to memref<2x120x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
+    // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
+    %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
+    // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
+    return %res : memref<120x120x120xi8>
+  }
 }
 
 // -----
-mesh.mesh @mesh0(shape = 3x4x5)
-memref.global constant @static_mpi_rank : memref<index> = dense<24>
-// CHECK-LABEL: func @update_halo_3d
-func.func @update_halo_3d(
-  // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
-  %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
-  // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
-  // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
-  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
-  // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
-  // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
-  // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
-  // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
-  // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
-  // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-  // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
-  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
-  // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
-  // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
-  // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
-  // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
-  // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
-  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
-  // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
-  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
-  // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
-  // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
-  // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
-  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
-  // CHECK: return [[varg0]] : memref<120x120x120xi8>
-  return %res : memref<120x120x120xi8>
-}
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+  mesh.mesh @mesh0(shape = 3x4x5)
+  // CHECK-LABEL: func @update_halo_3d
+  func.func @update_halo_3d(
+    // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+    %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+    // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
+    // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+    // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+    // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+    // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+    // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+    // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+    // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+    // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+    // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+    // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+    // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+    // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+    // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+    // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+    // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+    // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+    // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+    // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+    // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+    %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
+    // CHECK: return [[varg0]] : memref<120x120x120xi8>
+    return %res : memref<120x120x120xi8>
+  }
 
-// CHECK-LABEL: func @update_halo_3d_tensor
-func.func @update_halo_3d_tensor(
-  // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
-  %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
-  // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
-  // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
-  // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
-  // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
-  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
-  // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
-  // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
-  // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
-  // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
-  // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-  // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
-  // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
-  // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
-  // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
-  // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
-  // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
-  // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
-  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
-  // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
-  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
-  // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
-  // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
-  // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
-  // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
-  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
-  // CHECK: return [[v1]] : tensor<120x120x120xi8>
-  return %res : tensor<120x120x120xi8>
+  // CHECK-LABEL: func @update_halo_3d_tensor
+  func.func @update_halo_3d_tensor(
+    // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
+    %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
+    // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
+    // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+    // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+    // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+    // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
+    // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+    // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+    // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+    // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+    // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+    // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+    // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+    // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+    // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+    // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+    // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+    // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+    // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+    // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+    // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+    // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+    // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
+    %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
+    // CHECK: return [[v1]] : tensor<120x120x120xi8>
+    return %res : tensor<120x120x120xi8>
+  }
 }
 
 // -----
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
index 1d5cedf57ec6c..1cc848333ced2 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
@@ -1,75 +1,75 @@
-// RUN: mlir-opt %s --convert-mesh-to-mpi \
-// RUN:    | sed 's/%retval, %rank = mpi.comm_rank : !mpi.retval, i32/%rank = arith.constant 24 : i32/g' \
-// RUN:    | mlir-opt -canonicalize \
-// RUN:    | FileCheck %s
+// RUN: mlir-opt %s --convert-mesh-to-mpi -canonicalize | FileCheck %s
 
-// Notice: The sed replacement above defines the linear index to be 24, which is multiindex [1, 0, 4]
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
 
-// CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 3x4x5)
+  // CHECK: mesh.mesh @mesh0
+  mesh.mesh @mesh0(shape = 3x4x5)
+  
+  // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @mesh0
 
-// all shards are equal
-// CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
-func.func @shard_shape_equal() -> (index, index, index) {
-  %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
-  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  %c9 = arith.constant 9 : index
-  %c12 = arith.constant 12 : index
-  // CHECK: [[vc3:%.*]] = arith.constant 3 : index
-  %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
-  // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
-  return %1#0, %1#1, %1#2 : index, index, index
-}
+  // all shards are equal
+  // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
+  func.func @shard_shape_equal() -> (index, index, index) {
+    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %c9 = arith.constant 9 : index
+    %c12 = arith.constant 12 : index
+    // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+    %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+    return %1#0, %1#1, %1#2 : index, index, index
+  }
 
-// last shard in last dim gets an extra element
-// CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
-func.func @shard_shape_odd_1() -> (index, index, index) {
-  %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
-  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  %c9 = arith.constant 9 : index
-  %c12 = arith.constant 12 : index
-  // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
-  // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
-  %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
-  // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
-  return %1#0, %1#1, %1#2 : index, index, index
-}
+  // last shard in last dim gets an extra element
+  // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
+  func.func @shard_shape_odd_1() -> (index, index, index) {
+    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %c9 = arith.constant 9 : index
+    %c12 = arith.constant 12 : index
+    // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+    // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+    %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
+    return %1#0, %1#1, %1#2 : index, index, index
+  }
 
-// all except first shard in second dim get an extra element
-// CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
-func.func @shard_shape_odd_2() -> (index, index, index) {
-  %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
-  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  %c9 = arith.constant 9 : index
-  // CHECK: [[vc3:%.*]] = arith.constant 3 : index
-  %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
-  // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
-  return %1#0, %1#1, %1#2 : index, index, index
-}
+  // all except first shard in second dim get an extra element
+  // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
+  func.func @shard_shape_odd_2() -> (index, index, index) {
+    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %c9 = arith.constant 9 : index
+    // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+    %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+    return %1#0, %1#1, %1#2 : index, index, index
+  }
 
-// all except first shard in first dim get an extra element
-// CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
-func.func @shard_shape_odd_3() -> (index, index, index) {
-  %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
-  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
-  // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
-  %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
-  // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
-  return %1#0, %1#1, %1#2 : index, index, index
-}
+  // all except first shard in first dim get an extra element
+  // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
+  func.func @shard_shape_odd_3() -> (index, index, index) {
+    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+    // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+    %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
+    return %1#0, %1#1, %1#2 : index, index, index
+  }
 
-// extract from sharded_dims_offsets
-// CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
-func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
-  %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
-      sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
-  %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
-  %c9 = arith.constant 9 : index
-  %c12 = arith.constant 12 : index
-  // CHECK: [[vc3:%.*]] = arith.constant 3 : index
-  // CHECK: [[vc2:%.*]] = arith.constant 2 : index
-  %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
-  // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
-  return %1#0, %1#1, %1#2 : index, index, index
-}
+  // extract from sharded_dims_offsets
+  // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+  func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
+        sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
+    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %c9 = arith.constant 9 : index
+    %c12 = arith.constant 12 : index
+    // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+    // CHECK: [[vc2:%.*]] = arith.constant 2 : index
+    %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
+    return %1#0, %1#1, %1#2 : index, index, index
+  }
+}
\ No newline at end of file



More information about the Mlir-commits mailing list