[Mlir-commits] [mlir] [mlir][shard, mpi] Fixing lowering allgather shard->mpi->llvm (PR #178870)

Frank Schlimbach llvmlistbot at llvm.org
Fri Jan 30 09:14:12 PST 2026


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/178870

>From d743121f2dc16d5f1fca00c7d7129a49fa3a993e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 30 Jan 2026 01:36:35 -0800
Subject: [PATCH 1/4] adding support for gather_axis!=0 in shard.allgather

---
 mlir/include/mlir/Dialect/MPI/IR/MPIOps.td    |  1 +
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   | 38 +++++---
 mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp | 96 +++++++++++++++++--
 mlir/lib/Dialect/MPI/IR/MPIOps.cpp            | 58 +++++++----
 mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 42 +++++---
 .../ShardToMPI/convert-shard-to-mpi.mlir      | 42 ++++++--
 mlir/test/Dialect/MPI/canonicalize.mlir       | 14 +++
 7 files changed, 230 insertions(+), 61 deletions(-)
 create mode 100644 mlir/test/Dialect/MPI/canonicalize.mlir

diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index b24c12459475e..d9e47ea3f6bfe 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -96,6 +96,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
   );
 
   let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 4a1c5d1f7846c..9ddef319ad71c 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -51,7 +51,8 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
 
 std::pair<Value, Value> getRawPtrAndSize(const Location loc,
                                          ConversionPatternRewriter &rewriter,
-                                         Value memRef, Type elType) {
+                                         Value memRef, int64_t rank,
+                                         Type elType) {
   Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
   Value dataPtr =
       LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
@@ -59,11 +60,16 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
                                               rewriter.getI64Type(), memRef, 2);
   Value resPtr =
       LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
-  Value size;
+  Value size = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+                                        rewriter.getIndexAttr(1));
   if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
-    size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
-                                        ArrayRef<int64_t>{3, 0});
-    size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
+    for (int64_t i = 0; i < rank; ++i) {
+      Value dim = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
+                                               ArrayRef<int64_t>{3, i});
+      dim = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), dim);
+      size =
+          LLVM::MulOp::create(rewriter, loc, rewriter.getI32Type(), dim, size);
+    }
   } else {
     size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
   }
@@ -396,7 +402,7 @@ class OMPIImplTraits : public MPIImplTraits {
 };
 
 std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
-  auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
+  auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, false);
   if (failed(attr))
     return std::make_unique<MPICHImplTraits>(moduleOp);
   auto strAttr = dyn_cast<StringAttr>(attr.value());
@@ -634,7 +640,7 @@ struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
         LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
     // get or create function declaration:
     LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
-        moduleOp, loc, rewriter, "MPI_Comm_Size", SizeFuncType);
+        moduleOp, loc, rewriter, "MPI_Comm_size", SizeFuncType);
 
     // replace with function call
     auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
@@ -675,6 +681,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
     MLIRContext *context = rewriter.getContext();
     Type i32 = rewriter.getI32Type();
     Type elemType = op.getRef().getType().getElementType();
+    int64_t rank = op.getRef().getType().getRank();
 
     // ptrType `!llvm.ptr`
     Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -684,7 +691,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
 
     // get MPI_COMM_WORLD, dataType and pointer
     auto [dataPtr, size] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
@@ -727,6 +734,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
     Type i32 = rewriter.getI32Type();
     Type i64 = rewriter.getI64Type();
     Type elemType = op.getRef().getType().getElementType();
+    int64_t rank = op.getRef().getType().getRank();
 
     // ptrType `!llvm.ptr`
     Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -736,7 +744,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
 
     // get MPI_COMM_WORLD, dataType, status_ignore and pointer
     auto [dataPtr, size] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
     Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
@@ -782,10 +790,12 @@ struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
     MLIRContext *context = rewriter.getContext();
     Type sElemType = op.getSendbuf().getType().getElementType();
     Type rElemType = op.getRecvbuf().getType().getElementType();
+    int64_t sRank = op.getSendbuf().getType().getRank();
+    int64_t rRank = op.getRecvbuf().getType().getRank();
     auto [sendPtr, sendSize] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, sElemType);
     auto [recvPtr, recvSize] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, rElemType);
 
     auto moduleOp = op->getParentOfType<ModuleOp>();
     auto mpiTraits = MPIImplTraits::get(moduleOp);
@@ -843,15 +853,17 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
     Type i32 = rewriter.getI32Type();
     Type i64 = rewriter.getI64Type();
     Type elemType = op.getSendbuf().getType().getElementType();
