[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)

Frank Schlimbach llvmlistbot at llvm.org
Fri Feb 28 09:41:22 PST 2025


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/129048

>From 2e59db0e7aec204f74a1e385fd618f8a2b466a53 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 17:43:25 +0100
Subject: [PATCH 1/6] type converter for ShardingType, allowing returning a
 \!mesh.sharding

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 271 +++++++++++++++-
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 294 ++++++++++++------
 2 files changed, 457 insertions(+), 108 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 48b3764d520c2..5d230a24a6316 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -14,8 +14,13 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -25,6 +30,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "mesh-to-mpi"
@@ -36,10 +42,32 @@ namespace mlir {
 } // namespace mlir
 
 using namespace mlir;
-using namespace mlir::mesh;
+using namespace mesh;
 
 namespace {
-// Create operations converting a linear index to a multi-dimensional index
+/// Convert vec of OpFoldResults (ints) into vector of Values.
+static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
+                                           llvm::ArrayRef<int64_t> statics,
+                                           ValueRange dynamics,
+                                           Type type = Type()) {
+  SmallVector<Value> values;
+  auto dyn = dynamics.begin();
+  Type i64 = b.getI64Type();
+  if (!type)
+    type = i64;
+  assert(i64 == type || b.getIndexType() == type);
+  for (auto s : statics) {
+    values.emplace_back(
+        ShapedType::isDynamic(s)
+            ? *(dyn++)
+            : b.create<arith::ConstantOp>(loc, type,
+                                          i64 == type ? b.getI64IntegerAttr(s)
+                                                      : b.getIndexAttr(s)));
+  }
+  return values;
+};
+
+/// Create operations converting a linear index to a multi-dimensional index.
 static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
                                              Value linearIndex,
                                              ValueRange dimensions) {
@@ -72,6 +100,152 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
   return linearIndex;
 }
 
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+    if (!shardOp)
+      return failure();
+    auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+    if (!shardingOp)
+      return failure();
+
+    rewriter.replaceOp(op, shardingOp.getResult());
+    return success();
+  }
+};
+
+/// Convert a sharding op to a tuple of tensors of its components
+///   (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto splitAxes = op.getSplitAxes().getAxes();
+    int64_t maxNAxes = 0;
+    for (auto axes : splitAxes) {
+      maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+    }
+
+    // To hold the split axes, create empty 2d tensor with shape
+    // {splitAxes.size(), max-size-of-split-groups}.
+    // Set trailing elements for smaller split-groups to -1.
+    Location loc = op.getLoc();
+    auto i16 = rewriter.getI16Type();
+    auto i64 = rewriter.getI64Type();
+    int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+    Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+    auto attr = IntegerAttr::get(i16, 0xffff);
+    Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+    resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
+                       .getResult(0);
+
+    // explicitly write values into tensor row by row
+    int64_t strides[] = {1, 1};
+    int64_t nSplits = 0;
+    ValueRange empty = {};
+    for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+      int64_t size = axes.size();
+      if (size > 0)
+        ++nSplits;
+      int64_t offs[] = {(int64_t)i, 0};
+      int64_t sizes[] = {1, size};
+      auto tensorType = RankedTensorType::get({size}, i16);
+      auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+      auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
+      resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
+          loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+    }
+
+    // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+    // Store the halo sizes in the tensor.
+    auto haloSizes =
+        getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
+                         adaptor.getDynamicHaloSizes());
+    auto type = RankedTensorType::get({nSplits, 2}, i64);
+    Value resHaloSizes =
+        haloSizes.empty()
+            ? rewriter
+                  .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
+                                           i64)
+                  .getResult()
+            : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+                  .getResult();
+
+    // To hold sharded dims offsets, create Tensor with shape {nSplits,
+    // maxSplitSize+1}. Store the offsets in the tensor but set trailing
+    // elements for smaller split-groups to -1. Computing the max size of the
+    // split groups needs using collectiveProcessGroupSize (which needs the
+    // MeshOp)
+    Value resOffsets;
+    if (adaptor.getStaticShardedDimsOffsets().empty()) {
+      resOffsets = rewriter.create<tensor::EmptyOp>(
+          loc, std::array<int64_t, 2>{0, 0}, i64);
+    } else {
+      SymbolTableCollection symbolTableCollection;
+      auto meshOp = getMesh(op, symbolTableCollection);
+      auto maxSplitSize = 0;
+      for (auto axes : splitAxes) {
+        int64_t splitSize =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        assert(splitSize != ShapedType::kDynamic);
+        maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
+      }
+      assert(maxSplitSize);
+      ++maxSplitSize; // add one for the total size
+
+      resOffsets = rewriter.create<tensor::EmptyOp>(
+          loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+      Value zero = rewriter.create<arith::ConstantOp>(
+          loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
+      resOffsets =
+          rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+      auto offsets =
+          getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
+                           adaptor.getDynamicShardedDimsOffsets());
+      int64_t curr = 0;
+      for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+        int64_t splitSize =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
+        ++splitSize; // add one for the total size
+        ArrayRef<Value> values(&offsets[curr], splitSize);
+        Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+        int64_t offs[] = {(int64_t)i, 0};
+        int64_t sizes[] = {1, splitSize};
+        resOffsets = rewriter.create<tensor::InsertSliceOp>(
+            loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
+        curr += splitSize;
+      }
+    }
+
+    // return a tuple of tensors as defined by type converter
+    SmallVector<Type> resTypes;
+    if (failed(getTypeConverter()->convertType(op.getResult().getType(),
+                                               resTypes)))
+      return failure();
+
+    resSplitAxes =
+        rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
+    resHaloSizes =
+        rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
+    resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
+
+    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+        op, TupleType::get(op.getContext(), resTypes),
+        ValueRange{resSplitAxes, resHaloSizes, resOffsets});
+
+    return success();
+  }
+};
+
 struct ConvertProcessMultiIndexOp
     : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -419,14 +593,95 @@ struct ConvertMeshToMPIPass
 
   /// Run the dialect converter on the module.
   void runOnOperation() override {
-    auto *ctx = &getContext();
-    mlir::RewritePatternSet patterns(ctx);
+    uint64_t worldRank = -1;
+    // Try to get DLTI attribute for MPI:comm_world_rank
+    // If found, set worldRank to the value of the attribute.
+    {
+      auto dltiAttr =
+          dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
+      if (succeeded(dltiAttr)) {
+        if (!isa<IntegerAttr>(dltiAttr.value())) {
+          getOperation()->emitError()
+              << "Expected an integer attribute for MPI:comm_world_rank";
+          return signalPassFailure();
+        }
+        worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
+      }
+    }
 
-    patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
-                    ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
-        ctx);
+    auto *ctxt = &getContext();
+    RewritePatternSet patterns(ctxt);
+    ConversionTarget target(getContext());
+
+    // Define a type converter to convert mesh::ShardingType,
+    // mostly for use in return operations.
+    TypeConverter typeConverter;
+    typeConverter.addConversion([](Type type) { return type; });
+
+    // convert mesh::ShardingType to a tuple of RankedTensorTypes
+    typeConverter.addConversion(
+        [](ShardingType type,
+           SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+          auto i16 = IntegerType::get(type.getContext(), 16);
+          auto i64 = IntegerType::get(type.getContext(), 64);
+          std::array<int64_t, 2> shp{ShapedType::kDynamic,
+                                     ShapedType::kDynamic};
+          results.emplace_back(RankedTensorType::get(shp, i16));
+          results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+          results.emplace_back(RankedTensorType::get(shp, i64));
+          return success();
+        });
+
+    // To 'extract' components, a UnrealizedConversionCastOp is expected
+    // to define the input
+    typeConverter.addTargetMaterialization(
+        [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+            Location loc) {
+          // Expecting a single input.
+          if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
+            return SmallVector<Value>();
+          auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
+          // Expecting an UnrealizedConversionCastOp.
+          if (!castOp)
+            return SmallVector<Value>();
+          // Fill a vector with elements of the tuple/castOp.
+          SmallVector<Value> results;
+          for (auto oprnd : castOp.getInputs()) {
+            if (!isa<RankedTensorType>(oprnd.getType()))
+              return SmallVector<Value>();
+            results.emplace_back(oprnd);
+          }
+          return results;
+        });
 
