[Mlir-commits] [mlir] [mlir][mesh, mpi] Lower allreduce (PR #144060)

Frank Schlimbach llvmlistbot at llvm.org
Fri Jun 13 05:17:05 PDT 2025


https://github.com/fschlimb created https://github.com/llvm/llvm-project/pull/144060

Adding lowering mesh.allreduce to mpi.allreduce.
Minor restructuring to increase code reuse.

>From 59fce0fc3186498d01509c4933256ff4fdf796a9 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Mar 2025 15:01:42 +0100
Subject: [PATCH 1/2] Lowering mesh.all_reduce to mpi calls.

---
 mlir/include/mlir/Conversion/Passes.td        |   2 +
 mlir/include/mlir/Dialect/MPI/IR/MPI.h        |   1 +
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td    |  10 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |   4 +-
 .../Dialect/Mesh/Transforms/Simplifications.h |  10 +-
 .../mlir/Dialect/Mesh/Transforms/Transforms.h |   4 +
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp   | 168 +++++++--
 mlir/lib/Dialect/MPI/IR/MPIOps.cpp            |  38 ++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  25 ++
 .../Dialect/Mesh/Transforms/Transforms.cpp    |  16 +-
 .../MeshToMPI/convert-mesh-to-mpi.mlir        | 357 ++++++++++--------
 mlir/test/Dialect/Mesh/spmdization.mlir       |  30 ++
 12 files changed, 460 insertions(+), 205 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..5a864865adffc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -905,6 +905,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
     shard/partition sizes depend on the rank.
   }];
   let dependentDialects = [
+    "affine::AffineDialect",
+    "arith::ArithDialect",
     "memref::MemRefDialect",
     "mpi::MPIDialect",
     "scf::SCFDialect",
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.h b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
index f06b911ce3fe3..2b6743cd008c6 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.h
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
 
 //===----------------------------------------------------------------------===//
 // MPIDialect
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index d78aa92d201e7..c14837f6961eb 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/MPI/IR/MPI.td"
 include "mlir/Dialect/MPI/IR/MPITypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 class MPI_Op<string mnemonic, list<Trait> traits = []>
     : Op<MPI_Dialect, mnemonic, traits>;
@@ -41,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> {
 // CommWorldOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
+def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
   let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
   let description = [{
     This operation returns the predefined MPI_COMM_WORLD communicator.
@@ -56,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
 // CommRankOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
+def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
   let summary = "Get the current rank, equivalent to "
                 "`MPI_Comm_rank(comm, &rank)`";
   let description = [{
@@ -72,13 +73,14 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
   );
 
   let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
 // CommSizeOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
+def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
   let summary = "Get the size of the group associated to the communicator, "
                 "equivalent to `MPI_Comm_size(comm, &size)`";
   let description = [{
@@ -100,7 +102,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
 // CommSplitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
+def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
   let summary = "Partition the group associated with the given communicator into "
                 "disjoint subgroups";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index f59c4c4c67517..ef57751d7e179 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
     ```
   }];
   let arguments = !con(commonArgs, (ins
-    AnyRankedTensor:$input,
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$input,
     DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
   ));
   let results = (outs
-    AnyRankedTensor:$result
+    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
   );
   let assemblyFormat = [{
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index c64da29ca6412..3f1041cb25103 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -62,9 +62,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
   auto isEndomorphismOp = [reduction](Operation *op,
                                       std::optional<Operation *> referenceOp) {
     auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
-    if (!allReduceOp ||
-        allReduceOp.getInput().getType().getElementType() !=
-            allReduceOp.getResult().getType().getElementType() ||
+    auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
+    auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
+    if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
         allReduceOp.getReduction() != reduction) {
       return false;
     }
@@ -83,9 +83,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
     }
 
     auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+    auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
     return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
-           allReduceOp.getInput().getType().getElementType() ==
-               refAllReduceOp.getInput().getType().getElementType();
+           inType.getElementType() == refType.getElementType();
   };
   auto isAlgebraicOp = [](Operation *op) {
     return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index be82e2af399dc..5a1154bf9166e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -42,6 +42,10 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
 TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
                                                ArrayRef<MeshAxis> meshAxes,
                                                ImplicitLocOpBuilder &builder);
+TypedValue<IndexType>
+createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
+                         ArrayRef<MeshAxis> meshAxes,
+                         ImplicitLocOpBuilder &builder);
 
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 823d4d644f586..65469261d0957 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -12,9 +12,9 @@
 
 #include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,6 +22,8 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -289,27 +291,15 @@ struct ConvertProcessMultiIndexOp
 
 class ConvertProcessLinearIndexOp
     : public OpConversionPattern<ProcessLinearIndexOp> {
-  int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
 
 public:
   using OpConversionPattern::OpConversionPattern;
 
-  // Constructor accepting worldRank
-  ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
-                              MLIRContext *context, int64_t worldRank = -1)
-      : OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
-
   LogicalResult
   matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
+    // Create mpi::CommRankOp
     Location loc = op.getLoc();
-    if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
-      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
-      return success();
-    }
-
-    // Otherwise call create mpi::CommRankOp
     auto ctx = op.getContext();
     Value commWorld =
         rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
@@ -529,6 +519,121 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
   }
 };
 
+static mpi::MPI_OpClassEnumAttr getMPIReduction(ReductionKindAttr kind) {
+  auto ctx = kind.getContext();
+  switch (kind.getValue()) {
+  case ReductionKind::Sum:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_SUM);
+  case ReductionKind::Product:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_PROD);
+  case ReductionKind::Min:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MIN);
+  case ReductionKind::Max:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MAX);
+  case ReductionKind::BitwiseAnd:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BAND);
+  case ReductionKind::BitwiseOr:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BOR);
+  case ReductionKind::BitwiseXor:
+    return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BXOR);
+  default:
+    assert(false && "Unknown/unsupported reduction kind");
+  }
+}
+
+struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    auto mesh = adaptor.getMesh();
+    auto meshOp = getMesh(op, symbolTableCollection);
+    if (!meshOp)
+      return op->emitError() << "No mesh found for AllReduceOp";
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return op->emitError()
+             << "Dynamic mesh shape not supported in AllReduceOp";
+
+    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+    Value input = adaptor.getInput();
+    auto inputShape = cast<ShapedType>(input.getType()).getShape();
+
+    // If the source is a memref, cast it to a tensor.
+    if (isa<RankedTensorType>(input.getType())) {
+      auto memrefType = MemRefType::get(
+          inputShape, cast<ShapedType>(input.getType()).getElementType());
+      input = iBuilder.create<bufferization::ToMemrefOp>(memrefType, input);
+    }
+    MemRefType inType = cast<MemRefType>(input.getType());
+
+    // Get the actual shape to allocate the buffer.
+    SmallVector<OpFoldResult> shape(inType.getRank());
+    for (auto i = 0; i < inType.getRank(); ++i) {
+      auto s = inputShape[i];
+      if (ShapedType::isDynamic(s))
+        shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult();
+      else
+        shape[i] = iBuilder.getIndexAttr(s);
+    }
+
+    // Allocate buffer and copy input to buffer.
+    Value buffer = iBuilder.create<memref::AllocOp>(
+        shape, cast<ShapedType>(op.getType()).getElementType());
+    iBuilder.create<linalg::CopyOp>(input, buffer);
+
+    // Get an MPI_Comm_split for the AllReduce operation.
+    // The color is the linear index of the process in the mesh along the
+    // non-reduced axes. The key is the linear index of the process in the mesh
+    // along the reduced axes.
+    SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+                                       iBuilder.getIndexType());
+    SmallVector<Value> myMultiIndex =
+        iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh)
+            .getResult();
+    Value zero = iBuilder.create<arith::ConstantIndexOp>(0);
+    SmallVector<Value> multiKey(myMultiIndex.size(), zero);
+
+    auto redAxes = adaptor.getMeshAxes();
+    for (auto axis : redAxes) {
+      multiKey[axis] = myMultiIndex[axis];
+      myMultiIndex[axis] = zero;
+    }
+
+    Value color =
+        createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
+    color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color);
+    Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
+    key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key);
+
+    // Finally split the communicator
+    auto commType = mpi::CommType::get(op->getContext());
+    Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType);
+    auto comm =
+        iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key)
+            .getNewcomm();
+
+    // collapse shape to 1d
+    ReassociationIndices reassociation(inType.getRank());
+    std::iota(reassociation.begin(), reassociation.end(), 0);
+    Value buffer1d = iBuilder.create<memref::CollapseShapeOp>(
+        buffer, ArrayRef<ReassociationIndices>(reassociation));
+
+    // Create the MPI AllReduce operation.
+    iBuilder.create<mpi::AllReduceOp>(
+        TypeRange(), buffer1d, buffer1d,
+        getMPIReduction(adaptor.getReductionAttr()), comm);
+
+    // If the destination is a memref, cast it to a tensor
+    if (isa<RankedTensorType>(op.getType()))
+      buffer = iBuilder.create<bufferization::ToTensorOp>(buffer, true);
+
+    rewriter.replaceOp(op, buffer);
+    return success();
+  }
+};
+
 struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -573,10 +678,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
     Value array = dest;
     if (isa<RankedTensorType>(array.getType())) {
       // If the destination is a memref, we need to cast it to a tensor
-      auto tensorType = MemRefType::get(
+      auto mmemrefType = MemRefType::get(
           dstShape, cast<ShapedType>(array.getType()).getElementType());
       array =
-          rewriter.create<bufferization::ToBufferOp>(loc, tensorType, array);
+          rewriter.create<bufferization::ToMemrefOp>(loc, mmemrefType, array);
     }
     auto rank = cast<ShapedType>(array.getType()).getRank();
     auto opSplitAxes = adaptor.getSplitAxes().getAxes();
@@ -753,22 +858,6 @@ struct ConvertMeshToMPIPass
 
   /// Run the dialect converter on the module.
   void runOnOperation() override {
-    uint64_t worldRank = -1;
-    // Try to get DLTI attribute for MPI:comm_world_rank
-    // If found, set worldRank to the value of the attribute.
-    {
-      auto dltiAttr =
-          dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
-      if (succeeded(dltiAttr)) {
-        if (!isa<IntegerAttr>(dltiAttr.value())) {
-          getOperation()->emitError()
-              << "Expected an integer attribute for MPI:comm_world_rank";
-          return signalPassFailure();
-        }
-        worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
-      }
-    }
-
     auto *ctxt = &getContext();
     RewritePatternSet patterns(ctxt);
     ConversionTarget target(getContext());
@@ -819,10 +908,10 @@ struct ConvertMeshToMPIPass
     // ...except the global MeshOp
     target.addLegalOp<mesh::MeshOp>();
     // Allow all the stuff that our patterns will convert to
-    target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
-                           arith::ArithDialect, tensor::TensorDialect,
-                           bufferization::BufferizationDialect,
-                           linalg::LinalgDialect, memref::MemRefDialect>();
+    target.addLegalDialect<
+        BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
+        tensor::TensorDialect, bufferization::BufferizationDialect,
+        linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
     // Make sure the function signature, calls etc. are legal
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
       return typeConverter.isSignatureLegal(op.getFunctionType());
@@ -832,9 +921,10 @@ struct ConvertMeshToMPIPass
 
     patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
                  ConvertProcessMultiIndexOp, ConvertGetShardingOp,
-                 ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
-    // ConvertProcessLinearIndexOp accepts an optional worldRank
-    patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
+                 ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
+                 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
+    SymbolTableCollection symbolTableCollection;
+    mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
 
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
         patterns, typeConverter);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 56d8edfbcc025..6d445ca0e4099 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
@@ -41,6 +42,38 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
     return mlir::success();
   }
 };
+
+struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
+  using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
+                                mlir::PatternRewriter &b) const override {
+    auto comm = op.getComm();
+    if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>()) {
+      return mlir::failure();
+    }
+
+    // Try to get DLTI attribute for MPI:comm_world_rank
+    // If found, set worldRank to the value of the attribute.
+    {
+      auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
+      if (failed(dltiAttr))
+        return mlir::failure();
+      if (!isa<IntegerAttr>(dltiAttr.value())) {
+        return op->emitError()
+               << "Expected an integer attribute for MPI:comm_world_rank";
+      }
+      Value res = b.create<arith::ConstantIndexOp>(
+          op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+      if (Value retVal = op.getRetval())
+        b.replaceOp(op, {retVal, res});
+      else
+        b.replaceOp(op, res);
+      return mlir::success();
+    }
+  }
+};
+
 } // namespace
 
 void mlir::mpi::SendOp::getCanonicalizationPatterns(
@@ -63,6 +96,11 @@ void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
   results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
 }
 
+void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+  results.add<FoldRank>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 304cb55a35086..b84de2b716b32 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -75,6 +75,31 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
   return lhs.value() * rhs.value();
 }
 
+/// Converts a vector of OpFoldResults (ints) into vector of Values of the
+/// provided type.
+SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
+                                                const Location &loc,
+                                                llvm::ArrayRef<int64_t> statics,
+                                                ValueRange dynamics,
+                                                Type type) {
+  SmallVector<Value> values;
+  auto dyn = dynamics.begin();
+  Type i64 = b.getI64Type();
+  if (!type)
+    type = i64;
+  assert((i64 == type || b.getIndexType() == type) &&
+         "expected an i64 or an intex type");
+  for (auto s : statics) {
+    if (s == ShapedType::kDynamic) {
+      values.emplace_back(*(dyn++));
+    } else {
+      TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
+      values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
+    }
+  }
+  return values;
+}
+
 //===----------------------------------------------------------------------===//
 // Inliner
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 447668cc0ea50..228c56fdf4db8 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -207,11 +207,10 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
       builder.getIndexType()));
 }
 
-TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
-                                               ArrayRef<MeshAxis> meshAxes,
-                                               ImplicitLocOpBuilder &builder) {
-  ResultRange processInGroupMultiIndex =
-      builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
+TypedValue<IndexType>
+createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
+                         ArrayRef<MeshAxis> meshAxes,
+                         ImplicitLocOpBuilder &builder) {
   Operation::result_range processGroupShape =
       builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
   OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
@@ -220,4 +219,11 @@ TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
   return cast<TypedValue<IndexType>>(cast<Value>(processInGroupLinearIndex));
 }
 
+TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
+                                               ArrayRef<MeshAxis> meshAxes,
+                                               ImplicitLocOpBuilder &builder) {
+  return createProcessLinearIndex(
+      mesh, builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults(),
+      meshAxes, builder);
+}
 } // namespace mlir::mesh
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index d314ad3ac30fd..848413e7addde 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -80,6 +80,63 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
   }
 }
 
+// -----
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
+  mesh.mesh @mesh0(shape = 3x4x5)
+  // CHECK-LABEL: func.func @allreduce_tensor(
+  func.func @allreduce_tensor(
+    // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> {
+    // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
+    // CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
+    // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
+    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
+    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+    // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
+    %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
+    // CHECK: return [[v2]] : tensor<3x4xf32>
+    return %0 : tensor<3x4xf32>
+  }
+
+  // CHECK-LABEL: func.func @allreduce_memref(
+  func.func @allreduce_memref(
+    // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
+    %arg0 : memref<3x4xf32>) -> memref<3x4xf32> {
+    // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
+    // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
+    // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
+    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
+    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+    %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
+    // CHECK: return [[valloc]] : memref<3x4xf32>
+    return %0 : memref<3x4xf32>
+  }
+
+  // CHECK-LABEL: func.func @allreduce_new_type(
+  func.func @allreduce_new_type(
+    // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
+    %arg0 : memref<3x4xf32>) -> memref<3x4xf64> {
+    // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64>
+    // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>)
+    // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
+    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
+    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
+    %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
+    // CHECK: return [[valloc]] : memref<3x4xf64>
+    return %0 : memref<3x4xf64>
+  }
+}
+
 // -----
 mesh.mesh @mesh0(shape = 3x4x5)
 // CHECK-LABEL: func @update_halo_1d_first
@@ -91,13 +148,13 @@ func.func @update_halo_1d_first(
   // CHECK-SAME: : memref<2x120x120xi8>, i32, i32
   // CHECK: mpi.recv(
   // CHECK-SAME: : memref<2x120x120xi8>, i32, i32
-  // CHECK-NEXT: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+  // CHECK: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
   // CHECK: memref.subview [[arg0]][2, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
   // CHECK: mpi.send(
   // CHECK-SAME: : memref<3x120x120xi8>, i32, i32
   // CHECK: mpi.recv(
   // CHECK-SAME: : memref<3x120x120xi8>, i32, i32
-  // CHECK-NEXT: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
+  // CHECK: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
   %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
   // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
   return %res : memref<120x120x120xi8>
@@ -110,18 +167,18 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
   func.func @update_halo_1d_with_zero (
     // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
     %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
-    // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
-    // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
-    // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
-    // CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
-    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
-    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> to memref<2x120x120xi8>
-    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
-    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
-    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
-    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
-    // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
+    // CHECK-DAG: [[vc91_i32:%.*]] = arith.constant 91 : i32
+    // CHECK-DAG: [[vc0_i32:%.*]] = arith.constant 0 : i32
+    // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
+    // CHECK: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+    // CHECK: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> to memref<2x120x120xi8>
+    // CHECK: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK: memref.dealloc [[valloc]] : memref<2x120x120xi8>
     %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
     // CHECK: return [[varg0]] : memref<120x120x120xi8>
     return %res : memref<120x120x120xi8>
@@ -135,50 +192,50 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
   func.func @update_halo_3d(
     // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
     %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
-    // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
-    // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
-    // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-    // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
-    // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
-    // CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
-    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
-    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
-    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
-    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-    // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
-    // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
-    // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
-    // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-    // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
-    // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
-    // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-    // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-    // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
-    // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
-    // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
-    // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
-    // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x3x120xi8>, i32, i32
-    // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
-    // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
-    // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x4x120xi8>, i32, i32
-    // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
-    // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
-    // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
-    // CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
-    // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<1x120x120xi8>, i32, i32
-    // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
-    // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
-    // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
-    // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
-    // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
-    // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
-    // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32
-    // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+    // CHECK-DAG: [[vc23_i32:%.*]] = arith.constant 23 : i32
+    // CHECK-DAG: [[vc29_i32:%.*]] = arith.constant 29 : i32
+    // CHECK-DAG: [[vc91_i32:%.*]] = arith.constant 91 : i32
+    // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK-DAG: [[vc44_i32:%.*]] = arith.constant 44 : i32
+    // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+    // CHECK: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+    // CHECK: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+    // CHECK: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+    // CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+    // CHECK: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+    // CHECK: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+    // CHECK: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+    // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+    // CHECK: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+    // CHECK: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+    // CHECK: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x3x120xi8>, i32, i32
+    // CHECK: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+    // CHECK: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+    // CHECK: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x4x120xi8>, i32, i32
+    // CHECK: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+    // CHECK: [[v2:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+    // CHECK: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<1x120x120xi8>, i32, i32
+    // CHECK: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+    // CHECK: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+    // CHECK: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+    // CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+    // CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
     %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
     // CHECK: return [[varg0]] : memref<120x120x120xi8>
     return %res : memref<120x120x120xi8>
@@ -188,54 +245,54 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
   func.func @update_halo_3d_tensor(
     // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
     %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
-    // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
-    // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
-    // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
-    // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
-    // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
-    // CHECK-NEXT: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
-    // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
-    // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
-    // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
-    // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
-    // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
-    // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-    // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
-    // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
-    // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
-    // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
-    // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
-    // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
-    // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
-    // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-    // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
-    // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
-    // CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
-    // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
-    // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
-    // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x3x120xi8>, i32, i32
-    // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
-    // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
-    // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x4x120xi8>, i32, i32
-    // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
-    // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
-    // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
-    // CHECK-NEXT: [[v3:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
-    // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<1x120x120xi8>, i32, i32
-    // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
-    // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
-    // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
-    // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
-    // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
-    // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
-    // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32
-    // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
-    // CHECK-NEXT: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
+    // CHECK-DAG: [[vc23_i32:%.*]] = arith.constant 23 : i32
+    // CHECK-DAG: [[vc29_i32:%.*]] = arith.constant 29 : i32
+    // CHECK-DAG: [[vc44_i32:%.*]] = arith.constant 44 : i32
+    // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK-DAG: [[vc91_i32:%.*]] = arith.constant 91 : i32
+    // CHECK: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
+    // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+    // CHECK: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+    // CHECK: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+    // CHECK: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32
+    // CHECK: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+    // CHECK: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+    // CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+    // CHECK: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+    // CHECK: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+    // CHECK: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32
+    // CHECK: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+    // CHECK: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+    // CHECK: [[v2:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+    // CHECK: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+    // CHECK: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+    // CHECK: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x3x120xi8>, i32, i32
+    // CHECK: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+    // CHECK: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+    // CHECK: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x4x120xi8>, i32, i32
+    // CHECK: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+    // CHECK: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+    // CHECK: [[v3:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+    // CHECK: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<1x120x120xi8>, i32, i32
+    // CHECK: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+    // CHECK: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+    // CHECK: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+    // CHECK: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+    // CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+    // CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32
+    // CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+    // CHECK: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
     %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
-    // CHECK-NEXT: return [[v4]] : tensor<120x120x120xi8>
+    // CHECK: return [[v4]] : tensor<120x120x120xi8>
     return %res : tensor<120x120x120xi8>
   }
 }
@@ -246,19 +303,19 @@ mesh.mesh @mesh0(shape = 2x2x4)
 // CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
 func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) {
   %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding
-  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
-  // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
-  // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
-  // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
-  // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
-  // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
-  // CHECK-NEXT: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
-  // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
-  // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64>
-  // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor<?x?xi16>
-  // CHECK-NEXT: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
-  // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
-  // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+  // CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+  // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+  // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+  // CHECK: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+  // CHECK: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor<?x?xi16>
+  // CHECK: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
   return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding
 }
 
@@ -266,19 +323,19 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sh
 // CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
 func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) {
   %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding
-  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
-  // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
-  // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
-  // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
-  // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
-  // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
-  // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
-  // CHECK-NEXT: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
-  // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
-  // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor<?x?xi16>
-  // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
-  // CHECK-NEXT: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
-  // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+  // CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
+  // CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+  // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+  // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+  // CHECK: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+  // CHECK: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor<?x?xi16>
+  // CHECK: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
+  // CHECK: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
   return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding
 }
 
@@ -286,24 +343,24 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !m
 // CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
 func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !mesh.sharding) {
   %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding
-  // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
-  // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
-  // CHECK-NEXT: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
-  // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
-  // CHECK-NEXT: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
-  // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
-  // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
-  // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
-  // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
-  // CHECK-NEXT: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
-  // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
-  // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64>
-  // CHECK-NEXT: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64>
-  // CHECK-NEXT: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
-  // CHECK-NEXT: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
-  // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor<?x?xi16>
-  // CHECK-NEXT: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
-  // CHECK-NEXT: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
-  // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+  // CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
+  // CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
+  // CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
+  // CHECK: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
+  // CHECK: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+  // CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+  // CHECK: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+  // CHECK: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+  // CHECK: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+  // CHECK: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+  // CHECK: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+  // CHECK: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64>
+  // CHECK: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64>
+  // CHECK: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
+  // CHECK: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
+  // CHECK: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor<?x?xi16>
+  // CHECK: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+  // CHECK: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
+  // CHECK: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
   return %arg0, %sharding : tensor<?x?xf32>, !mesh.sharding
 }
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 5c9fd29444f04..f928f53313e1e 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -241,6 +241,36 @@ func.func @ew_chain_with_halo(
   return %sharding_annotated_6 : tensor<8x16xf32>
 }
 
+// CHECK-LABEL: func @ew_chain_with_halo_odd
+func.func @ew_chain_with_halo_odd(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<?x16xf32>
+  %arg0: tensor<11x16xf32>)
+  // CHECK-SAME: -> tensor<?x16xf32>
+   -> tensor<11x16xf32> {
+  %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated  annotate_for_users : tensor<11x16xf32>
+  // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<?x16xf32>) -> tensor<?x16xf32>
+  %0 = tosa.tanh %sharding_annotated : (tensor<11x16xf32>) -> tensor<11x16xf32>
+  %ssharding_annotated_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_0 = mesh.shard %0 to %ssharding_annotated_0  : tensor<11x16xf32>
+  %ssharding_annotated_1 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1  annotate_for_users : tensor<11x16xf32>
+  // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<?x16xf32>) -> tensor<?x16xf32>
+  %1 = tosa.abs %sharding_annotated_1 : (tensor<11x16xf32>) -> tensor<11x16xf32>
+  %ssharding_annotated_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2  : tensor<11x16xf32>
+  %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4  annotate_for_users : tensor<11x16xf32>
+  // CHECK-NEXT: %[[TMP3:.*]] = tosa.ceil %[[TMP2]] : (tensor<?x16xf32>) -> tensor<?x16xf32>
+  %2 = tosa.ceil %sharding_annotated_4 : (tensor<11x16xf32>) -> tensor<11x16xf32>
+  %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5  : tensor<11x16xf32>
+  %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
+  %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to %ssharding_annotated_6  annotate_for_users : tensor<11x16xf32>
+  // CHECK-NEXT: return %[[TMP3]] : tensor<?x16xf32>
+  return %sharding_annotated_6 : tensor<11x16xf32>
+}
+
 // CHECK-LABEL: func @test_shard_update_halo
 // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
 func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {

>From b07686ca25fec076dc5fff49c8829cc55ac718d7 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 2 Apr 2025 18:14:04 +0200
Subject: [PATCH 2/2] fixing in-place mesh.all_reduce

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td    |  4 ++--
 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp     | 13 ++++++++-----
 mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp |  6 +++++-
 3 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index ef57751d7e179..ac05ee243d7be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
     ```
   }];
   let arguments = !con(commonArgs, (ins
-    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$input,
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
     DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
   ));
   let results = (outs
-    AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
   );
   let assemblyFormat = [{
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 65469261d0957..521569e69b61a 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -614,11 +614,14 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
         iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key)
             .getNewcomm();
 
-    // collapse shape to 1d
-    ReassociationIndices reassociation(inType.getRank());
-    std::iota(reassociation.begin(), reassociation.end(), 0);
-    Value buffer1d = iBuilder.create<memref::CollapseShapeOp>(
-        buffer, ArrayRef<ReassociationIndices>(reassociation));
+    Value buffer1d = buffer;
+    // Collapse shape to 1d if needed
+    if (inType.getRank() > 1) {
+      ReassociationIndices reassociation(inType.getRank());
+      std::iota(reassociation.begin(), reassociation.end(), 0);
+      buffer1d = iBuilder.create<memref::CollapseShapeOp>(
+          buffer, ArrayRef<ReassociationIndices>(reassociation));
+    }
 
     // Create the MPI AllReduce operation.
     iBuilder.create<mpi::AllReduceOp>(
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 228c56fdf4db8..244b52a0c1b5c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -216,7 +216,11 @@ createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
   OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
       llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
       llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
-  return cast<TypedValue<IndexType>>(cast<Value>(processInGroupLinearIndex));
+  Value res = dyn_cast<Value>(processInGroupLinearIndex);
+  if (!res)
+    res = builder.create<arith::ConstantIndexOp>(
+        cast<IntegerAttr>(cast<Attribute>(processInGroupLinearIndex)).getInt());
+  return cast<TypedValue<IndexType>>(res);
 }
 
 TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,



More information about the Mlir-commits mailing list