+    int64_t sRank = op.getSendbuf().getType().getRank();
+    int64_t rRank = op.getRecvbuf().getType().getRank();
 
     // ptrType `!llvm.ptr`
     Type ptrType = LLVM::LLVMPointerType::get(context);
     auto moduleOp = op->getParentOfType<ModuleOp>();
     auto mpiTraits = MPIImplTraits::get(moduleOp);
     auto [sendPtr, sendSize] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
     auto [recvPtr, recvSize] =
-        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
 
     // If input and output are the same, request in-place operation.
     if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 87ae28892fcf7..45e07a1ce5fb9 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -30,6 +30,7 @@
 #include "mlir/Dialect/Shard/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -620,6 +621,13 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
 struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
   using CommOpPattern::CommOpPattern;
 
+  // shard.allgather always concatenates along a specified gather-axis.
+  // mpi.allgather always concatenates along the first dimension and
+  // there is no MPI operation that allows gathering along an arbitrary axis.
+  // Hence, if gather-axis!=0, we need to create a temporary buffer
+  // where we gather along the first dimension and then copy from that
+  // buffer to the final output along the specified gather-axis.
+
   LogicalResult
   matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -637,19 +645,95 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
     if (!memref::isStaticShapeAndContiguousRowMajor(outType))
       return op.emitError(
           "Expected static shaped memref in contiguous row-major layout.");
+    int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
+    auto ctx = op->getContext();
 
     // Get the right communicator
     Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
-    // Allocate output buffer
-    Value output = memref::AllocOp::create(iBuilder, outType);
+
+    Value nRanks =
+        mpi::CommSizeOp::create(iBuilder, iBuilder.getI32Type(), comm)
+            .getSize();
+    nRanks =
+        arith::IndexCastOp::create(iBuilder, iBuilder.getIndexType(), nRanks);
+
+    Value tmpOutput, gatherDimSz;
+    if (gatherAxis == 0) {
+      tmpOutput = memref::AllocOp::create(iBuilder, outType);
+    } else {
+      // MPI's allgather always concatenates along the first dimension.
+      // Create a memref type for the output buffer with adjusted (expanded)
+      // shape.
+      SmallVector<int64_t> gatherShape(1, ShapedType::kDynamic);
+      llvm::append_range(gatherShape, outType.getShape());
+      gatherShape[gatherAxis + 1] = ShapedType::kDynamic;
+      MemRefType gatherType =
+          MemRefType::get(gatherShape, outType.getElementType());
+      gatherDimSz = arith::ConstantIndexOp::create(
+          iBuilder, outType.getDimSize(gatherAxis));
+      gatherDimSz = arith::DivSIOp::create(iBuilder, iBuilder.getIndexType(),
+                                           gatherDimSz, nRanks);
+      // Allocate output buffer
+      tmpOutput =
+          memref::AllocOp::create(iBuilder, gatherType, {nRanks, gatherDimSz});
+    }
     // Create the MPI AllGather operation.
-    mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
+    mpi::AllGatherOp::create(iBuilder, TypeRange(), input, tmpOutput, comm);
+
+    // If gather-axis!=0, copy from gathered buffer to output with the right
+    // layout.
+    Value finalOutput = tmpOutput;
+    if (gatherAxis != 0) {
+      int64_t nSrcDims = cast<ShapedType>(tmpOutput.getType()).getRank();
+      assert(nSrcDims == outType.getRank() + 1 &&
+             "Expected gathered type to have rank one more than output type.");
+
+      // Create affine map for copying from gathered buffer to output.
+      SmallVector<AffineExpr> dims;
+      dims.reserve(nSrcDims);
+      for (unsigned i = 0; i < nSrcDims; ++i)
+        dims.emplace_back(getAffineDimExpr(i, ctx));
+      AffineExpr s = getAffineSymbolExpr(0, ctx);
+      SmallVector<AffineExpr> results;
+      results.reserve(nSrcDims);
+      for (unsigned i = 0; i < nSrcDims - 1; ++i) {
+        if (i == gatherAxis)
+          results.emplace_back(dims[0] * s + dims[gatherAxis + 1]);
+        else
+          results.emplace_back(dims[i + 1]);
+      }
+      auto affineMap = AffineMap::get(nSrcDims, /*symbols=*/1, results, ctx);
+
+      finalOutput = memref::AllocOp::create(iBuilder, outType);
+
+      // Now build a loop nest to copy from gathered buffer to finalOutput
+      // It would be nicer to just use a memref.transpose/collapse_shape op but
+      // these currently only support simpler cases.
+      Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
+      SmallVector<Value> lbs(nSrcDims, zero);
+      SmallVector<Value> ubs;
+      for (int64_t d = 0; d < nSrcDims; ++d)
+        ubs.emplace_back(memref::DimOp::create(iBuilder, tmpOutput, d));
+      SmallVector<int64_t> steps(nSrcDims, 1);
+      auto emitCopy = [&](OpBuilder &builder, Location loc, ValueRange ivs) {
+        Value v = memref::LoadOp::create(iBuilder, tmpOutput, ivs);
+        // set symbol value
+        SmallVector<Value> ivss(ivs.begin(), ivs.end());
+        ivss.emplace_back(gatherDimSz);
+        affine::AffineStoreOp::create(iBuilder, v, finalOutput, affineMap,
+                                      ivss);
+      };
+      affine::buildAffineLoopNest(iBuilder, op->getLoc(), lbs, ubs, steps,
+                                  emitCopy);
+
+      memref::DeallocOp::create(iBuilder, tmpOutput);
+    }
 
     // If the destination is a tensor, cast it to a tensor
     if (isa<RankedTensorType>(op.getType()))