-    (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
+    // No mesh dialect should left after conversion...
+    target.addIllegalDialect<mesh::MeshDialect>();
+    // ...except the global MeshOp
+    target.addLegalOp<mesh::MeshOp>();
+    // Allow all the stuff that our patterns will convert to
+    target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
+                           arith::ArithDialect, tensor::TensorDialect,
+                           bufferization::BufferizationDialect,
+                           linalg::LinalgDialect, memref::MemRefDialect>();
+    // Make sure the function signature, calls etc. are legal
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+        [&](Operation *op) { return typeConverter.isLegal(op); });
+
+    patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
+                 ConvertProcessMultiIndexOp, ConvertGetShardingOp,
+                 ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
+    // ConvertProcessLinearIndexOp accepts an optional worldRank
+    patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
+
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+        patterns, typeConverter);
+    populateCallOpTypeConversionPattern(patterns, typeConverter);
+    populateReturnOpTypeConversionPattern(patterns, typeConverter);
+
+    (void)applyPartialConversion(getOperation(), target, std::move(patterns));
   }
 };
 
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index c1aef97438bd5..90bd80472c2b9 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -103,106 +103,200 @@ func.func @update_halo_1d_first(
 }
 
 // -----
-mesh.mesh @mesh0(shape = 3x4x5)
-memref.global constant @static_mpi_rank : memref<index> = dense<24>
-// CHECK-LABEL: func @update_halo_3d
-func.func @update_halo_3d(
-  // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
-  %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
-  // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
-  // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
-  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
-  // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
-  // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
-  // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
-  // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
-  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
-  // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-  // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
-  // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
-  // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
-  // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
-  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
-  // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
-  // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
-  // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
-  // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
-  // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
-  // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
-  // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
-  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
-  // CHECK: return [[varg0]] : memref<120x120x120xi8>
-  return %res : memref<120x120x120xi8>
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
+  mesh.mesh @mesh0(shape = 4)
+  // CHECK-LABEL: func @update_halo_1d_with_zero
+  func.func @update_halo_1d_with_zero (
+    // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+    %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+    // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
+    // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+    // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
+    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
+    // CHECK-SAME: to memref<2x120x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
+    // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
+    %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
+    // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
+    return %res : memref<120x120x120xi8>
+  }
+}
+
+// -----
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+  mesh.mesh @mesh0(shape = 3x4x5)
+  // CHECK-LABEL: func @update_halo_3d
+  func.func @update_halo_3d(
+    // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+    %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+    // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
+    // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+    // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+    // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+    // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+    // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+    // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+    // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+    // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+    // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+    // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+    // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+    // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+    // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+    // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+    // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+    // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+    // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+    // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+    // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+    %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
+    // CHECK: return [[varg0]] : memref<120x120x120xi8>
+    return %res : memref<120x120x120xi8>
+  }
+
+  // CHECK-LABEL: func @update_halo_3d_tensor
+  func.func @update_halo_3d_tensor(
+    // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
+    %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
+    // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
+    // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+    // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+    // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+    // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
+    // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+    // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+    // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+    // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+    // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+    // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+    // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+    // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+    // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+    // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+    // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+    // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+    // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+    // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+    // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+    // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+    // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+    // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+    // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
+    %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
+    // CHECK: return [[v1]] : tensor<120x120x120xi8>
+    return %res : tensor<120x120x120xi8>
+  }
+}
+
+// -----
+mesh.mesh @mesh0(shape = 2x2x4)
+// CHECK-LABEL: func.func @return_sharding(
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding
+  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+  // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor<?x?xi16>
+  // CHECK-NEXT: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+  return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding
+}
+
+// CHECK-LABEL: func.func @return_sharding_halos(
+// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding
+  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
+  // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+  // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor<?x?xi16>
+  // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+  return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding
 }
 
