[Mlir-commits] [mlir] [mlir][shard, mpi] Fixing lowering allgather shard->mpi->llvm (PR #178870)
Frank Schlimbach
llvmlistbot at llvm.org
Fri Jan 30 09:14:12 PST 2026
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/178870
>From d743121f2dc16d5f1fca00c7d7129a49fa3a993e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 30 Jan 2026 01:36:35 -0800
Subject: [PATCH 1/4] adding support for gather_axis!=0 in shard.allgather
---
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 1 +
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 38 +++++---
mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp | 96 +++++++++++++++++--
mlir/lib/Dialect/MPI/IR/MPIOps.cpp | 58 +++++++----
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 42 +++++---
.../ShardToMPI/convert-shard-to-mpi.mlir | 42 ++++++--
mlir/test/Dialect/MPI/canonicalize.mlir | 14 +++
7 files changed, 230 insertions(+), 61 deletions(-)
create mode 100644 mlir/test/Dialect/MPI/canonicalize.mlir
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index b24c12459475e..d9e47ea3f6bfe 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -96,6 +96,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
);
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 4a1c5d1f7846c..9ddef319ad71c 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -51,7 +51,8 @@ static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
std::pair<Value, Value> getRawPtrAndSize(const Location loc,
ConversionPatternRewriter &rewriter,
- Value memRef, Type elType) {
+ Value memRef, int64_t rank,
+ Type elType) {
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
Value dataPtr =
LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
@@ -59,11 +60,16 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
rewriter.getI64Type(), memRef, 2);
Value resPtr =
LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
- Value size;
+ Value size = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIndexAttr(1));
if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
- size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
+ for (int64_t i = 0; i < rank; ++i) {
+ Value dim = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
+ ArrayRef<int64_t>{3, i});
+ dim = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), dim);
+ size =
+ LLVM::MulOp::create(rewriter, loc, rewriter.getI32Type(), dim, size);
+ }
} else {
size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
}
@@ -396,7 +402,7 @@ class OMPIImplTraits : public MPIImplTraits {
};
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
- auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
+ auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, false);
if (failed(attr))
return std::make_unique<MPICHImplTraits>(moduleOp);
auto strAttr = dyn_cast<StringAttr>(attr.value());
@@ -634,7 +640,7 @@ struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
- moduleOp, loc, rewriter, "MPI_Comm_Size", SizeFuncType);
+ moduleOp, loc, rewriter, "MPI_Comm_size", SizeFuncType);
// replace with function call
auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
@@ -675,6 +681,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
MLIRContext *context = rewriter.getContext();
Type i32 = rewriter.getI32Type();
Type elemType = op.getRef().getType().getElementType();
+ int64_t rank = op.getRef().getType().getRank();
// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -684,7 +691,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
// get MPI_COMM_WORLD, dataType and pointer
auto [dataPtr, size] =
- getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
@@ -727,6 +734,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
Type elemType = op.getRef().getType().getElementType();
+ int64_t rank = op.getRef().getType().getRank();
// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(context);
@@ -736,7 +744,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
// get MPI_COMM_WORLD, dataType, status_ignore and pointer
auto [dataPtr, size] =
- getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
@@ -782,10 +790,12 @@ struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
MLIRContext *context = rewriter.getContext();
Type sElemType = op.getSendbuf().getType().getElementType();
Type rElemType = op.getRecvbuf().getType().getElementType();
+ int64_t sRank = op.getSendbuf().getType().getRank();
+ int64_t rRank = op.getRecvbuf().getType().getRank();
auto [sendPtr, sendSize] =
- getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sElemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, sElemType);
auto [recvPtr, recvSize] =
- getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rElemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, rElemType);
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
@@ -843,15 +853,17 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
Type i32 = rewriter.getI32Type();
Type i64 = rewriter.getI64Type();
Type elemType = op.getSendbuf().getType().getElementType();
+ int64_t sRank = op.getSendbuf().getType().getRank();
+ int64_t rRank = op.getRecvbuf().getType().getRank();
// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(context);
auto moduleOp = op->getParentOfType<ModuleOp>();
auto mpiTraits = MPIImplTraits::get(moduleOp);
auto [sendPtr, sendSize] =
- getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
auto [recvPtr, recvSize] =
- getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
+ getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
// If input and output are the same, request in-place operation.
if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 87ae28892fcf7..45e07a1ce5fb9 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -30,6 +30,7 @@
#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -620,6 +621,13 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
using CommOpPattern::CommOpPattern;
+ // shard.allgather always concatenates along a specified gather-axis.
+ // mpi.allgather always concatenates along the first dimension and
+ // there is no MPI operation that allows gathering along an arbitrary axis.
+ // Hence, if gather-axis!=0, we need to create a temporary buffer
+ // where we gather along the first dimension and then copy from that
+ // buffer to the final output along the specified gather-axis.
+
LogicalResult
matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -637,19 +645,95 @@ struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
if (!memref::isStaticShapeAndContiguousRowMajor(outType))
return op.emitError(
"Expected static shaped memref in contiguous row-major layout.");
+ int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
+ auto ctx = op->getContext();
// Get the right communicator
Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
- // Allocate output buffer
- Value output = memref::AllocOp::create(iBuilder, outType);
+
+ Value nRanks =
+ mpi::CommSizeOp::create(iBuilder, iBuilder.getI32Type(), comm)
+ .getSize();
+ nRanks =
+ arith::IndexCastOp::create(iBuilder, iBuilder.getIndexType(), nRanks);
+
+ Value tmpOutput, gatherDimSz;
+ if (gatherAxis == 0) {
+ tmpOutput = memref::AllocOp::create(iBuilder, outType);
+ } else {
+ // MPI's allgather always concatenates along the first dimension.
+ // Create a memref type for the output buffer with adjusted (expanded)
+ // shape.
+ SmallVector<int64_t> gatherShape(1, ShapedType::kDynamic);
+ llvm::append_range(gatherShape, outType.getShape());
+ gatherShape[gatherAxis + 1] = ShapedType::kDynamic;
+ MemRefType gatherType =
+ MemRefType::get(gatherShape, outType.getElementType());
+ gatherDimSz = arith::ConstantIndexOp::create(
+ iBuilder, outType.getDimSize(gatherAxis));
+ gatherDimSz = arith::DivSIOp::create(iBuilder, iBuilder.getIndexType(),
+ gatherDimSz, nRanks);
+ // Allocate output buffer
+ tmpOutput =
+ memref::AllocOp::create(iBuilder, gatherType, {nRanks, gatherDimSz});
+ }
// Create the MPI AllGather operation.
- mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
+ mpi::AllGatherOp::create(iBuilder, TypeRange(), input, tmpOutput, comm);
+
+ // If gather-axis!=0, copy from gathered buffer to output with the right
+ // layout.
+ Value finalOutput = tmpOutput;
+ if (gatherAxis != 0) {
+ int64_t nSrcDims = cast<ShapedType>(tmpOutput.getType()).getRank();
+ assert(nSrcDims == outType.getRank() + 1 &&
+ "Expected gathered type to have rank one more than output type.");
+
+ // Create affine map for copying from gathered buffer to output.
+ SmallVector<AffineExpr> dims;
+ dims.reserve(nSrcDims);
+ for (unsigned i = 0; i < nSrcDims; ++i)
+ dims.emplace_back(getAffineDimExpr(i, ctx));
+ AffineExpr s = getAffineSymbolExpr(0, ctx);
+ SmallVector<AffineExpr> results;
+ results.reserve(nSrcDims);
+ for (unsigned i = 0; i < nSrcDims - 1; ++i) {
+ if (i == gatherAxis)
+ results.emplace_back(dims[0] * s + dims[gatherAxis + 1]);
+ else
+ results.emplace_back(dims[i + 1]);
+ }
+ auto affineMap = AffineMap::get(nSrcDims, /*symbols=*/1, results, ctx);
+
+ finalOutput = memref::AllocOp::create(iBuilder, outType);
+
+ // Now build a loop nest to copy from gathered buffer to finalOutput
+ // It would be nicer to just use a memref.transpose/collapse_shape op but
+ // these currently only support simpler cases.
+ Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
+ SmallVector<Value> lbs(nSrcDims, zero);
+ SmallVector<Value> ubs;
+ for (int64_t d = 0; d < nSrcDims; ++d)
+ ubs.emplace_back(memref::DimOp::create(iBuilder, tmpOutput, d));
+ SmallVector<int64_t> steps(nSrcDims, 1);
+ auto emitCopy = [&](OpBuilder &builder, Location loc, ValueRange ivs) {
+ Value v = memref::LoadOp::create(iBuilder, tmpOutput, ivs);
+ // set symbol value
+ SmallVector<Value> ivss(ivs.begin(), ivs.end());
+ ivss.emplace_back(gatherDimSz);
+ affine::AffineStoreOp::create(iBuilder, v, finalOutput, affineMap,
+ ivss);
+ };
+ affine::buildAffineLoopNest(iBuilder, op->getLoc(), lbs, ubs, steps,
+ emitCopy);
+
+ memref::DeallocOp::create(iBuilder, tmpOutput);
+ }
// If the destination is a tensor, cast it to a tensor
if (isa<RankedTensorType>(op.getType()))
- output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
- true);
- rewriter.replaceOp(op, output);
+ finalOutput = bufferization::ToTensorOp::create(iBuilder, op.getType(),
+ finalOutput, true);
+ rewriter.replaceOp(op, finalOutput);
return success();
}
};
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index f52c3f99189d2..e47a1f62d76d0 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -43,33 +43,46 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
}
};
+template <typename OpT>
+static LogicalResult FoldToDLTIConst(OpT op, const char *key,
+ mlir::PatternRewriter &b) {
+ auto comm = op.getComm();
+ if (!comm.template getDefiningOp<mlir::mpi::CommWorldOp>())
+ return mlir::failure();
+
+ // Try to get DLTI attribute for MPI:comm_world_rank
+ // If found, set worldRank to the value of the attribute.
+ auto dltiAttr = dlti::query(op, {key}, false);
+ if (failed(dltiAttr))
+ return mlir::failure();
+ if (!isa<IntegerAttr>(dltiAttr.value()))
+ return op->emitError() << "Expected an integer attribute for " << key;
+ Value res = arith::ConstantOp::create(
+ b, op.getLoc(), b.getI32Type(),
+ b.getI32IntegerAttr(cast<IntegerAttr>(dltiAttr.value()).getInt()));
+ if (Value retVal = op.getRetval())
+ b.replaceOp(op, {retVal, res});
+ else
+ b.replaceOp(op, res);
+ return mlir::success();
+}
+
struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
-
LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
mlir::PatternRewriter &b) const override {
- auto comm = op.getComm();
- if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>())
- return mlir::failure();
-
- // Try to get DLTI attribute for MPI:comm_world_rank
- // If found, set worldRank to the value of the attribute.
- auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
- if (failed(dltiAttr))
- return mlir::failure();
- if (!isa<IntegerAttr>(dltiAttr.value()))
- return op->emitError()
- << "Expected an integer attribute for MPI:comm_world_rank";
- Value res = arith::ConstantIndexOp::create(
- b, op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
- if (Value retVal = op.getRetval())
- b.replaceOp(op, {retVal, res});
- else
- b.replaceOp(op, res);
- return mlir::success();
+ return FoldToDLTIConst(op, "MPI:comm_world_rank", b);
}
};
+struct FoldSize final : public mlir::OpRewritePattern<mlir::mpi::CommSizeOp> {
+ using mlir::OpRewritePattern<mlir::mpi::CommSizeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::mpi::CommSizeOp op,
+ mlir::PatternRewriter &b) const override {
+ return FoldToDLTIConst(op, "MPI:comm_world_size", b);
+ }
+};
} // namespace
void mlir::mpi::SendOp::getCanonicalizationPatterns(
@@ -97,6 +110,11 @@ void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
results.add<FoldRank>(context);
}
+void mlir::mpi::CommSizeOp::getCanonicalizationPatterns(
+ mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+ results.add<FoldSize>(context);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index 4c1beee2fe144..c6d4ca1eb2e33 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -4,7 +4,7 @@
// CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32
-// CHECK: llvm.func @MPI_Comm_Size(i32, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32
@@ -40,7 +40,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
+ // CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
+ // CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
// CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -50,7 +51,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
+ // CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
+ // CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
// CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32
@@ -60,7 +62,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
+ // CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
+ // CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
// CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64
@@ -72,7 +75,8 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
+ // CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
+ // CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
// CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
// CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32
// CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64
@@ -98,12 +102,14 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v62:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v63:%.*]] = llvm.trunc [[v62]] : i64 to i32
+ // CHECK: [[v63a:%.*]] = llvm.trunc [[v62]] : i64 to i32
+ // CHECK: [[v63:%.*]] = llvm.mul [[v63a]]
// CHECK: [[v64:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v65:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32
+ // CHECK: [[v68a:%.*]] = llvm.trunc [[v67]] : i64 to i32
+ // CHECK: [[v68:%.*]] = llvm.mul [[v68a]]
// CHECK: [[ip:%.*]] = llvm.mlir.constant(-1 : i64) : i64
// CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
// CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32
@@ -127,7 +133,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @ompi_mpi_sum() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_op_t", opaque>
-// CHECK: llvm.func @MPI_Comm_Size(!llvm.ptr, !llvm.ptr) -> i32
+// CHECK: llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Allgather(!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -164,7 +170,8 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v14:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v15:%.*]] = llvm.getelementptr [[v13]][[[v14]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32
+ // CHECK: [[v17a:%.*]] = llvm.trunc [[v16]] : i64 to i32
+ // CHECK: [[v17:%.*]] = llvm.mul [[v17a]]
// CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -174,7 +181,8 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v23:%.*]] = llvm.getelementptr [[v21]][[[v22]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32
+ // CHECK: [[v25a:%.*]] = llvm.trunc [[v24]] : i64 to i32
+ // CHECK: [[v25:%.*]] = llvm.mul [[v25a]]
// CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32
@@ -184,7 +192,8 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v31:%.*]] = llvm.getelementptr [[v29]][[[v30]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32
+ // CHECK: [[v33a:%.*]] = llvm.trunc [[v32]] : i64 to i32
+ // CHECK: [[v33:%.*]] = llvm.mul [[v33a]]
// CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64
@@ -196,7 +205,8 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v41:%.*]] = llvm.getelementptr [[v39]][[[v40]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32
+ // CHECK: [[v43a:%.*]] = llvm.trunc [[v42]] : i64 to i32
+ // CHECK: [[v43:%.*]] = llvm.mul [[v43a]]
// CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
// CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr
// CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64
@@ -204,7 +214,7 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32
%2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval
- // CHECK: llvm.call @MPI_Comm_Size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
+ // CHECK: llvm.call @MPI_Comm_size({{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.udiv {{.*}} : i32
// CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> i32
mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32>
@@ -213,12 +223,14 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
// CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32
+ // CHECK: [[v53a:%.*]] = llvm.trunc [[v52]] : i64 to i32
+ // CHECK: [[v53:%.*]] = llvm.mul [[v53a]]
// CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK: [[v58a:%.*]] = llvm.trunc [[v57]] : i64 to i32
+ // CHECK: [[v58:%.*]] = llvm.mul [[v58a]]
// CHECK: [[ip:%.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[ipp:%.*]] = llvm.inttoptr [[ip]] : i64 to !llvm.ptr
// CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 9a8ad5eea1c7b..200ef488a1a1b 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -155,14 +155,28 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
%arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+ // CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
// CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
// CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
- // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32>
- // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<3x20xf32>
- // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x20xf32> to tensor<3x20xf32>
+ // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
+ // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
+ // CHECK: [[v3:%.*]] = arith.divsi [[vc20]], [[v2]] : index
+ // CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x3x?xf32>
+ // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
+ // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
+ // CHECK: affine.for [[varg1:%.*]] = 0 to [[v2]] {
+ // CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
+ // CHECK: affine.for [[varg3:%.*]] = 0 to [[v3]] {
+ // CHECK: [[v5:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
+ // CHECK: affine.store [[v5]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v3]]) + [[varg3]]] : memref<3x20xf32>
+ // CHECK: }
+ // CHECK: }
+ // CHECK: }
+ // CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
+ // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<3x20xf32> to tensor<3x20xf32>
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x20xf32>
- // CHECK: return [[v2]] : tensor<3x20xf32>
+ // CHECK: return [[v4]] : tensor<3x20xf32>
return %0 : tensor<3x20xf32>
}
@@ -173,12 +187,26 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
%arg0 : memref<3x4xf32>) -> memref<3x20xf32> {
// CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
// CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-DAG: [[vc20:%.*]] = arith.constant 20 : index
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
- // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32>
- // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<3x20xf32>
+ // CHECK: [[vsize:%.*]] = mpi.comm_size([[vnewcomm]]) : i32
+ // CHECK: [[v1:%.*]] = arith.index_cast [[vsize]] : i32 to index
+ // CHECK: [[v2:%.*]] = arith.divsi [[vc20]], [[v1]] : index
+ // CHECK: [[valloc:%.*]] = memref.alloc([[v1]], [[v2]]) : memref<?x3x?xf32>
+ // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<?x3x?xf32>
+ // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<3x20xf32>
+ // CHECK: affine.for [[varg1:%.*]] = 0 to [[v1]] {
+ // CHECK: affine.for [[varg2:%.*]] = 0 to 3 {
+ // CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
+ // CHECK: [[v3:%.*]] = memref.load [[valloc]][[[varg1]], [[varg2]], [[varg3]]] : memref<?x3x?xf32>
+ // CHECK: affine.store [[v3]], [[valloc_0]][[[varg2]], [[varg1]] * symbol([[v2]]) + [[varg3]]] : memref<3x20xf32>
+ // CHECK: }
+ // CHECK: }
+ // CHECK: }
+ // CHECK: memref.dealloc [[valloc]] : memref<?x3x?xf32>
%0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32>
- // CHECK: return [[valloc]] : memref<3x20xf32>
+ // CHECK: return [[valloc_0]] : memref<3x20xf32>
return %0 : memref<3x20xf32>
}
}
diff --git a/mlir/test/Dialect/MPI/canonicalize.mlir b/mlir/test/Dialect/MPI/canonicalize.mlir
new file mode 100644
index 0000000000000..43787a7fad23f
--- /dev/null
+++ b/mlir/test/Dialect/MPI/canonicalize.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -canonicalize | mlir-opt | FileCheck %s
+
+module attributes {mpi.dlti = #dlti.map<"MPI:comm_world_size" = 12, "MPI:comm_world_rank" = 5> } {
+ // CHECK-LABEL: func.func @mpi_test
+ func.func @mpi_test(%ref : memref<100xf32>) -> (i32, i32) {
+ %comm = mpi.comm_world : !mpi.comm
+ // CHECK: [[s:%.*]] = arith.constant 12 : i32
+ %sz = mpi.comm_size(%comm) : i32
+ // CHECK: [[r:%.*]] = arith.constant 5 : i32
+ %rk = mpi.comm_rank(%comm) : i32
+ // CHECK: return [[s]], [[r]] : i32, i32
+ return %sz, %rk : i32, i32
+ }
+}
\ No newline at end of file
>From 7bd1a8eea04bd722903e5ff4f035964aaa17bd45 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 30 Jan 2026 03:14:28 -0800
Subject: [PATCH 2/4] adding mlp tests for partition and shardtompi
---
mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp | 2 +-
.../ShardToMPI/convert-shard-to-mpi.mlir | 58 +++++++++++++++++++
mlir/test/Dialect/MPI/canonicalize.mlir | 2 +-
mlir/test/Dialect/Shard/partition.mlir | 56 +++++++++++++++++-
4 files changed, 114 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 45e07a1ce5fb9..c765ad5a579c8 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -621,7 +621,7 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
using CommOpPattern::CommOpPattern;
- // shard.allgather always concatenates along a specified gather-axis.
+ // shard.allgather concatenates along a specified gather-axis.
// mpi.allgather always concatenates along the first dimension and
// there is no MPI operation that allows gathering along an arbitrary axis.
// Hence, if gather-axis!=0, we need to create a temporary buffer
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 200ef488a1a1b..4ac4a69dd5b18 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -438,3 +438,61 @@ func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !sh
// CHECK: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
return %arg0, %sharding : tensor<?x?xf32>, !shard.sharding
}
+
+// -----
+shard.grid @grid_1d_4(shape = 4)
+// CHECK-LABEL: func.func @mlp_1dgrid(
+// CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
+func.func @mlp_1dgrid(%arg0: tensor<512x512xf32>, %arg1: tensor<2048x256xf32>, %arg2: tensor<256x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
+ // CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<512x512xf32> to memref<512x512xf32>
+ // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+ // CHECK: [[vsize:%.*]] = mpi.comm_size
+ // CHECK: [[v2:%.*]] = arith.index_cast [[vsize]] : i32 to index
+ // CHECK: [[v3:%.*]] = arith.divsi
+ // CHECK: [[valloc:%.*]] = memref.alloc([[v2]], [[v3]]) : memref<?x512x?xf32>
+ // CHECK: mpi.allgather([[v0]], [[valloc]], [[v1]]) : memref<512x512xf32>, memref<?x512x?xf32>
+ // CHECK: [[valloc_0:%.*]] = memref.alloc() : memref<512x2048xf32>
+ // CHECK: affine.for [[varg3:%.*]] = 0 to [[v2]] {
+ // CHECK: affine.for [[varg4:%.*]] = 0 to 512 {
+ // CHECK: affine.for [[varg5:%.*]] = 0 to [[v3]] {
+ // CHECK: [[v19:%.*]] = memref.load [[valloc]][[[varg3]], [[varg4]], [[varg5]]] : memref<?x512x?xf32>
+ // CHECK: affine.store [[v19]], [[valloc_0]][[[varg4]], [[varg3]] * symbol([[v3]]) + [[varg5]]] : memref<512x2048xf32>
+ // CHECK: memref.dealloc [[valloc]] : memref<?x512x?xf32>
+ // CHECK: [[v4:%.*]] = bufferization.to_tensor [[valloc_0]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
+ %all_gather = shard.all_gather %arg0 on @grid_1d_4 grid_axes = [0] gather_axis = 1 : tensor<512x512xf32> -> tensor<512x2048xf32>
+ // CHECK: [[v5:%.*]] = tensor.empty() : tensor<512x256xf32>
+ %0 = tensor.empty() : tensor<512x256xf32>
+ // CHECK: [[v6:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v5]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x256xf32>) -> tensor<512x256xf32>
+ // CHECK: [[v7:%.*]] = linalg.matmul ins([[v4]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v6]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+ %2 = linalg.matmul ins(%all_gather, %arg1 : tensor<512x2048xf32>, tensor<2048x256xf32>) outs(%1 : tensor<512x256xf32>) -> tensor<512x256xf32>
+ // CHECK: [[v8:%.*]] = tosa.sigmoid [[v7]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
+ %3 = tosa.sigmoid %2 : (tensor<512x256xf32>) -> tensor<512x256xf32>
+ %4 = tensor.empty() : tensor<512x2048xf32>
+ %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+ %proc_linear_idx = shard.process_multi_index on @grid_1d_4 axes = [0] : index
+ %grid_shape = shard.grid_shape @grid_1d_4 axes = [0] : index
+ %6 = arith.cmpi eq, %proc_linear_idx, %c0 : index
+ // CHECK: [[v14:%.*]] = scf.if
+ %7 = scf.if %6 -> (tensor<512x2048xf32>) {
+ scf.yield %5 : tensor<512x2048xf32>
+ } else {
+ %9 = tensor.empty() : tensor<512x2048xf32>
+ %10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+ scf.yield %10 : tensor<512x2048xf32>
+ }
+ // CHECK: [[v15:%.*]] = linalg.matmul ins([[v8]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v14]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+ %8 = linalg.matmul ins(%3, %arg2 : tensor<512x256xf32>, tensor<256x2048xf32>) outs(%7 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+ // CHECK: [[v16:%.*]] = bufferization.to_buffer
+ // CHECK: [[valloc_1:%.*]] = memref.alloc() : memref<512x2048xf32>
+ // CHECK: linalg.copy ins([[v16]] : memref<512x2048xf32>) outs([[valloc_1]] : memref<512x2048xf32>)
+ // CHECK: [[v17:%.*]] = mpi.comm_world : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc_1]], [[valloc_1]], MPI_SUM, [[v17]]) : memref<512x2048xf32>, memref<512x2048xf32>
+ // CHECK: [[v18:%.*]] = bufferization.to_tensor [[valloc_1]] restrict : memref<512x2048xf32> to tensor<512x2048xf32>
+ %all_reduce = shard.all_reduce %8 on @grid_1d_4 grid_axes = [0] : tensor<512x2048xf32> -> tensor<512x2048xf32>
+ // CHECK: return [[v18]] : tensor<512x2048xf32>
+ return %all_reduce : tensor<512x2048xf32>
+}
diff --git a/mlir/test/Dialect/MPI/canonicalize.mlir b/mlir/test/Dialect/MPI/canonicalize.mlir
index 43787a7fad23f..3523d46e21219 100644
--- a/mlir/test/Dialect/MPI/canonicalize.mlir
+++ b/mlir/test/Dialect/MPI/canonicalize.mlir
@@ -11,4 +11,4 @@ module attributes {mpi.dlti = #dlti.map<"MPI:comm_world_size" = 12, "MPI:comm_wo
// CHECK: return [[s]], [[r]] : i32, i32
return %sz, %rk : i32, i32
}
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
index 0f293a39608e3..cd9fa2215e0ee 100644
--- a/mlir/test/Dialect/Shard/partition.mlir
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -3,6 +3,7 @@
// RUN: %s | FileCheck %s
shard.grid @grid_1d(shape = 2)
+shard.grid @grid_1d_4(shape = 4)
// CHECK-LABEL: func @return_sharding
func.func @return_sharding(
@@ -204,8 +205,6 @@ func.func @incomplete_sharding(
return %3 : tensor<8x16xf32>
}
-shard.grid @grid_1d_4(shape = 4)
-
// CHECK-LABEL: func @ew_chain_with_halo
func.func @ew_chain_with_halo(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
@@ -318,3 +317,56 @@ func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
// CHECK: return %[[reduced]] : tensor<3xi32>
return %sharded_ret : tensor<6xi32>
}
+
+// CHECK-LABEL: func.func @mlp_1dgrid
+// CHECK-SAME: [[varg0:%.*]]: tensor<512x512xf32>, [[varg1:%.*]]: tensor<2048x256xf32>, [[varg2:%.*]]: tensor<256x2048xf32>) -> tensor<512x2048xf32>
+func.func @mlp_1dgrid(%arg0: tensor<512x2048xf32>, %arg1: tensor<2048x1024xf32>, %arg2: tensor<1024x2048xf32>) -> tensor<512x2048xf32> attributes {llvm.emit_c_interface} {
+ // CHECK: [[vcst:%.*]] = arith.constant 0.000000e+00 : f32
+ %sharding = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding
+ %sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0], []] : !shard.sharding
+ %sharding_1 = shard.sharding @grid_1d_4 split_axes = [[]] : !shard.sharding
+ %sharding_2 = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding
+ %sharding_3 = shard.sharding @grid_1d_4 split_axes = [[], [0]] : !shard.sharding
+ %sharding_4 = shard.sharding @grid_1d_4 split_axes = [[0], []] : !shard.sharding
+ %sharding_5 = shard.sharding @grid_1d_4 split_axes = [[]] : !shard.sharding
+ %sharding_annotated = shard.shard %arg0 to %sharding_2 : tensor<512x2048xf32>
+ %sharding_annotated_6 = shard.shard %arg1 to %sharding_3 : tensor<2048x1024xf32>
+ %sharding_annotated_7 = shard.shard %arg2 to %sharding_4 : tensor<1024x2048xf32>
+ // CHECK-DAG: [[v0:%.*]] = tensor.empty() : tensor<512x256xf32>
+ %0 = tensor.empty() : tensor<512x1024xf32>
+ %sharding_annotated_8 = shard.shard %0 to %sharding : tensor<512x1024xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %sharding_annotated_9 = shard.shard %sharding_annotated_8 to %sharding annotate_for_users : tensor<512x1024xf32>
+ // CHECK-DAG: [[v1:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v0]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%sharding_annotated_9 : tensor<512x1024xf32>) -> tensor<512x1024xf32>
+ %sharding_annotated_10 = shard.shard %1 to %sharding : tensor<512x1024xf32>
+ // CHECK-DAG: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_1d_4 grid_axes = [0] gather_axis = 1 : tensor<512x512xf32> -> tensor<512x2048xf32>
+ %sharding_annotated_11 = shard.shard %sharding_annotated to %sharding_1 annotate_for_users : tensor<512x2048xf32>
+ %sharding_annotated_12 = shard.shard %sharding_annotated_6 to %sharding annotate_for_users : tensor<2048x1024xf32>
+ %sharding_annotated_13 = shard.shard %sharding_annotated_10 to %sharding annotate_for_users : tensor<512x1024xf32>
+ // CHECK: [[v2:%.*]] = linalg.matmul ins([[vall_gather]], [[varg1]] : tensor<512x2048xf32>, tensor<2048x256xf32>) outs([[v1]] : tensor<512x256xf32>) -> tensor<512x256xf32>
+ %2 = linalg.matmul ins(%sharding_annotated_11, %sharding_annotated_12 : tensor<512x2048xf32>, tensor<2048x1024xf32>) outs(%sharding_annotated_13 : tensor<512x1024xf32>) -> tensor<512x1024xf32>
+ %sharding_annotated_14 = shard.shard %2 to %sharding : tensor<512x1024xf32>
+ %sharding_annotated_15 = shard.shard %sharding_annotated_14 to %sharding annotate_for_users : tensor<512x1024xf32>
+ // CHECK: [[v3:%.*]] = tosa.sigmoid [[v2]] : (tensor<512x256xf32>) -> tensor<512x256xf32>
+ %3 = tosa.sigmoid %sharding_annotated_15 : (tensor<512x1024xf32>) -> tensor<512x1024xf32>
+ %sharding_annotated_16 = shard.shard %3 to %sharding : tensor<512x1024xf32>
+ // CHECK: [[v9:%.*]] = tensor.empty() : tensor<512x2048xf32>
+ %4 = tensor.empty() : tensor<512x2048xf32>
+ %sharding_annotated_17 = shard.shard %4 to %sharding_1 : tensor<512x2048xf32>
+ %sharding_annotated_18 = shard.shard %sharding_annotated_17 to %sharding_1 annotate_for_users : tensor<512x2048xf32>
+ // CHECK: [[v10:%.*]] = linalg.fill ins([[vcst]] : f32) outs([[v9]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+ %5 = linalg.fill ins(%cst : f32) outs(%sharding_annotated_18 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+ %sharding_annotated_19 = shard.shard %5 to %sharding_1 : tensor<512x2048xf32>
+ %sharding_annotated_20 = shard.shard %sharding_annotated_16 to %sharding annotate_for_users : tensor<512x1024xf32>
+ %sharding_annotated_21 = shard.shard %sharding_annotated_7 to %sharding_0 annotate_for_users : tensor<1024x2048xf32>
+ %sharding_annotated_22 = shard.shard %sharding_annotated_19 to %sharding_1 annotate_for_users : tensor<512x2048xf32>
+ // CHECK: [[v7:%.*]] = scf.if
+ // CHECK: [[v8:%.*]] = linalg.matmul ins([[v3]], [[varg2]] : tensor<512x256xf32>, tensor<256x2048xf32>) outs([[v7]] : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+ %6 = linalg.matmul ins(%sharding_annotated_20, %sharding_annotated_21 : tensor<512x1024xf32>, tensor<1024x2048xf32>) outs(%sharding_annotated_22 : tensor<512x2048xf32>) -> tensor<512x2048xf32>
+ %sharding_annotated_23 = shard.shard %6 to %sharding_1 : tensor<512x2048xf32>
+ // CHECK: [[vall_reduce:%.*]] = shard.all_reduce [[v8]] on @grid_1d_4 grid_axes = [0] : tensor<512x2048xf32> -> tensor<512x2048xf32>
+ %sharding_annotated_24 = shard.shard %sharding_annotated_23 to %sharding_5 annotate_for_users : tensor<512x2048xf32>
+ // CHECK: return [[vall_reduce]] : tensor<512x2048xf32>
+ return %sharding_annotated_24 : tensor<512x2048xf32>
+}
>From 1fd20a067462dbe2bd679bcc0e5c6d243535e7b9 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 30 Jan 2026 03:34:16 -0800
Subject: [PATCH 3/4] delete redundant *&
---
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 9ddef319ad71c..548481250e4f6 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -402,7 +402,7 @@ class OMPIImplTraits : public MPIImplTraits {
};
std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
- auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, false);
+ auto attr = dlti::query(moduleOp, {"MPI:Implementation"}, false);
if (failed(attr))
return std::make_unique<MPICHImplTraits>(moduleOp);
auto strAttr = dyn_cast<StringAttr>(attr.value());
>From 0449fa9a6b85d5f71295c2ebc28bb19b83ae8550 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 30 Jan 2026 09:12:14 -0800
Subject: [PATCH 4/4] folding CommSize
---
mlir/include/mlir/Dialect/MPI/IR/Utils.h | 43 +++++++++++++++++++
mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 15 ++++++-
mlir/lib/Dialect/MPI/IR/MPIOps.cpp | 27 +-----------
mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 16 +++++++
4 files changed, 73 insertions(+), 28 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/MPI/IR/Utils.h
diff --git a/mlir/include/mlir/Dialect/MPI/IR/Utils.h b/mlir/include/mlir/Dialect/MPI/IR/Utils.h
new file mode 100644
index 0000000000000..9ff78e9e092fc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MPI/IR/Utils.h
@@ -0,0 +1,43 @@
+//===- Utils.h - MPI dialect --------------------------------------*- C++-*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_MPI_IR_UTILS_H_
+#define MLIR_DIALECT_MPI_IR_UTILS_H_
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace mpi {
+template <typename OpT>
+LogicalResult FoldToDLTIConst(OpT op, const char *key,
+ mlir::PatternRewriter &b) {
+ auto comm = op.getComm();
+ if (!comm.template getDefiningOp<mlir::mpi::CommWorldOp>())
+ return mlir::failure();
+
+ // Try to get DLTI attribute for MPI:comm_world_rank
+ // If found, set worldRank to the value of the attribute.
+ auto dltiAttr = dlti::query(op, {key}, false);
+ if (failed(dltiAttr))
+ return mlir::failure();
+ if (!isa<IntegerAttr>(dltiAttr.value()))
+ return op->emitError() << "Expected an integer attribute for " << key;
+ Value res = arith::ConstantOp::create(
+ b, op.getLoc(), b.getI32Type(),
+ b.getI32IntegerAttr(cast<IntegerAttr>(dltiAttr.value()).getInt()));
+ if (Value retVal = op.getRetval())
+ b.replaceOp(op, {retVal, res});
+ else
+ b.replaceOp(op, res);
+ return mlir::success();
+}
+} // namespace mpi
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MPI_IR_UTILS_H_
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 548481250e4f6..0dbc0a126a5c6 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MPI/IR/Utils.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
@@ -614,6 +615,17 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
// CommSizeOpLowering
//===----------------------------------------------------------------------===//
+static Value createOrFoldCommSize(ConversionPatternRewriter &rewriter,
+ Location loc, Value commOrg,
+ Value commAdapt) {
+ auto i32 = rewriter.getI32Type();
+ auto nRanksOp = mpi::CommSizeOp::create(rewriter, loc, i32, commOrg);
+ if (succeeded(FoldToDLTIConst(nRanksOp, "MPI:comm_world_size", rewriter)))
+ return nRanksOp.getSize();
+ rewriter.eraseOp(nRanksOp);
+ return mpi::CommSizeOp::create(rewriter, loc, i32, commAdapt).getSize();
+}
+
struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -818,8 +830,7 @@ struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
// count_recv is the number of elements received from each rank, not total
Value nRanks =
- mpi::CommSizeOp::create(rewriter, loc, i32, adaptor.getComm())
- .getSize();
+ createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
Value recvCountPerRank =
LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index e47a1f62d76d0..6cca853071dc2 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -6,12 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MPI/IR/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::mpi;
@@ -43,30 +42,6 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
}
};
-template <typename OpT>
-static LogicalResult FoldToDLTIConst(OpT op, const char *key,
- mlir::PatternRewriter &b) {
- auto comm = op.getComm();
- if (!comm.template getDefiningOp<mlir::mpi::CommWorldOp>())
- return mlir::failure();
-
- // Try to get DLTI attribute for MPI:comm_world_rank
- // If found, set worldRank to the value of the attribute.
- auto dltiAttr = dlti::query(op, {key}, false);
- if (failed(dltiAttr))
- return mlir::failure();
- if (!isa<IntegerAttr>(dltiAttr.value()))
- return op->emitError() << "Expected an integer attribute for " << key;
- Value res = arith::ConstantOp::create(
- b, op.getLoc(), b.getI32Type(),
- b.getI32IntegerAttr(cast<IntegerAttr>(dltiAttr.value()).getInt()));
- if (Value retVal = op.getRetval())
- b.replaceOp(op, {retVal, res});
- else
- b.replaceOp(op, res);
- return mlir::success();
-}
-
struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
index c6d4ca1eb2e33..9ec81c53b41f8 100644
--- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
+++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir
@@ -95,6 +95,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} {
// CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32
%split = mpi.comm_split(%comm, %color, %key) : !mpi.comm
+ // CHECK: llvm.call @MPI_Comm_size
// CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
%err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
@@ -256,3 +257,18 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } {
return
}
}
+
+// -----
+
+module attributes {mpi.dlti = #dlti.map<"MPI:Implementation" = "MPICH", "MPI:comm_world_size" = 4, "MPI:comm_world_rank" = 1> } {
+ // CHECK: llvm.func @mpi_test_fold
+ func.func @mpi_test_fold(%arg0: memref<100xf32>) {
+ // CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64
+ %comm = mpi.comm_world : !mpi.comm
+
+ // CHECK-NOT: llvm.call @MPI_Comm_size
+ // CHECK: llvm.call @MPI_Allgather({{.*}} : (!llvm.ptr, i32, i32, !llvm.ptr, i32, i32, i32) -> i32
+ %err3 = mpi.allgather(%arg0, %arg0, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval
+ return
+ }
+}
More information about the Mlir-commits
mailing list