-      output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
-                                                 true);
-    rewriter.replaceOp(op, output);
+      finalOutput = bufferization::ToTensorOp::create(iBuilder, op.getType(),
+                                                      finalOutput, true);
+    rewriter.replaceOp(op, finalOutput);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index f52c3f99189d2..e47a1f62d76d0 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -43,33 +43,46 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
   }
 };
 
+template <typename OpT>
+static LogicalResult FoldToDLTIConst(OpT op, const char *key,
+                                     mlir::PatternRewriter &b) {
+  auto comm = op.getComm();
+  if (!comm.template getDefiningOp<mlir::mpi::CommWorldOp>())
+    return mlir::failure();
+
+  // Try to get DLTI attribute for MPI:comm_world_rank
+  // If found, set worldRank to the value of the attribute.
+  auto dltiAttr = dlti::query(op, {key}, false);
+  if (failed(dltiAttr))
+    return mlir::failure();
+  if (!isa<IntegerAttr>(dltiAttr.value()))
+    return op->emitError() << "Expected an integer attribute for " << key;
+  Value res = arith::ConstantOp::create(
+      b, op.getLoc(), b.getI32Type(),
+      b.getI32IntegerAttr(cast<IntegerAttr>(dltiAttr.value()).getInt()));
+  if (Value retVal = op.getRetval())
+    b.replaceOp(op, {retVal, res});
+  else
+    b.replaceOp(op, res);
+  return mlir::success();
+}
+
 struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
   using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
-
   LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
                                 mlir::PatternRewriter &b) const override {
-    auto comm = op.getComm();
-    if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>())
-      return mlir::failure();
-
-    // Try to get DLTI attribute for MPI:comm_world_rank
-    // If found, set worldRank to the value of the attribute.
-    auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
-    if (failed(dltiAttr))
-      return mlir::failure();
-    if (!isa<IntegerAttr>(dltiAttr.value()))
-      return op->emitError()
-             << "Expected an integer attribute for MPI:comm_world_rank";
-    Value res = arith::ConstantIndexOp::create(
-        b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
-    if (Value retVal = op.getRetval())
-      b.replaceOp(op, {retVal, res});
-    else
-      b.replaceOp(op, res);
-    return mlir::success();
+    return FoldToDLTIConst(op, "MPI:comm_world_rank", b);
   }
 };
 
+struct FoldSize final : public mlir::OpRewritePattern<mlir::mpi::CommSizeOp> {
+  using mlir::OpRewritePattern<mlir::mpi::CommSizeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::mpi::CommSizeOp op,
+                                mlir::PatternRewriter &b) const override {
+    return FoldToDLTIConst(op, "MPI:comm_world_size", b);
+  }
+};
 } // namespace
 
 void mlir::mpi::SendOp::getCanonicalizationPatterns(
@@ -97,6 +110,11 @@ void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
   results.add<FoldRank>(context);
 }
 
+void mlir::mpi::CommSizeOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+  results.add<FoldSize>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 4c1beee2fe144..c6d4ca1eb2e33 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -4,7 +4,7 @@
 // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: llvm.func @MPI_Finalize() -> i32
 // CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
-// CHECK: llvm.func @MPI_Comm_Size(i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
 // CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
@@ -40,7 +40,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
+    // CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
+    // CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
     // CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
     // CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
     // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -50,7 +51,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
+    // CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
+    // CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
     // CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
     // CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
     // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -60,7 +62,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
+    // CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
+    // CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
     // CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
     // CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
     // CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
@@ -72,7 +75,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
+    // CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
+    // CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
     // CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
     // CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32
     // CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