-// CHECK-LABEL: func @update_halo_3d_tensor
-func.func @update_halo_3d_tensor(
-  // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
-  %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
-  // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
-  // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
-  // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
-  // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
-  // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-  // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
-  // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
-  // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
-  // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-  // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
-  // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
-  // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
-  // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-  // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-  // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-  // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
-  // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
-  // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
-  // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
-  // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
-  // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
-  // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
-  // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
-  // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
-  // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
-  // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
-  // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
-  // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
-  // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
-  // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
-  // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
-  // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
-  // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8>
-  %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
-  // CHECK: return [[v1]] : tensor<120x120x120xi8>
-  return %res : tensor<120x120x120xi8>
+// CHECK-LABEL: func.func @return_sharding_offs(
+// CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !mesh.sharding) {
+  %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding
+  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
+  // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
+  // CHECK-NEXT: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
+  // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK-NEXT: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+  // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+  // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64>
+  // CHECK-NEXT: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64>
+  // CHECK-NEXT: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
+  // CHECK-NEXT: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
+  // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor<?x?xi16>
+  // CHECK-NEXT: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
+  // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+  return %arg0, %sharding : tensor<?x?xf32>, !mesh.sharding
 }

>From 30e6cd18a84cbc6dccfef0857070403f9015fa70 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 17:48:07 +0100
Subject: [PATCH 2/6] fixing halo send/rdv direction, not communication if no
 halo

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 54 +++++++++++++++------
 1 file changed, 38 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 5d230a24a6316..ee05276eb8366 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -410,42 +410,47 @@ struct ConvertUpdateHaloOp
     // local data. Because subviews and halos can have mixed dynamic and static
     // shapes, OpFoldResults are used whenever possible.
 
+    auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
+                                    adaptor.getHaloSizes(), rewriter);
+    if (haloSizes.empty()) {
+      // no halos -> nothing to do
+      rewriter.replaceOp(op, adaptor.getDestination());
+      return success();
+    }
+
     SymbolTableCollection symbolTableCollection;
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
 
     // convert a OpFoldResult into a Value
     auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
       if (auto value = dyn_cast<Value>(v))
         return value;
-      return rewriter.create<::mlir::arith::ConstantOp>(
+      return rewriter.create<arith::ConstantOp>(
           loc, rewriter.getIndexAttr(
                    cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
     };
 
-    auto dest = op.getDestination();
+    auto dest = adaptor.getDestination();
     auto dstShape = cast<ShapedType>(dest.getType()).getShape();
     Value array = dest;
     if (isa<RankedTensorType>(array.getType())) {
       // If the destination is a memref, we need to cast it to a tensor
       auto tensorType = MemRefType::get(
           dstShape, cast<ShapedType>(array.getType()).getElementType());
-      array = rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array)
-                  .getResult();
+      array =
+          rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array);
     }
     auto rank = cast<ShapedType>(array.getType()).getRank();
-    auto opSplitAxes = op.getSplitAxes().getAxes();
-    auto mesh = op.getMesh();
+    auto opSplitAxes = adaptor.getSplitAxes().getAxes();
+    auto mesh = adaptor.getMesh();
     auto meshOp = getMesh(op, symbolTableCollection);
-    auto haloSizes =
-        getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
     // subviews need Index values
     for (auto &sz : haloSizes) {
-      if (auto value = dyn_cast<Value>(sz)) {
+      if (auto value = dyn_cast<Value>(sz))
         sz =
             rewriter
                 .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
                 .getResult();
-      }
     }
 
     // most of the offset/size/stride data is the same for all dims
@@ -530,8 +535,8 @@ struct ConvertUpdateHaloOp
                                   : haloSizes[currHaloDim * 2];
         // Check if we need to send and/or receive
         // Processes on the mesh borders have only one neighbor
-        auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
-        auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+        auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+        auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
         auto hasFrom = rewriter.create<arith::CmpIOp>(
             loc, arith::CmpIPredicate::sge, from, zero);
         auto hasTo = rewriter.create<arith::CmpIOp>(
@@ -564,8 +569,25 @@ struct ConvertUpdateHaloOp
         offsets[dim] = orgOffset;
       };
 
-      genSendRecv(false);
-      genSendRecv(true);
+      auto get_i32val = [&](OpFoldResult &v) {
+        return isa<Value>(v)
+                   ? cast<Value>(v)
+                   : rewriter.create<arith::ConstantOp>(
+                         loc,
+                         rewriter.getI32IntegerAttr(
+                             cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
+      };
+
+      for (int i = 0; i < 2; ++i) {
+        Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);
+        auto hasSize = rewriter.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::sgt, haloSz, zero);
+        rewriter.create<scf::IfOp>(loc, hasSize,
+                                   [&](OpBuilder &builder, Location loc) {
+                                     genSendRecv(i > 0);
+                                     builder.create<scf::YieldOp>(loc);
+                                   });
+      }
 
       // the shape for lower dims include higher dims' halos
       dimSizes[dim] = shape[dim];
@@ -583,7 +605,7 @@ struct ConvertUpdateHaloOp
                                  loc, op.getResult().getType(), array,
                                  /*restrict=*/true, /*writable=*/true));
     }
-    return mlir::success();
+    return success();
   }
 };
 

>From 3e7d98dcdec057166e26d5b3303d53250ec49bc9 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 17:52:26 +0100
Subject: [PATCH 3/6] Fixing definition of ShardShapeOp and lowering it to MPI

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  22 ++-
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 154 +++++++++++++++++-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  15 +-
 .../Extensions/MeshShardingExtensions.cpp     |  19 ++-
 .../MeshToMPI/convert-shardshape-to-mpi.mlir  |  75 +++++++++
 mlir/test/Dialect/Mesh/ops.mlir               |  10 +-
 .../test/Dialect/Tensor/mesh-spmdization.mlir |  11 +-
 7 files changed, 269 insertions(+), 37 deletions(-)
 create mode 100644 mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 031e6f63bcb42..f59c4c4c67517 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -345,24 +345,32 @@ def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
   }];
 }
 