@@ -98,12 +102,14 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v62:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v63:%.*]] = llvm.trunc [[v62]] : i64 to i32
+    // CHECK: [[v63a:%.*]] = llvm.trunc [[v62]] : i64 to i32
+    // CHECK: [[v63:%.*]] = llvm.mul [[v63a]]
     // CHECK: [[v64:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v65:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32
+    // CHECK: [[v68a:%.*]] = llvm.trunc [[v67]] : i64 to i32
+    // CHECK: [[v68:%.*]] = llvm.mul [[v68a]]
     // CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64
     // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
     // CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
@@ -127,7 +133,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
 // CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
-// CHECK: llvm.func @MPI_Comm_Size(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -164,7 +170,8 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
+    // CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
+    // CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
     // CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
     // CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
     // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -174,7 +181,8 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
+    // CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
+    // CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
     // CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
     // CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
     // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -184,7 +192,8 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
+    // CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
+    // CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
     // CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
     // CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
     // CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
@@ -196,7 +205,8 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
+    // CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
+    // CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
     // CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
     // CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
     // CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
@@ -204,7 +214,7 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
     %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
     
-    // CHECK: llvm.call @MPI_Comm_Size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
+    // CHECK: llvm.call @MPI_Comm_size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
     // CHECK: llvm.udiv {{.*}} : i32
     // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
     mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32>
@@ -213,12 +223,14 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+    // CHECK: [[v53a:%.*]] = llvm.trunc [[v52]] : i64 to i32
+    // CHECK: [[v53:%.*]] = llvm.mul [[v53a]]
     // CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
     // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+    // CHECK: [[v58a:%.*]] = llvm.trunc [[v57]] : i64 to i32
+    // CHECK: [[v58:%.*]] = llvm.mul [[v58a]]
     // CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64
     // CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
     // CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 9a8ad5eea1c7b..200ef488a1a1b 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -155,14 +155,28 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
       %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
     // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+    // CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
     // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
     // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
-    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32>
-    // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<3x20xf32>
-    // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x20xf32> to tensor<3x20xf32>
+    // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
+    // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
+    // CHECK: [[v3:%.*]] = arith.divsi [[vc20]], [[v2]] : index
+    // CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x3x?xf32>
+    // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
+    // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
+    // CHECK: affine.for [[varg1:%.*]] = 0 to [[v2]] {
+      // CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
+        // CHECK: affine.for [[varg3:%.*]] = 0 to [[v3]] {
+          // CHECK: [[v5:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
+          // CHECK: affine.store [[v5]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v3]]) + [[varg3]]] : memref<3x20xf32>
+        // CHECK: }
+      // CHECK: }
+    // CHECK: }
+    // CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
+    // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<3x20xf32> to tensor<3x20xf32>
     %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x20xf32>
-    // CHECK: return [[v2]] : tensor<3x20xf32>
+    // CHECK: return [[v4]] : tensor<3x20xf32>
     return %0 : tensor<3x20xf32>
   }
 
@@ -173,12 +187,26 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
       %arg0 : memref<3x4xf32>) -> memref<3x20xf32> {
     // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
     // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
     // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
-    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32>
-    // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<3x20xf32>
+    // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
+    // CHECK: [[v1:%.*]] = arith.index_cast [[vsize]] : i32 to index
+    // CHECK: [[v2:%.*]] = arith.divsi [[vc20]], [[v1]] : index
+    // CHECK: [[valloc:%.*]] = memref.alloc([[v1]], [[v2]]) : memref<?x3x?xf32>
+    // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
+    // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
+    // CHECK: affine.for [[varg1:%.*]] = 0 to [[v1]] {
+      // CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
+        // CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
+          // CHECK: [[v3:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
+          // CHECK: affine.store [[v3]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v2]]) + [[varg3]]] : memref<3x20xf32>
+        // CHECK: }
+      // CHECK: }
+    // CHECK: }
+    // CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
     %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32>
-    // CHECK: return [[valloc]] : memref<3x20xf32>
+    // CHECK: return [[valloc_0]] : memref<3x20xf32>
     return %0 : memref<3x20xf32>
   }
 }
diff --git a/mlir/test/Dialect/MPI/canonicalize.mlir b/mlir/test/Dialect/MPI/canonicalize.mlir
new file mode 100644
index 0000000000000..43787a7fad23f
--- /dev/null
+++ b/mlir/test/Dialect/MPI/canonicalize.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -canonicalize | mlir-opt | FileCheck %s
+
+module attributes {mpi.dlti = #dlti.map<"MPI:comm_world_size" = 12, "MPI:comm_world_rank" = 5> } {
+  // CHECK-LABEL: func.func @mpi_test
+  func.func @mpi_test(%ref : memref<100xf32>) -> (i32, i32) {
+    %comm = mpi.comm_world : !mpi.comm
+    // CHECK: [[s:%.*]] = arith.constant 12 : i32
+    %sz = mpi.comm_size(%comm) : i32
+    // CHECK: [[r:%.*]] = arith.constant 5 : i32
+    %rk = mpi.comm_rank(%comm) : i32
+    // CHECK: return [[s]], [[r]] : i32, i32
+    return %sz, %rk : i32, i32
+  }
+}
\ No newline at end of file

>From 7bd1a8eea04bd722903e5ff4f035964aaa17bd45 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 30 Jan 2026 03:14:28 -0800
Subject: [PATCH 2/4] adding mlp tests for partition and shardtompi

---
 mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp |  2 +-
 .../ShardToMPI/convert-shard-to-mpi.mlir      | 58 +++++++++++++++++++
 mlir/test/Dialect/MPI/canonicalize.mlir       |  2 +-
 mlir/test/Dialect/Shard/partition.mlir        | 56 +++++++++++++++++-
 4 files changed, 114 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 45e07a1ce5fb9..c765ad5a579c8 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -621,7 +621,7 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
 struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
   using CommOpPattern::CommOpPattern;
 
-  // shard.allgather always concatenates along a specified gather-axis.
+  // shard.allgather concatenates along a specified gather-axis.
   // mpi.allgather always concatenates along the first dimension and
   // there is no MPI operation that allows gathering along an arbitrary axis.
   // Hence, if gather-axis!=0, we need to create a temporary buffer
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 200ef488a1a1b..4ac4a69dd5b18 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -438,3 +438,61 @@ func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !sh
   // CHECK: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
   return %arg0, %sharding : tensor<?x?xf32>, !shard.sharding
 }
+
+// -----
+shard.grid @grid_1d_4(shape = 4)
+// CHECK-LABEL: func.func @mlp_1dgrid(
+// CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
+func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %arg2: tensor<256x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
+  // CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<512x512xf32> to memref<512x512xf32>
+  // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+  // CHECK: [[vsize:%.*]] = mpi.comm_size
+  // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
+  // CHECK: [[v3:%.*]] = arith.divsi
+  // CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x512x?xf32>
+  // CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<?x512x?xf32>
+  // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
+  // CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
+    // CHECK: affine.for [[varg4:%.*]] = 0 to 512 {
+      // CHECK: affine.for [[varg5:%.*]] = 0 to [[v3]] {
+        // CHECK: [[v19:%.*]] = memref.load [[valloc]][[[varg3]], [[varg4]], [[varg5]]] : memref<?x512x?xf32>
+        // CHECK: affine.store [[v19]], [[valloc_0]][[[varg4]], [[varg3]] * symbol([[v3]]) + [[varg5]]] : memref<512x2048xf32>
+  // CHECK: memref.dealloc [[valloc]] : memref<?x512x?xf32>
+  // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
+  %all_gather = shard.all_gather %arg0 on @grid_1d_4 grid_axes = [0] gather_axis = 1 : tensor<512x512xf32> -> tensor<512x2048xf32>
+  // CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x256xf32>
+  %0 = tensor.empty() : tensor<512x256xf32>
+  // CHECK: [[v6:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v5]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x256xf32>) -> tensor<512x256xf32>
+  // CHECK: [[v7:%.*]] = linalg.matmul ins([[v4]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+  %2 = linalg.matmul ins(%all_gather, %arg1 : tensor<512x2048xf32>, tensor<2048x256xf32>) outs(%1 : tensor<512x256xf32>) -> tensor<512x256xf32>
+  // CHECK: [[v8:%.*]] = tosa.sigmoid [[v7]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
+  %3 = tosa.sigmoid %2 : (tensor<512x256xf32>) -> tensor<512x256xf32>
+  %4 = tensor.empty() : tensor<512x2048xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+  %proc_linear_idx = shard.process_multi_index on @grid_1d_4 axes = [0] : index
+  %grid_shape = shard.grid_shape @grid_1d_4 axes = [0] : index
+  %6 = arith.cmpi eq, %proc_linear_idx, %c0 : index
+  // CHECK: [[v14:%.*]] = scf.if
+  %7 = scf.if %6 -> (tensor<512x2048xf32>) {
+    scf.yield %5 : tensor<512x2048xf32>
+  } else {
+    %9 = tensor.empty() : tensor<512x2048xf32>
+    %10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+    scf.yield %10 : tensor<512x2048xf32>
+  }
+  // CHECK: [[v15:%.*]] = linalg.matmul ins([[v8]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v14]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+  %8 = linalg.matmul ins(%3, %arg2 : tensor<512x256xf32>, tensor<256x2048xf32>) outs(%7 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+  // CHECK: [[v16:%.*]] = bufferization.to_buffer
+  // CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<512x2048xf32>
+  // CHECK: linalg.copy ins([[v16]] : memref<512x2048xf32>) outs([[valloc_1]] : memref<512x2048xf32>)
+  // CHECK: [[v17:%.*]] = mpi.comm_world : !mpi.comm
+  // CHECK: mpi.allreduce([[valloc_1]], [[valloc_1]], MPI_SUM, [[v17]]) : memref<512x2048xf32>, memref<512x2048xf32>
+  // CHECK: [[v18:%.*]] = bufferization.to_tensor [[valloc_1]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
+  %all_reduce = shard.all_reduce %8 on @grid_1d_4 grid_axes = [0] : tensor<512x2048xf32> -> tensor<512x2048xf32>
+  // CHECK: return [[v18]] : tensor<512x2048xf32>
+  return %all_reduce : tensor<512x2048xf32>
+}
diff --git a/mlir/test/Dialect/MPI/canonicalize.mlir b/mlir/test/Dialect/MPI/canonicalize.mlir
index 43787a7fad23f..3523d46e21219 100644
--- a/mlir/test/Dialect/MPI/canonicalize.mlir
+++ b/mlir/test/Dialect/MPI/canonicalize.mlir
@@ -11,4 +11,4 @@ module attributes {mpi.dlti = #dlti.map<"MPI:comm_world_size" = 12, "MPI:comm_wo
     // CHECK: return [[s]], [[r]] : i32, i32
     return %sz, %rk : i32, i32
   }
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
index 0f293a39608e3..cd9fa2215e0ee 100644
--- a/mlir/test/Dialect/Shard/partition.mlir
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -3,6 +3,7 @@
 // RUN:   %s | FileCheck %s
 
 shard.grid @grid_1d(shape = 2)