-def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
-  let summary = "Get the shard shape of a given process/device.";
+def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
+    Pure, AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Get the shard shape for a given process/device.";
   let description = [{
-    The device/process id is a linearized id of the device/process in the mesh.
+    The device/process id is a multi-index of the device/process in the mesh.
     This operation might be used during spmdization when the shard shape depends
     on (non-constant) values used in `mesh.sharding`.
   }];
   let arguments = (ins
-    DenseI64ArrayAttr:$shape,
+    DenseI64ArrayAttr:$dims,
+    Variadic<Index>:$dims_dynamic,
     Mesh_Sharding:$sharding,
-    Index:$device
+    DenseI64ArrayAttr:$device,
+    Variadic<Index>:$device_dynamic
   );
   let results = (outs Variadic<Index>:$result);
   let assemblyFormat = [{
-      custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
+      `dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
+      `sharding` `=` $sharding
+      `device` `=` custom<DynamicIndexList>($device_dynamic, $device)
+      attr-dict `:` type(results)
   }];
   let builders = [
-    OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
+    OpBuilder<(ins "ArrayRef<int64_t>":$dims, "ArrayRef<Value>":$dims_dyn, "Value":$sharding, "ValueRange":$device)>
   ];
 }
 
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index ee05276eb8366..3f95b803bc2ed 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -380,23 +380,159 @@ struct ConvertNeighborsLinearIndicesOp
         [&](OpBuilder &builder, Location loc) {
           SmallVector<Value> tmp = mIdx;
           tmp[axes[0]] =
-              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
-                  .getResult();
+              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
           builder.create<scf::YieldOp>(
               loc, multiToLinearIndex(loc, rewriter, tmp, dims));
         });
     rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
-    return mlir::success();
+    return success();
   }
 };
 
-struct ConvertUpdateHaloOp
-    : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::UpdateHaloOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
+    if (!sharding) {
+      return op->emitError()
+             << "Expected SharingOp as defining op for sharding"
+             << " but found " << adaptor.getSharding()[0].getDefiningOp();
+    }
+
+    // Compute the sharded shape by applying the sharding to the input shape.
+    // Without shardedDimsOffsets in the sharding, the shard shape is computed
+    // by dividing the dimension size by the number of shards in that dimension
+    // (which is given by the size of the mesh axes provided in split-axes).
+    // Odd elements get distributed to trailing shards.
+    // If a shardedDimsOffsets is provided, the shard shape is computed by
+    // subtracting the offset of the current shard from the offset of the next
+    // shard.
+
+    Location loc = op.getLoc();
+    Type index = rewriter.getIndexType();
+
+    // This is a 1:N conversion because the sharding op is a 1:3 conversion.
+    // The operands in the adaptor are a vector<ValeRange>. For dims and device
+    // we have a 1:1 conversion.
+    // For simpler access fill a vector with the dynamic dims.
+    SmallVector<Value> dynDims, dynDevice;
+    for (auto dim : adaptor.getDimsDynamic()) {
+      // type conversion should be 1:1 for ints
+      assert(dim.size() == 1);
+      dynDims.emplace_back(dim[0]);
+    }
+    // same for device
+    for (auto device : adaptor.getDeviceDynamic()) {
+      assert(device.size() == 1);
+      dynDevice.emplace_back(device[0]);
+    }
+
+    // To keep the code simple, convert dims/device to values when they are
+    // attributes. Count on canonicalization to fold static values.
+    auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
+    auto multiIdx =
+        getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
+
+    // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
+    SymbolTableCollection symbolTableCollection;
+    auto meshOp = getMesh(sharding, symbolTableCollection);
+    // For now we only support static mesh shapes
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return failure();
+
+    auto splitAxes = sharding.getSplitAxes().getAxes();
+    // shardedDimsOffsets are optional and might be Values (not attributes).
+    // Also, the shardId might be dynamic which means the position in the
+    // shardedDimsOffsets is not statically known. Create a tensor of the
+    // shardedDimsOffsets and later extract the offsets for computing the
+    // local shard-size.
+    Value shardedDimsOffs;
+    {
+      auto tmp = getMixedAsValues(
+          rewriter, loc, sharding.getStaticShardedDimsOffsets(),
+          sharding.getDynamicShardedDimsOffsets(), index);
+      if (!tmp.empty())
+        shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
+            loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
+    }
+
+    // With static mesh shape the sizes of the split axes are known.
+    // Hence the start/pos for each split axes in shardDimsOffsets can be
+    // computed statically.
+    int64_t pos = 0;
+    SmallVector<Value> shardShape;
+    Value zero =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
+    Value one =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
+
+    // Iterate over the dimensions of the tensor shape, get their split Axes,
+    // and compute the sharded shape.
+    for (auto [i, dim] : llvm::enumerate(shape)) {
+      // Trailing dimensions might not be annotated.
+      if (i < splitAxes.size() && !splitAxes[i].empty()) {
+        auto axes = splitAxes[i];
+        // The current dimension might not be sharded.
+        // Create a value from the static position in shardDimsOffsets.
+        Value posVal =
+            rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
+        // Get the index of the local shard in the mesh axis.
+        Value idx = multiIdx[axes[0]];
+        auto _numShards =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        if (shardedDimsOffs) {
+          // If sharded dims offsets are provided, use them to compute the
+          // sharded shape.
+          if (axes.size() > 1) {
+            return op->emitError() << "Only single axis sharding is "
+                                   << "supported for each dimension.";
+          }
+          idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
+          // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
+          Value off =
+              rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
+          idx = rewriter.create<arith::AddIOp>(loc, idx, one);
+          Value nextOff =
+              rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
+          Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
+          shardShape.emplace_back(sz);
+        } else {
+          auto numShards = rewriter.create<arith::ConstantOp>(
+              loc, rewriter.getIndexAttr(_numShards));
+          // Compute shard dim size by distributing odd elements to trailing
+          // shards:
+          // sz = dim / numShards
+          //      + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
+          Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShards);
+          Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShards);
+          sz1 = rewriter.create<arith::SubIOp>(loc, numShards, sz1);
+          auto cond = rewriter.create<arith::CmpIOp>(
+              loc, arith::CmpIPredicate::sge, idx, sz1);
+          Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
+          sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
+          shardShape.emplace_back(sz);
+        }
+        pos += _numShards + 1; // add one for the total size.
+      } // else no sharding if split axis is empty or no split axis
+      // If no size was added -> no sharding in this dimension.
+      if (shardShape.size() <= i)
+        shardShape.emplace_back(dim);
+    }
+    assert(shardShape.size() == shape.size());
+    rewriter.replaceOp(op, shardShape);
+    return success();
+  }
+};
+
+struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // The input/output memref is assumed to be in C memory order.
     // Halos are exchanged as 2 blocks per dimension (one for each side: down
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 304ede195c762..3e9f86fde64f3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -831,12 +831,19 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
 // mesh.shard_shape
 //===----------------------------------------------------------------------===//
 
+void ShardShapeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult()[0], "shard_shape");
+}
+
 void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
                          ::mlir::OperationState &odsState,
-                         ::llvm::ArrayRef<int64_t> shape,
-                         ::mlir::Value sharding, ::mlir::Value device) {
-  SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
-  build(odsBuilder, odsState, resType, shape, sharding, device);
+                         ::llvm::ArrayRef<int64_t> dims,
+                         ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
+                         ::mlir::ValueRange device) {
+  SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
+  build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
+        SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index b2acbf20b3fb9..b3d69eb5e1a23 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -50,10 +50,10 @@ struct CreatorOpShardingInterface
                         IRMapping &spmdizationMap,
                         SymbolTableCollection &symbolTable,
                         OpBuilder &builder) const {
-    auto shardType = cast<ShapedType>(mesh::shardType(
-        op->getResult(0).getType(),
-        mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable),
-        resultShardings[0]));
+    auto mesh =
+        mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+    auto shardType = cast<ShapedType>(
+        mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
     Operation *newOp = nullptr;
     // if the sharding introduces a new dynamic dimension, we take it from
     // the dynamic sharding info. For now bail out if it's not
@@ -66,18 +66,19 @@ struct CreatorOpShardingInterface
       assert(oldType.getRank() == shardType.getRank());
       int currOldOprndNum = -1;
       mesh::ShardShapeOp shapeForDevice;
-      Value device;
+      ValueRange device;
       Operation *newSharding = nullptr;
       for (auto i = 0; i < oldType.getRank(); ++i) {
         if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
           if (!newSharding) {
             newSharding =
                 builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
-            device = builder.create<mesh::ProcessLinearIndexOp>(
-                op->getLoc(), resultShardings[0].getMesh());
+            device =
+                builder.create<mesh::ProcessMultiIndexOp>(op->getLoc(), mesh)
+                    .getResults();
             shapeForDevice = builder.create<mesh::ShardShapeOp>(
-                op->getLoc(), oldType.getShape(), newSharding->getResult(0),
-                device);
+                op->getLoc(), oldType.getShape(), spmdizedOperands,
+                newSharding->getResult(0), device);
           }
           newOperands.emplace_back(shapeForDevice.getResult()[i]);
         } else if (oldType.isDynamicDim(i)) {
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
new file mode 100644
index 0000000000000..1cc848333ced2
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s --convert-mesh-to-mpi -canonicalize | FileCheck %s
+
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+
+  // CHECK: mesh.mesh @mesh0
+  mesh.mesh @mesh0(shape = 3x4x5)
+  
+  // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @mesh0
+
+  // all shards are equal
+  // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
+  func.func @shard_shape_equal() -> (index, index, index) {
+    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %c9 = arith.constant 9 : index
+    %c12 = arith.constant 12 : index
+    // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+    %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+    return %1#0, %1#1, %1#2 : index, index, index
+  }
+
+  // last shard in last dim gets an extra element
+  // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
+  func.func @shard_shape_odd_1() -> (index, index, index) {
+    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %c9 = arith.constant 9 : index
+    %c12 = arith.constant 12 : index
+    // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+    // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+    %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
+    return %1#0, %1#1, %1#2 : index, index, index
+  }
+
+  // all except first shard in second dim get an extra element
+  // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
+  func.func @shard_shape_odd_2() -> (index, index, index) {
+    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %c9 = arith.constant 9 : index
+    // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+    %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+    return %1#0, %1#1, %1#2 : index, index, index
+  }
+
+  // all except first shard in first dim get an extra element
+  // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
+  func.func @shard_shape_odd_3() -> (index, index, index) {
+    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+    // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+    %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
+    return %1#0, %1#1, %1#2 : index, index, index
+  }
+
+  // extract from sharded_dims_offsets
+  // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+  func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+    %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
+        sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
+    %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+    %c9 = arith.constant 9 : index
+    %c12 = arith.constant 12 : index
+    // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+    // CHECK: [[vc2:%.*]] = arith.constant 2 : index
+    %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+    // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
+    return %1#0, %1#1, %1#2 : index, index, index
+  }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 43a75bf3d8040..3d133f2255772 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -157,10 +157,12 @@ func.func @mesh_shard_shape() {
   %c3 = arith.constant 3 : index
   // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
   %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
-  // CHECK-NEXT: mesh.shard_shape 8x? %[[S]] %[[C3]] : index, index
-  %shp:2 = mesh.shard_shape 8x? %s %c3 : index, index
-  // CHECK-NEXT: mesh.shard_shape 8x4 %[[S]] %[[C3]] : index, index
-  %shp1:2 = mesh.shard_shape 8x4 %s %c3 : index, index
+  // CHECK-NEXT: mesh.shard_shape dims = [8, %[[C3]]
+  // CHECK-SAME: ] sharding = %[[S]] device = [%[[C3]]
+  // CHECK-SAME: ] : index, index
+  %shp:2 = mesh.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index
+  // CHECK-NEXT: mesh.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index
+  %shp1:2 = mesh.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index
   return
 }
 
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 5443eea83aa2d..01cf5972177f4 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -10,8 +10,9 @@ func.func @tensor_empty_static_sharded_dims_offsets() -> () {
   %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
   // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
-  // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
-  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
+  // CHECK:  %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
+  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+  // CHECK-SAME: ] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
 
   return
@@ -24,8 +25,10 @@ func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
   %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
   %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
   // CHECK:  %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
-  // CHECK:  %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
-  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
+  // CHECK:  %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
+  // CHECK:  %[[V0:.*]]:2 = mesh.shard_shape dims = [8, %[[A0]]
+  // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+  // CHECK-SAME: ] : index, index
   // CHECK:  tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
 
   return

>From 3428e5064fd30e30598f74aeb9cb5b6b7c4aec45 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 17:55:27 +0100
Subject: [PATCH 4/6] using DLTI instead of global symbol for static rank in
 comm_world

---
 mlir/include/mlir/Conversion/Passes.td        |  8 ++--
 mlir/lib/Conversion/MeshToMPI/CMakeLists.txt  |  1 +
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 39 ++++++++++---------
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 33 ++++++++--------
 4 files changed, 42 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..6074e0e8d822c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
   let description = [{
     This pass converts communication operations from the Mesh dialect to the
     MPI dialect.
-    If it finds a global named "static_mpi_rank" it will use that splat value
-    instead of calling MPI_Comm_rank. This allows optimizations like constant
-    shape propagation and fusion because shard/partition sizes depend on the
-    rank.
+    If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
+    use that integer value instead of calling MPI_Comm_rank. This allows
+    optimizations like constant shape propagation and fusion because
+    shard/partition sizes depend on the rank.
   }];
   let dependentDialects = [
     "memref::MemRefDialect",
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
index 95815a683f6d6..15560aa61e145 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
   Core
 
   LINK_LIBS PUBLIC
+  MLIRDLTIDialect
   MLIRFuncDialect
   MLIRIR
   MLIRLinalgTransforms
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 3f95b803bc2ed..ed8fe80b18efe 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -284,32 +284,33 @@ struct ConvertProcessMultiIndexOp
     }
 
     rewriter.replaceOp(op, mIdx);
-    return mlir::success();
+    return success();
   }
 };
 
-struct ConvertProcessLinearIndexOp
-    : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
+class ConvertProcessLinearIndexOp
+    : public OpConversionPattern<ProcessLinearIndexOp> {
+  int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+public:
+  using OpConversionPattern::OpConversionPattern;
 
-    // Finds a global named "static_mpi_rank" it will use that splat value.
-    // Otherwise it defaults to mpi.comm_rank.
+  // Constructor accepting worldRank
+  ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
+                              MLIRContext *context, int64_t worldRank_ = -1)
+      : OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
 
-    auto loc = op.getLoc();
-    auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
-    if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
-            op, rankOpName)) {
-      if (auto initTnsr = globalOp.getInitialValueAttr()) {
-        auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
-        rewriter.replaceOp(op,
-                           rewriter.create<arith::ConstantIndexOp>(loc, val));
-        return mlir::success();
-      }
+  LogicalResult
+  matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
+      return success();
     }
+
+    // Otherwise call create mpi::CommRankOp
     auto rank =
         rewriter
             .create<mpi::CommRankOp>(
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 90bd80472c2b9..4e60c6f0d4e44 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -60,23 +60,24 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
 
 // -----
 // CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 3x4x5)
-memref.global constant @static_mpi_rank : memref<index> = dense<24>
-func.func @process_multi_index() -> (index, index, index) {
-  // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
-  // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
-  %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
-  // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
-  return %0#0, %0#1, %0#2 : index, index, index
-}
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+  mesh.mesh @mesh0(shape = 3x4x5)
+  func.func @process_multi_index() -> (index, index, index) {
+    // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+    // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+    // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
+    return %0#0, %0#1, %0#2 : index, index, index
+  }
 
-// CHECK-LABEL: func @process_linear_index
-func.func @process_linear_index() -> index {
-  // CHECK: %[[c24:.*]] = arith.constant 24 : index
-  %0 = mesh.process_linear_index on @mesh0 : index
-  // CHECK: return %[[c24]] : index
-  return %0 : index
+  // CHECK-LABEL: func @process_linear_index
+  func.func @process_linear_index() -> index {
+    // CHECK: %[[c24:.*]] = arith.constant 24 : index
+    %0 = mesh.process_linear_index on @mesh0 : index
+    // CHECK: return %[[c24]] : index
+    return %0 : index
+  }
 }
 
 // -----

>From ce882e60d9933008fcc61b906eab25aeaa05b529 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 17:57:23 +0100
Subject: [PATCH 5/6] cleanup, aligning to conventions

---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 86 ++++++++++-----------
 1 file changed, 39 insertions(+), 47 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index ed8fe80b18efe..49860b96d3685 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -76,23 +76,22 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
 
   for (int i = n - 1; i >= 0; --i) {
     multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
-    if (i > 0) {
+    if (i > 0)
       linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
-    }
   }
 
   return multiIndex;
 }
 
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
 Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
                          ValueRange dimensions) {
 
-  auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
-  auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+  Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
 
   for (int i = multiIndex.size() - 1; i >= 0; --i) {
-    auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+    Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
     linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
     stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
   }
@@ -247,34 +246,32 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
 };
 
 struct ConvertProcessMultiIndexOp
-    : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public OpConversionPattern<ProcessMultiIndexOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // Currently converts its linear index to a multi-dimensional index.
 
     SymbolTableCollection symbolTableCollection;
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     auto meshOp = getMesh(op, symbolTableCollection);
     // For now we only support static mesh shapes
-    if (ShapedType::isDynamicShape(meshOp.getShape())) {
-      return mlir::failure();
-    }
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return failure();
 
     SmallVector<Value> dims;
     llvm::transform(
         meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
           return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
         });
-    auto rank =
-        rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
+    Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
     auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
 
     // optionally extract subset of mesh axes
-    auto axes = op.getAxes();
+    auto axes = adaptor.getAxes();
     if (!axes.empty()) {
       SmallVector<Value> subIndex;
       for (auto axis : axes) {
@@ -319,44 +316,43 @@ class ConvertProcessLinearIndexOp
             .getRank();
     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
                                                     rank);
-    return mlir::success();
+    return success();
   }
 };
 
 struct ConvertNeighborsLinearIndicesOp
-    : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public OpConversionPattern<NeighborsLinearIndicesOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // Computes the neighbors indices along a split axis by simply
     // adding/subtracting 1 to the current index in that dimension.
     // Assigns -1 if neighbor is out of bounds.
 
-    auto axes = op.getSplitAxes();
+    auto axes = adaptor.getSplitAxes();
     // For now only single axis sharding is supported
-    if (axes.size() != 1) {
-      return mlir::failure();
-    }
+    if (axes.size() != 1)
+      return failure();
 
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     SymbolTableCollection symbolTableCollection;
     auto meshOp = getMesh(op, symbolTableCollection);
-    auto mIdx = op.getDevice();
+    auto mIdx = adaptor.getDevice();
     auto orgIdx = mIdx[axes[0]];
     SmallVector<Value> dims;
     llvm::transform(
         meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
           return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
         });
-    auto dimSz = dims[axes[0]];
-    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
-    auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
-    auto atBorder = rewriter.create<arith::CmpIOp>(
+    Value dimSz = dims[axes[0]];
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
+    Value atBorder = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::sle, orgIdx,
-        rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
+        rewriter.create<arith::ConstantIndexOp>(loc, 0));
     auto down = rewriter.create<scf::IfOp>(
         loc, atBorder,
         [&](OpBuilder &builder, Location loc) {
@@ -598,11 +594,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
     // we need the actual shape to compute offsets and sizes
     for (auto i = 0; i < rank; ++i) {
       auto s = dstShape[i];
-      if (ShapedType::isDynamic(s)) {
+      if (ShapedType::isDynamic(s))
         shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
-      } else {
+      else
         shape[i] = rewriter.getIndexAttr(s);
-      }
 
       if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
         ++currHaloDim;
@@ -610,11 +605,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
         offsets[i] = haloSizes[currHaloDim * 2];
 
         // prepare shape and offsets of highest dim's halo exchange
-        auto _haloSz =
-            rewriter
-                .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
-                                       toValue(haloSizes[currHaloDim * 2 + 1]))
-                .getResult();
+        Value _haloSz = rewriter.create<arith::AddIOp>(
+            loc, toValue(haloSizes[currHaloDim * 2]),
+            toValue(haloSizes[currHaloDim * 2 + 1]));
         // the halo shape of lower dims exlude the halos
         dimSizes[i] =
             rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
@@ -625,9 +618,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
     }
 
     auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
-    auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+    auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
     auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
-    auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+    auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
 
     SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
                                        rewriter.getIndexType());
@@ -637,9 +630,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
     // traverse all split axes from high to low dim
     for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
       auto splitAxes = opSplitAxes[dim];
-      if (splitAxes.empty()) {
+      if (splitAxes.empty())
         continue;
-      }
       assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
       // Get the linearized ids of the neighbors (down and up) for the
       // given split

>From 276e5e620af442b93ac1c58dab0aae880708fa7c Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Fri, 28 Feb 2025 10:01:03 +0100
Subject: [PATCH 6/6] Apply suggestions from code review

Co-authored-by: Christian Ulmann <christianulmann at gmail.com>
---
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 119 +++++++++---------
 .../MeshToMPI/convert-shardshape-to-mpi.mlir  |   6 +-
 2 files changed, 63 insertions(+), 62 deletions(-)

diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 49860b96d3685..76b1978e6a025 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -45,7 +45,8 @@ using namespace mlir;
 using namespace mesh;
 
 namespace {
-/// Convert vec of OpFoldResults (ints) into vector of Values.
+/// Converts a vector of OpFoldResults (ints) into vector of Values of the
+/// provided type.
 static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
                                            llvm::ArrayRef<int64_t> statics,
                                            ValueRange dynamics,
@@ -55,14 +56,15 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
   Type i64 = b.getI64Type();
   if (!type)
     type = i64;
-  assert(i64 == type || b.getIndexType() == type);
+  assert((i64 == type || b.getIndexType() == type) &&
+         "expected an i64 or an intex type");
   for (auto s : statics) {
-    values.emplace_back(
-        ShapedType::isDynamic(s)
-            ? *(dyn++)
-            : b.create<arith::ConstantOp>(loc, type,
-                                          i64 == type ? b.getI64IntegerAttr(s)
-                                                      : b.getIndexAttr(s)));
+    if (s == ShapedType::kDynamic) {
+      values.emplace_back(*(dyn++));
+    } else {
+      TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
+      values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
+    }
   }
   return values;
 };
@@ -129,9 +131,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto splitAxes = op.getSplitAxes().getAxes();
     int64_t maxNAxes = 0;
-    for (auto axes : splitAxes) {
+    for (auto axes : splitAxes)
       maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
-    }
 
     // To hold the split axes, create empty 2d tensor with shape
     // {splitAxes.size(), max-size-of-split-groups}.
@@ -139,23 +140,24 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
     Location loc = op.getLoc();
     auto i16 = rewriter.getI16Type();
     auto i64 = rewriter.getI64Type();
-    int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+    std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
+                                    maxNAxes};
     Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
-    auto attr = IntegerAttr::get(i16, 0xffff);
+    auto attr = IntegerAttr::get(i16, -1);
     Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
     resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
                        .getResult(0);
 
     // explicitly write values into tensor row by row
-    int64_t strides[] = {1, 1};
+    std::array<int64_t, 2> strides = {1, 1};
     int64_t nSplits = 0;
     ValueRange empty = {};
     for (auto [i, axes] : llvm::enumerate(splitAxes)) {
       int64_t size = axes.size();
       if (size > 0)
         ++nSplits;
-      int64_t offs[] = {(int64_t)i, 0};
-      int64_t sizes[] = {1, size};
+      std::array<int64_t, 2> offs = {(int64_t)i, 0};
+      std::array<int64_t, 2> sizes = {1, size};
       auto tensorType = RankedTensorType::get({size}, i16);
       auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
       auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
@@ -165,7 +167,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
 
     // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
     // Store the halo sizes in the tensor.
-    auto haloSizes =
+    SmallVector<Value> haloSizes =
         getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
                          adaptor.getDynamicHaloSizes());
     auto type = RankedTensorType::get({nSplits, 2}, i64);
@@ -190,7 +192,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
     } else {
       SymbolTableCollection symbolTableCollection;
       auto meshOp = getMesh(op, symbolTableCollection);
-      auto maxSplitSize = 0;
+      int64_t maxSplitSize = 0;
       for (auto axes : splitAxes) {
         int64_t splitSize =
             collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
@@ -206,7 +208,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
           loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
       resOffsets =
           rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
-      auto offsets =
+      SmallVector<Value> offsets =
           getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
                            adaptor.getDynamicShardedDimsOffsets());
       int64_t curr = 0;
@@ -217,8 +219,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
         ++splitSize; // add one for the total size
         ArrayRef<Value> values(&offsets[curr], splitSize);
         Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
-        int64_t offs[] = {(int64_t)i, 0};
-        int64_t sizes[] = {1, splitSize};
+        std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
+        std::array<int64_t, 2> sizes = {1, splitSize};
         resOffsets = rewriter.create<tensor::InsertSliceOp>(
             loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
         curr += splitSize;
@@ -275,9 +277,9 @@ struct ConvertProcessMultiIndexOp
     if (!axes.empty()) {
       SmallVector<Value> subIndex;
       for (auto axis : axes) {
-        subIndex.push_back(mIdx[axis]);
+        subIndex.emplace_back(mIdx[axis]);
       }
-      mIdx = subIndex;
+      mIdx = std::move(subIndex);
     }
 
     rewriter.replaceOp(op, mIdx);
@@ -294,8 +296,8 @@ class ConvertProcessLinearIndexOp
 
   // Constructor accepting worldRank
   ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
-                              MLIRContext *context, int64_t worldRank_ = -1)
-      : OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
+                              MLIRContext *context, int64_t worldRank = -1)
+      : OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
 
   LogicalResult
   matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
@@ -308,12 +310,11 @@ class ConvertProcessLinearIndexOp
     }
 
     // Otherwise call create mpi::CommRankOp
-    auto rank =
-        rewriter
-            .create<mpi::CommRankOp>(
-                op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()),
+    auto rank = rewriter
+                    .create<mpi::CommRankOp>(
+                        loc, TypeRange{mpi::RetvalType::get(op->getContext()),
                                        rewriter.getI32Type()})
-            .getRank();
+                    .getRank();
     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
                                                     rank);
     return success();
@@ -400,11 +401,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
     }
 
     // Compute the sharded shape by applying the sharding to the input shape.
-    // Without shardedDimsOffsets in the sharding, the shard shape is computed
-    // by dividing the dimension size by the number of shards in that dimension
-    // (which is given by the size of the mesh axes provided in split-axes).
-    // Odd elements get distributed to trailing shards.
-    // If a shardedDimsOffsets is provided, the shard shape is computed by
+    // If shardedDimsOffsets is not defined in the sharding, the shard shape is
+    // computed by dividing the dimension size by the number of shards in that
+    // dimension (which is given by the size of the mesh axes provided in
+    // split-axes). Odd elements get distributed to trailing shards. If a
+    // shardedDimsOffsets is provided, the shard shape is computed by
     // subtracting the offset of the current shard from the offset of the next
     // shard.
 
@@ -429,8 +430,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
 
     // To keep the code simple, convert dims/device to values when they are
     // attributes. Count on canonicalization to fold static values.
-    auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
-    auto multiIdx =
+    SmallVector<Value> shape =
+        getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
+    SmallVector<Value> multiIdx =
         getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
 
     // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
@@ -448,7 +450,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
     // local shard-size.
     Value shardedDimsOffs;
     {
-      auto tmp = getMixedAsValues(
+      SmallVector<Value> tmp = getMixedAsValues(
           rewriter, loc, sharding.getStaticShardedDimsOffsets(),
           sharding.getDynamicShardedDimsOffsets(), index);
       if (!tmp.empty())
@@ -478,7 +480,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
             rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
         // Get the index of the local shard in the mesh axis.
         Value idx = multiIdx[axes[0]];
-        auto _numShards =
+        auto numShards =
             collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
         if (shardedDimsOffs) {
           // If sharded dims offsets are provided, use them to compute the
@@ -497,22 +499,22 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
           Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
           shardShape.emplace_back(sz);
         } else {
-          auto numShards = rewriter.create<arith::ConstantOp>(
-              loc, rewriter.getIndexAttr(_numShards));
+          Value numShardsVal = rewriter.create<arith::ConstantOp>(
+              loc, rewriter.getIndexAttr(numShards));
           // Compute shard dim size by distributing odd elements to trailing
           // shards:
           // sz = dim / numShards
           //      + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
-          Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShards);
-          Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShards);
-          sz1 = rewriter.create<arith::SubIOp>(loc, numShards, sz1);
+          Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal);
+          Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal);
+          sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1);
           auto cond = rewriter.create<arith::CmpIOp>(
               loc, arith::CmpIPredicate::sge, idx, sz1);
           Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
           sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
           shardShape.emplace_back(sz);
         }
-        pos += _numShards + 1; // add one for the total size.
+        pos += numShards + 1; // add one for the total size.
       } // else no sharding if split axis is empty or no split axis
       // If no size was added -> no sharding in this dimension.
       if (shardShape.size() <= i)
@@ -698,25 +700,24 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
         offsets[dim] = orgOffset;
       };
 
-      auto get_i32val = [&](OpFoldResult &v) {
-        return isa<Value>(v)
-                   ? cast<Value>(v)
-                   : rewriter.create<arith::ConstantOp>(
-                         loc,
-                         rewriter.getI32IntegerAttr(
-                             cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
-      };
-
-      for (int i = 0; i < 2; ++i) {
-        Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);
+      auto doSendRecv = [&](int upOrDown) {
+        OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
+        Value haloSz = dyn_cast<Value>(v);
+        if (!haloSz)
+          haloSz = rewriter.create<arith::ConstantOp>(
+              loc, rewriter.getI32IntegerAttr(
+                       cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
         auto hasSize = rewriter.create<arith::CmpIOp>(
             loc, arith::CmpIPredicate::sgt, haloSz, zero);
         rewriter.create<scf::IfOp>(loc, hasSize,
                                    [&](OpBuilder &builder, Location loc) {
-                                     genSendRecv(i > 0);
+                                     genSendRecv(upOrDown > 0);
                                      builder.create<scf::YieldOp>(loc);
                                    });
-      }
+      };
+
+      doSendRecv(0);
+      doSendRecv(1);
 
       // the shape for lower dims include higher dims' halos
       dimSizes[dim] = shape[dim];
@@ -775,8 +776,8 @@ struct ConvertMeshToMPIPass
            SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
           auto i16 = IntegerType::get(type.getContext(), 16);
           auto i64 = IntegerType::get(type.getContext(), 64);
-          std::array<int64_t, 2> shp{ShapedType::kDynamic,
-                                     ShapedType::kDynamic};
+          std::array<int64_t, 2> shp = {ShapedType::kDynamic,
+                                        ShapedType::kDynamic};
           results.emplace_back(RankedTensorType::get(shp, i16));
           results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
           results.emplace_back(RankedTensorType::get(shp, i64));
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
index 1cc848333ced2..156bbfb54845b 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
@@ -34,7 +34,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
     return %1#0, %1#1, %1#2 : index, index, index
   }
 
-  // all except first shard in second dim get an extra element
+  // In the second dimension the shard sizes are now [3 4 4 4]
   // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
   func.func @shard_shape_odd_2() -> (index, index, index) {
     %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
@@ -46,7 +46,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
     return %1#0, %1#1, %1#2 : index, index, index
   }
 
-  // all except first shard in first dim get an extra element
+  // In the first dimension the shard sizes are now [3 4 4]
   // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
   func.func @shard_shape_odd_3() -> (index, index, index) {
     %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
@@ -72,4 +72,4 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
     // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
     return %1#0, %1#1, %1#2 : index, index, index
   }
-}
\ No newline at end of file
+}



More information about the Mlir-commits mailing list