+shard.grid @grid_1d_4(shape = 4)
 
 // CHECK-LABEL: func @return_sharding
 func.func @return_sharding(
@@ -204,8 +205,6 @@ func.func @incomplete_sharding(
   return %3 : tensor<8x16xf32>
 }
 
-shard.grid @grid_1d_4(shape = 4)
-
 // CHECK-LABEL: func @ew_chain_with_halo
 func.func @ew_chain_with_halo(
   // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
@@ -318,3 +317,56 @@ func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
   // CHECK: return %[[reduced]] : tensor<3xi32>
   return %sharded_ret : tensor<6xi32>
 }
+
+// CHECK-LABEL: func.func @mlp_1dgrid
+// CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
+func.func @mlp_1dgrid(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
+  // CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
+  %sharding = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding
+  %sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0], []] : !shard.sharding
+  %sharding_1 = shard.sharding @grid_1d_4 split_axes = [[]] : !shard.sharding
+  %sharding_2 = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding
+  %sharding_3 = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding
+  %sharding_4 = shard.sharding @grid_1d_4 split_axes = [[0], []] : !shard.sharding
+  %sharding_5 = shard.sharding @grid_1d_4 split_axes = [[]] : !shard.sharding
+  %sharding_annotated = shard.shard %arg0 to %sharding_2 : tensor<512x2048xf32>
+  %sharding_annotated_6 = shard.shard %arg1 to %sharding_3 : tensor<2048x1024xf32>
+  %sharding_annotated_7 = shard.shard %arg2 to %sharding_4 : tensor<1024x2048xf32>
+  // CHECK-DAG: [[v0:%.*]] = tensor.empty() : tensor<512x256xf32>
+  %0 = tensor.empty() : tensor<512x1024xf32>
+  %sharding_annotated_8 = shard.shard %0 to %sharding : tensor<512x1024xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %sharding_annotated_9 = shard.shard %sharding_annotated_8 to %sharding annotate_for_users : tensor<512x1024xf32>
+  // CHECK-DAG: [[v1:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v0]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%sharding_annotated_9 : tensor<512x1024xf32>) -> tensor<512x1024xf32>
+  %sharding_annotated_10 = shard.shard %1 to %sharding : tensor<512x1024xf32>
+  // CHECK-DAG: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_1d_4 grid_axes = [0] gather_axis = 1 : tensor<512x512xf32> -> tensor<512x2048xf32>
+  %sharding_annotated_11 = shard.shard %sharding_annotated to %sharding_1 annotate_for_users : tensor<512x2048xf32>
+  %sharding_annotated_12 = shard.shard %sharding_annotated_6 to %sharding annotate_for_users : tensor<2048x1024xf32>
+  %sharding_annotated_13 = shard.shard %sharding_annotated_10 to %sharding annotate_for_users : tensor<512x1024xf32>
+  // CHECK: [[v2:%.*]] = linalg.matmul ins([[vall_gather]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v1]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+  %2 = linalg.matmul ins(%sharding_annotated_11, %sharding_annotated_12 : tensor<512x2048xf32>, tensor<2048x1024xf32>) outs(%sharding_annotated_13 : tensor<512x1024xf32>) -> tensor<512x1024xf32>
+  %sharding_annotated_14 = shard.shard %2 to %sharding : tensor<512x1024xf32>
+  %sharding_annotated_15 = shard.shard %sharding_annotated_14 to %sharding annotate_for_users : tensor<512x1024xf32>
+  // CHECK: [[v3:%.*]] = tosa.sigmoid [[v2]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
+  %3 = tosa.sigmoid %sharding_annotated_15 : (tensor<512x1024xf32>) -> tensor<512x1024xf32>
+  %sharding_annotated_16 = shard.shard %3 to %sharding : tensor<512x1024xf32>
+  // CHECK: [[v9:%.*]] = tensor.empty() : tensor<512x2048xf32>
+  %4 = tensor.empty() : tensor<512x2048xf32>
+  %sharding_annotated_17 = shard.shard %4 to %sharding_1 : tensor<512x2048xf32>
+  %sharding_annotated_18 = shard.shard %sharding_annotated_17 to %sharding_1 annotate_for_users : tensor<512x2048xf32>
+  // CHECK: [[v10:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v9]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%sharding_annotated_18 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+  %sharding_annotated_19 = shard.shard %5 to %sharding_1 : tensor<512x2048xf32>
+  %sharding_annotated_20 = shard.shard %sharding_annotated_16 to %sharding annotate_for_users : tensor<512x1024xf32>
+  %sharding_annotated_21 = shard.shard %sharding_annotated_7 to %sharding_0 annotate_for_users : tensor<1024x2048xf32>
+  %sharding_annotated_22 = shard.shard %sharding_annotated_19 to %sharding_1 annotate_for_users : tensor<512x2048xf32>
+  // CHECK: [[v7:%.*]] = scf.if
+  // CHECK: [[v8:%.*]] = linalg.matmul ins([[v3]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v7]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+  %6 = linalg.matmul ins(%sharding_annotated_20, %sharding_annotated_21 : tensor<512x1024xf32>, tensor<1024x2048xf32>) outs(%sharding_annotated_22 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+  %sharding_annotated_23 = shard.shard %6 to %sharding_1 : tensor<512x2048xf32>
+  // CHECK: [[vall_reduce:%.*]] = shard.all_reduce [[v8]] on @grid_1d_4 grid_axes = [0] : tensor<512x2048xf32> -> tensor<512x2048xf32>
+  %sharding_annotated_24 = shard.shard %sharding_annotated_23 to %sharding_5 annotate_for_users : tensor<512x2048xf32>
+  // CHECK: return [[vall_reduce]] : tensor<512x2048xf32>
+  return %sharding_annotated_24 : tensor<512x2048xf32>
+}

>From 1fd20a067462dbe2bd679bcc0e5c6d243535e7b9 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 30 Jan 2026 03:34:16 -0800
Subject: [PATCH 3/4]  delete redundant *&

---
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 9ddef319ad71c..548481250e4f6 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -402,7 +402,7 @@ class OMPIImplTraits : public MPIImplTraits {
 };
 
 std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
-  auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, false);
+  auto attr = dlti::query(moduleOp, {"MPI:Implementation"}, false);
   if (failed(attr))
     return std::make_unique<MPICHImplTraits>(moduleOp);
   auto strAttr = dyn_cast<StringAttr>(attr.value());

>From 0449fa9a6b85d5f71295c2ebc28bb19b83ae8550 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 30 Jan 2026 09:12:14 -0800
Subject: [PATCH 4/4] folding CommSize

---
 mlir/include/mlir/Dialect/MPI/IR/Utils.h      | 43 +++++++++++++++++++
 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp   | 15 ++++++-
 mlir/lib/Dialect/MPI/IR/MPIOps.cpp            | 27 +-----------
 mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 16 +++++++
 4 files changed, 73 insertions(+), 28 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/MPI/IR/Utils.h

diff --git a/mlir/include/mlir/Dialect/MPI/IR/Utils.h b/mlir/include/mlir/Dialect/MPI/IR/Utils.h
new file mode 100644
index 0000000000000..9ff78e9e092fc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MPI/IR/Utils.h
@@ -0,0 +1,43 @@
+//===- Utils.h - MPI dialect --------------------------------------*- C++-*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_MPI_IR_UTILS_H_
+#define MLIR_DIALECT_MPI_IR_UTILS_H_
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace mpi {
+template <typename OpT>
+LogicalResult FoldToDLTIConst(OpT op, const char *key,
+                              mlir::PatternRewriter &b) {
+  auto comm = op.getComm();
+  if (!comm.template getDefiningOp<mlir::mpi::CommWorldOp>())
+    return mlir::failure();
+
+  // Try to get DLTI attribute for MPI:comm_world_rank
+  // If found, set worldRank to the value of the attribute.
+  auto dltiAttr = dlti::query(op, {key}, false);
+  if (failed(dltiAttr))
+    return mlir::failure();
+  if (!isa<IntegerAttr>(dltiAttr.value()))
+    return op->emitError() << "Expected an integer attribute for " << key;
+  Value res = arith::ConstantOp::create(
+      b, op.getLoc(), b.getI32Type(),
+      b.getI32IntegerAttr(cast<IntegerAttr>(dltiAttr.value()).getInt()));
+  if (Value retVal = op.getRetval())
+    b.replaceOp(op, {retVal, res});
+  else
+    b.replaceOp(op, res);
+  return mlir::success();
+}
+} // namespace mpi
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MPI_IR_UTILS_H_
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 548481250e4f6..0dbc0a126a5c6 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MPI/IR/Utils.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <memory>
 
@@ -614,6 +615,17 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
 // CommSizeOpLowering
 //===----------------------------------------------------------------------===//
 
+static Value createOrFoldCommSize(ConversionPatternRewriter &rewriter,
+                                  Location loc, Value commOrg,
+                                  Value commAdapt) {
+  auto i32 = rewriter.getI32Type();
+  auto nRanksOp = mpi::CommSizeOp::create(rewriter, loc, i32, commOrg);
+  if (succeeded(FoldToDLTIConst(nRanksOp, "MPI:comm_world_size", rewriter)))
+    return nRanksOp.getSize();
+  rewriter.eraseOp(nRanksOp);
+  return mpi::CommSizeOp::create(rewriter, loc, i32, commAdapt).getSize();
+}
+
 struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
@@ -818,8 +830,7 @@ struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
 
     // count_recv is the number of elements received from each rank, not total
     Value nRanks =
-        mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm())
-            .getSize();
+        createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
     Value recvCountPerRank =
         LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
 
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index e47a1f62d76d0..6cca853071dc2 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -6,12 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MPI/IR/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/PatternMatch.h"
 
 using namespace mlir;
 using namespace mlir::mpi;
@@ -43,30 +42,6 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
   }
 };
 
-template <typename OpT>
-static LogicalResult FoldToDLTIConst(OpT op, const char *key,
-                                     mlir::PatternRewriter &b) {
-  auto comm = op.getComm();
-  if (!comm.template getDefiningOp<mlir::mpi::CommWorldOp>())
-    return mlir::failure();
-
-  // Try to get DLTI attribute for MPI:comm_world_rank
-  // If found, set worldRank to the value of the attribute.
-  auto dltiAttr = dlti::query(op, {key}, false);
-  if (failed(dltiAttr))
-    return mlir::failure();
-  if (!isa<IntegerAttr>(dltiAttr.value()))
-    return op->emitError() << "Expected an integer attribute for " << key;
-  Value res = arith::ConstantOp::create(
-      b, op.getLoc(), b.getI32Type(),
-      b.getI32IntegerAttr(cast<IntegerAttr>(dltiAttr.value()).getInt()));
-  if (Value retVal = op.getRetval())
-    b.replaceOp(op, {retVal, res});
-  else
-    b.replaceOp(op, res);
-  return mlir::success();
-}
-
 struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
   using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index c6d4ca1eb2e33..9ec81c53b41f8 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -95,6 +95,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
     // CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
     %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
 
+    // CHECK: llvm.call @MPI_Comm_size
     // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
     %err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
 
@@ -256,3 +257,18 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
     return
   }
 }
+
+// -----
+
+module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:comm_world_size" = 4, "MPI:comm_world_rank" = 1> } {
+  // CHECK: llvm.func @mpi_test_fold
+  func.func @mpi_test_fold(%arg0: memref<100xf32>) {
+    // CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
+    %comm = mpi.comm_world : !mpi.comm
+
+    // CHECK-NOT: llvm.call @MPI_Comm_size
+    // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
+    %err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+    return
+  }
+}



More information about the Mlir-commits mailing list